本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
SageMaker 智能筛选 Python SDK 参考
本页提供了在训练脚本中应用 SageMaker 智能筛选所需的 Python 模块的参考。
SageMaker 智能筛选配置模块
class
smart_sifting.sift_config.sift_configs.RelativeProbabilisticSiftConfig()
SageMaker 智能筛选配置类。
参数
-
beta_value
(浮点数)- β(常数)值。它用于根据损失值历史记录中的损失百分位数计算选择样本进行训练的概率。降低 β 值会降低筛选数据的百分比,而提高此值则会提高筛选数据的百分比。β 值没有最小值或最大值之分,但必须是正值。下表提供了与beta_value
有关的筛选率信息。beta_value
保留数据的比例 (%) 筛选出的数据比例 (%) 0.1 90.91 9.01 0.25 80 20 0.5 66.67 33.33 1 50 50 2 33.33 66.67 3 25 75 10 9.09 90.92 100 0.99 99.01 -
loss_history_length
(int):基于相对阈值损失的采样要存储的先前训练损失的数量。 -
loss_based_sift_config
(dict 或LossConfig
对象)— 指定返回 SageMaker 智能筛选 Loss 接口配置的LossConfig
对象。
class
smart_sifting.sift_config.sift_configs.LossConfig()
RelativeProbabilisticSiftConfig
类 loss_based_sift_config
参数的配置类。
参数
-
sift_config
(dict 或SiftingBaseConfig
对象):指定返回筛选基础配置字典的SiftingBaseConfig
对象。
class
smart_sifting.sift_config.sift_configs.SiftingBaseConfig()
LossConfig
的 sift_config
参数的配置类。
参数
-
sift_delay
(int):开始筛选之前要等待的训练步骤数。我们建议您在模型中的所有层都有足够的训练数据视图后再开始筛选。默认值为1000
。 -
repeat_delay_per_epoch
(bool):指定是否延迟筛选每个历时的时间。默认值为False
。
SageMaker 智能筛选数据批量转换模块
class
smart_sifting.data_model.data_model_interface.SiftingBatchTransform
一个 SageMaker 智能筛选 Python 模块,用于定义如何执行批量转换。使用它,您可以设置一个批处理转换类,将训练数据的数据SiftingBatch
格式转换为格式。 SageMaker 智能筛选可以将这种格式的数据筛选并累积成经过筛选的批次。
class
smart_sifting.data_model.data_model_interface.SiftingBatch
用于定义可筛选和累积的批次数据类型的界面。
class
smart_sifting.data_model.list_batch.ListBatch
用于跟踪列表批次以进行筛选的模块。
class
smart_sifting.data_model.tensor_batch.TensorBatch
用于跟踪张量批次以进行筛选的模块。
SageMaker 智能筛选损失实现模块
class
smart_sifting.loss.abstract_sift_loss_module.Loss
一个包装模块,用于将 SageMaker 智能筛选接口注册到 PyTorch基于模型的损失函数。
SageMaker 智能筛选数据加载器封装模块
class
smart_sifting.dataloader.sift_dataloader.SiftingDataloader
一个封装模块,用于将 SageMaker 智能筛选接口注册到 PyTorch基于模型的数据加载器。
主筛选数据加载器迭代器根据筛选配置从数据加载器中筛选出训练样本。
参数
-
sift_config
(dict 或RelativeProbabilisticSiftConfig
对象):RelativeProbabilisticSiftConfig
对象。 -
orig_dataloader
( PyTorch DataLoader 对象)— 指定要封装的 PyTorch Dataloader 对象。 -
batch_transforms
(SiftingBatchTransform
对象)—(可选)如果 SageMaker 智能筛选库的默认转换不支持您的数据格式,则必须使用该SiftingBatchTransform
模块创建批处理转换类。此参数用于传递批次转换类。该类用于将数据SiftingDataloader
转换为 SageMaker 智能筛选算法可以接受的格式。 -
model
( PyTorch 模型对象)-原始 PyTorch模型 -
loss_impl
(的筛选损失函数smart_sifting.loss.abstract_sift_loss_module.Loss
)— 一种筛选损失函数,它与Loss
模块一起配置并封装损失函数。 PyTorch -
log_batch_data
(bool):指定是否记录批次数据。如果设置为True
,则 SageMaker 智能筛选会记录保留或筛选的批次的详细信息。我们建议您只在测试训练作业时打开它。开启日志记录时,样本会被加载到 GPU 并传输到 CPU,这会带来开销。默认值为False
。