光照GBM算法的输入和输出接口 - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

光照GBM算法的输入和输出接口

梯度提升对表格数据进行操作,其中行表示观察、一个列表示目标变量或标签,其余列表示特征。

Light 的 SageMaker 实现GBMCSV支持训练和推理:

  • 对于训练 ContentType,有效的输入必须是文本/ csv。

  • 要进行推理 ContentType,有效的输入必须是文本 /csv。

注意

对于CSV训练,该算法假设目标变量位于第一列,并且CSV没有标题记录。

为了进行CSV推断,该算法假设CSV输入没有标签列。

训练数据、验证数据和类别特征的输入格式

请注意如何格式化训练数据,以便输入到 Light GBM 模型。您必须提供包含训练和验证数据的 Amazon S3 存储桶的路径。您还可以包含类别特征列表。请使用 trainvalidation 通道来提供您的输入数据。您也可以只使用 train 通道。

注意

train和都training是 Light GBM 训练的有效频道名称。

使用 trainvalidation 通道

您可以通过两条 S3 路径来提供输入数据,一条用于 train 通道,一条用于 validation 通道。每个 S3 路径可以是指向一个或多个CSV文件的 S3 前缀,也可以是指向一个特定CSV文件的完整 S3 路径。目标变量应位于CSV文件的第一列。预测器变量(特征)应位于其余列。如果为trainvalidation通道提供了多个CSV文件,则 Light GBM 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。

如果您的预测变量包含类别特征,则可以提供一个JSON或多个与训练数据文件同名的categorical_index.json文件。如果您为分类要素提供JSON文件,则您的train频道必须指向 S3 前缀而不是特定CSV文件。此文件应包含一个 Python 字典,其中的键是字符串 "cat_index_list",值是唯一整数列表。值列表中的每个整数都应表示训练数据CSV文件中相应类别特征的列索引。每个值都应为正整数(大于零,因为零表示目标值),小于 Int32.MaxValue (2147483647),并且小于列的总数。应该只有一个分类索引JSON文件。

仅使用 train 通道

您也可以通过单个 S3 路径,为 train 通道提供输入数据。此 S3 路径应指向一个名为、train/包含一个或多个CSV文件的子目录的目录。您可以选择在名为的同一位置添加另一个子目录validation/,该子目录也包含一个或多个CSV文件。如果未提供验证数据,则会随机采样 20% 的训练数据作为验证数据。如果您的预测变量包含类别特征,则可以提供一个与数据子目录同名的JSONcategorical_index.json文件。

注意

对于CSV训练输入模式,算法可用的总内存(实例数乘以中的可用内存InstanceType)必须能够容纳训练数据集。

SageMaker Light GBM 使用 Python Joblib 模块对模型进行序列化或反序列化,该模块可用于保存或加载模型。

在 JobLib 模块中使用使用 L SageMaker ight 训练过GBM的模型
  • 使用以下 Python 代码:

    import joblib import tarfile t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() model = joblib.load(model_file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest)