本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
TabTransformer 算法的输入和输出接口
TabTransformer 对表格数据进行操作,行代表观测值,一列代表目标变量或标签,其余列代表特征。
CSV为训练和推理提供 TabTransformer 支持的 SageMaker 实施:
-
对于训练 ContentType,有效的输入必须是文本/ csv。
-
要进行推理 ContentType,有效的输入必须是文本 /csv。
注意
对于CSV训练,该算法假设目标变量位于第一列,并且CSV没有标题记录。
为了进行CSV推断,该算法假设CSV输入没有标签列。
训练数据、验证数据和类别特征的输入格式
请注意如何格式化训练数据,以便输入到 TabTransformer 模型中。您必须提供包含训练和验证数据的 Amazon S3 存储桶的路径。您还可以包含类别特征列表。请使用 training
和 validation
通道来提供您的输入数据。您也可以只使用 training
通道。
使用 training
和 validation
通道
您可以通过两条 S3 路径来提供输入数据,一条用于 training
通道,一条用于 validation
通道。每个 S3 路径可以是指向一个或多个CSV文件的 S3 前缀,也可以是指向一个特定CSV文件的完整 S3 路径。目标变量应位于CSV文件的第一列。预测器变量(特征)应位于其余列。如果为training
或validation
通道提供了多个CSV文件,则 TabTransformer 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。
如果您的预测变量包含类别特征,则可以提供一个JSON或多个与训练数据文件同名的categorical_index.json
文件。如果您为分类要素提供JSON文件,则您的training
频道必须指向 S3 前缀而不是特定CSV文件。此文件应包含一个 Python 字典,其中的键是字符串 "cat_index_list"
,值是唯一整数列表。值列表中的每个整数都应表示训练数据CSV文件中相应类别特征的列索引。每个值都应为正整数(大于零,因为零表示目标值),小于 Int32.MaxValue
(2147483647),并且小于列的总数。应该只有一个分类索引JSON文件。
仅使用 training
通道:
您也可以通过单个 S3 路径,为 training
通道提供输入数据。此 S3 路径应指向一个名为、training/
包含一个或多个CSV文件的子目录的目录。您可以选择在名为的同一位置添加另一个子目录validation/
,该子目录也包含一个或多个CSV文件。如果未提供验证数据,则会随机采样 20% 的训练数据作为验证数据。如果您的预测变量包含类别特征,则可以提供一个与数据子目录同名的JSONcategorical_index.json
文件。
注意
对于CSV训练输入模式,算法可用的总内存(实例数乘以中的可用内存InstanceType
)必须能够容纳训练数据集。