本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
这些说明演示了如何使用训练脚本启用 SageMaker 智能筛选。
-
配置 SageMaker 智能筛选界面。
SageMaker 智能筛选库实现了一种基于相对阈值损耗的采样技术,该技术有助于筛选出对降低损耗值影响较小的样本。 SageMaker 智能筛选算法使用正向传递计算每个输入数据样本的损失值,并根据先前数据的损失值计算其相对百分位数。
以下两个参数是创建筛选配置对象时需要为
RelativeProbabilisticSiftConfig
类指定的参数。-
指定用于
beta_value
参数训练的数据比例。 -
使用
loss_history_length
参数指定用于比较的样本数。
以下代码示例演示了如何设置
RelativeProbabilisticSiftConfig
类的对象。from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )
有关
loss_based_sift_config
参数和相关类的更多信息,请参阅SageMaker 智能筛选配置模块 SageMaker 智能筛选 Python SDK 参考部分中的。前面代码示例中的
sift_config
对象在第 4 步中用于设置SiftingDataloader
类。 -
-
(可选)配置 SageMaker 智能筛选批量转换类。
不同的训练使用场景需要不同的训练数据格式。鉴于数据格式多种多样, SageMaker 智能筛选算法需要确定如何对特定批次进行筛选。为了解决这个问题, SageMaker 智能筛选提供了一个批量转换模块,可以帮助将批次转换为可以高效筛选的标准化格式。
-
SageMaker 智能筛选处理以下格式的训练数据的批量转换:Python 列表、字典、元组和张量。对于这些数据格式, SageMaker 智能筛选会自动处理批量数据格式转换,您可以跳过此步骤的其余部分。如果您跳过此步骤,在配置
SiftingDataloader
的第 4 步中,请将SiftingDataloader
的batch_transforms
参数保留为默认值None
。 -
如果您的数据集不是这些格式,则您应继续本步骤的其余部分,使用
SiftingBatchTransform
创建自定义批量转换。如果您的数据集不是 SageMaker 智能筛选支持的格式之一,则可能会遇到错误。此类数据格式错误可以通过在
SiftingDataloader
类中添加batch_format_index
或batch_transforms
参数来解决,您可以在第 4 步中进行设置。下面显示了由于数据格式不兼容而导致的错误示例以及解决方法。错误消息 解决方案 默认情况下,
{type(batch)}
不支持该类型的批处理。此错误表示默认不支持批次格式。您应该实现一个自定义批次转换类,并通过将其指定给 SiftingDataloader
类的batch_transforms
参数中来使用它。无法为该批次编制索引
{type(batch)}
此错误表明无法正常为批次对象编制索引。用户必须实现自定义批次转换,并使用 batch_transforms
参数传递。Batch 大小与维度 0 或维度 1 的大小
{batch_size}
不匹配当提供的批次大小与批次的维度 0 或维度 1 不匹配时,会出现此错误。用户必须实现自定义批次转换,并使用 batch_transforms
参数传递。维度 0 和维度 1 都匹配批次大小
此错误表明,由于多个维度与提供的批次大小相匹配,因此需要更多信息来筛选批次。用户可提供 batch_format_index
参数,指示批次是否可按样本或特征编制索引。用户也可以实施自定义批次转换,但这比所需的工作量更大。要解决上述问题,您需要使用
SiftingBatchTransform
模块创建自定义批处理转换类。批次转换类应由一对转换和反向转换函数组成。函数对将您的数据格式转换为 SageMaker 智能筛选算法可以处理的格式。创建批次转换类后,此类会返回一个SiftingBatch
对象,您将在第 4 步中把此对象传递给SiftingDataloader
类。以下是
SiftingBatchTransform
模块中自定义批次转换类的示例。-
使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于数据加载器块包含输入、掩码和标签的情况。
from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class
ListBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch -
使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于不需要标签进行反向转换的情况。
class
ListBatchTransformNoLabels
(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch -
在数据加载器块有输入、掩码和标签的情况下,使用 SageMaker 智能筛选的自定义张量批处理实现示例。
from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class
TensorBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch
在您创建已执行
SiftingBatchTransform
批次转换类后,可在第 4 步中使用SiftingDataloader
类进行设置。本指南的其余部分假设已创建了一个ListBatchTransform
类。在第 4 步中,此类将传递给batch_transforms
。 -
-
-
创建用于实现 SageMaker 智能筛选
Loss
接口的类。本教程假定此类名为SiftingImplementedLoss
。在设置此类时,我们建议您在模型训练循环中使用相同的损失函数。按照以下子步骤创建 SageMaker 智能筛选Loss
实现的类。-
SageMaker 智能筛选计算每个训练数据样本的损失值,而不是计算批次的单个损失值。为确保 SageMaker 智能筛选使用相同的损失计算逻辑,请使用 SageMaker 智能筛选
Loss
模块创建 smart-sifting-implemented损失函数,该模块使用您的损失函数并计算每个训练样本的损失。提示
SageMaker 智能筛选算法在每个数据样本上运行,而不是在整个批次上运行,因此您应该添加一个初始化函数来设置 PyTorch 损失函数,而无需任何还原策略。
class
SiftingImplementedLoss
(Loss): def __init__(self): self.loss =torch.nn.CrossEntropyLoss
(reduction='none')以下代码示例也说明了这一点。
-
定义一个接受
original_batch
(或者transformed_batch
如果您在步骤 2 中设置了批量变换)和 PyTorch模型的损失函数。 SageMaker 智能筛选使用不减值的指定损失函数,对每个数据样本进行正向传递,以评估其损失值。
以下代码是一个名为的 smart-sifting-implemented
Loss
接口的示例SiftingImplementedLoss
。from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class
SiftingImplementedLoss
(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction
=torch.nn.CrossEntropyLoss
(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction
(outputs.logits, batch[2])在训练循环进入实际前向传递之前,每次迭代获取批次数据的数据加载阶段都会进行筛选损失计算。然后将单个损失值与之前的损失值进行比较,并根据步骤 1 中设置的
RelativeProbabilisticSiftConfig
对象估算出其相对百分位数。 -
-
按 SageMaker AI
SiftingDataloader
类封装 PyTroch 数据加载器。最后,将您在前面步骤中配置的所有 SageMaker 智能筛选实现的类用于 SageMaker AI
SiftingDataloder
配置类。这个类是的封装器。 PyTorchDataLoader
通过封装 PyTorch DataLoader
, SageMaker 智能筛选被注册为在 PyTorch 训练作业的每次迭代中作为数据加载的一部分运行。以下代码示例演示如何实现 SageMaker AI 数据筛选到. PyTorchDataLoader
from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=
sift_config
, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader
, batch_transforms=ListBatchTransform
(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss
(), # PyTorch loss function wrapped by the Sifting Loss interface model=model
, log_batch_data=False
)