选择您的 Cookie 首选项

我们使用必要 Cookie 和类似工具提供我们的网站和服务。我们使用性能 Cookie 收集匿名统计数据,以便我们可以了解客户如何使用我们的网站并进行改进。必要 Cookie 无法停用,但您可以单击“自定义”或“拒绝”来拒绝性能 Cookie。

如果您同意,AWS 和经批准的第三方还将使用 Cookie 提供有用的网站功能、记住您的首选项并显示相关内容,包括相关广告。要接受或拒绝所有非必要 Cookie,请单击“接受”或“拒绝”。要做出更详细的选择,请单击“自定义”。

对脚本应用 SageMaker 智能筛选 PyTorch

聚焦模式
对脚本应用 SageMaker 智能筛选 PyTorch - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

这些说明演示了如何使用训练脚本启用 SageMaker 智能筛选。

  1. 配置 SageMaker 智能筛选界面。

    SageMaker 智能筛选库实现了一种基于相对阈值损耗的采样技术,该技术有助于筛选出对降低损耗值影响较小的样本。 SageMaker 智能筛选算法使用正向传递计算每个输入数据样本的损失值,并根据先前数据的损失值计算其相对百分位数。

    以下两个参数是创建筛选配置对象时需要为 RelativeProbabilisticSiftConfig 类指定的参数。

    • 指定用于 beta_value 参数训练的数据比例。

    • 使用 loss_history_length 参数指定用于比较的样本数。

    以下代码示例演示了如何设置 RelativeProbabilisticSiftConfig 类的对象。

    from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )

    有关loss_based_sift_config参数和相关类的更多信息,请参阅SageMaker 智能筛选配置模块 SageMaker 智能筛选 Python SDK 参考部分中的。

    前面代码示例中的 sift_config 对象在第 4 步中用于设置 SiftingDataloader 类。

  2. (可选)配置 SageMaker 智能筛选批量转换类。

    不同的训练使用场景需要不同的训练数据格式。鉴于数据格式多种多样, SageMaker 智能筛选算法需要确定如何对特定批次进行筛选。为了解决这个问题, SageMaker 智能筛选提供了一个批量转换模块,可以帮助将批次转换为可以高效筛选的标准化格式。

    1. SageMaker 智能筛选处理以下格式的训练数据的批量转换:Python 列表、字典、元组和张量。对于这些数据格式, SageMaker 智能筛选会自动处理批量数据格式转换,您可以跳过此步骤的其余部分。如果您跳过此步骤,在配置 SiftingDataloader 的第 4 步中,请将 SiftingDataloaderbatch_transforms 参数保留为默认值 None

    2. 如果您的数据集不是这些格式,则您应继续本步骤的其余部分,使用 SiftingBatchTransform 创建自定义批量转换。

      如果您的数据集不是 SageMaker 智能筛选支持的格式之一,则可能会遇到错误。此类数据格式错误可以通过在 SiftingDataloader 类中添加 batch_format_indexbatch_transforms 参数来解决,您可以在第 4 步中进行设置。下面显示了由于数据格式不兼容而导致的错误示例以及解决方法。

      错误消息 解决方案

      默认情况下,{type(batch)}不支持该类型的批处理。

      此错误表示默认不支持批次格式。您应该实现一个自定义批次转换类,并通过将其指定给 SiftingDataloader 类的 batch_transforms 参数中来使用它。

      无法为该批次编制索引 {type(batch)}

      此错误表明无法正常为批次对象编制索引。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      Batch 大小与维度 0 或维度 1 的大小{batch_size}不匹配

      当提供的批次大小与批次的维度 0 或维度 1 不匹配时,会出现此错误。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      维度 0 和维度 1 都匹配批次大小

      此错误表明,由于多个维度与提供的批次大小相匹配,因此需要更多信息来筛选批次。用户可提供 batch_format_index 参数,指示批次是否可按样本或特征编制索引。用户也可以实施自定义批次转换,但这比所需的工作量更大。

      要解决上述问题,您需要使用 SiftingBatchTransform 模块创建自定义批处理转换类。批次转换类应由一对转换和反向转换函数组成。函数对将您的数据格式转换为 SageMaker 智能筛选算法可以处理的格式。创建批次转换类后,此类会返回一个 SiftingBatch 对象,您将在第 4 步中把此对象传递给 SiftingDataloader 类。

      以下是 SiftingBatchTransform 模块中自定义批次转换类的示例。

      • 使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于数据加载器块包含输入、掩码和标签的情况。

        from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class ListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch
      • 使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于不需要标签进行反向转换的情况。

        class ListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch
      • 在数据加载器块有输入、掩码和标签的情况下,使用 SageMaker 智能筛选的自定义张量批处理实现示例。

        from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class TensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch

      在您创建已执行 SiftingBatchTransform 批次转换类后,可在第 4 步中使用 SiftingDataloader 类进行设置。本指南的其余部分假设已创建了一个 ListBatchTransform 类。在第 4 步中,此类将传递给 batch_transforms

  3. 创建用于实现 SageMaker 智能筛选Loss接口的类。本教程假定此类名为 SiftingImplementedLoss。在设置此类时,我们建议您在模型训练循环中使用相同的损失函数。按照以下子步骤创建 SageMaker 智能筛选Loss实现的类。

    1. SageMaker 智能筛选计算每个训练数据样本的损失值,而不是计算批次的单个损失值。为确保 SageMaker 智能筛选使用相同的损失计算逻辑,请使用 SageMaker 智能筛选Loss模块创建 smart-sifting-implemented损失函数,该模块使用您的损失函数并计算每个训练样本的损失。

      提示

      SageMaker 智能筛选算法在每个数据样本上运行,而不是在整个批次上运行,因此您应该添加一个初始化函数来设置 PyTorch 损失函数,而无需任何还原策略。

      class SiftingImplementedLoss(Loss): def __init__(self): self.loss = torch.nn.CrossEntropyLoss(reduction='none')

      以下代码示例也说明了这一点。

    2. 定义一个接受original_batch(或者transformed_batch如果您在步骤 2 中设置了批量变换)和 PyTorch模型的损失函数。 SageMaker 智能筛选使用不减值的指定损失函数,对每个数据样本进行正向传递,以评估其损失值。

    以下代码是一个名为的 smart-sifting-implementedLoss接口的示例SiftingImplementedLoss

    from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class SiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])

    在训练循环进入实际前向传递之前,每次迭代获取批次数据的数据加载阶段都会进行筛选损失计算。然后将单个损失值与之前的损失值进行比较,并根据步骤 1 中设置的 RelativeProbabilisticSiftConfig 对象估算出其相对百分位数。

  4. 按 SageMaker AI SiftingDataloader 类封装 PyTroch 数据加载器。

    最后,将您在前面步骤中配置的所有 SageMaker 智能筛选实现的类用于 SageMaker AI SiftingDataloder 配置类。这个类是的封装器。 PyTorch DataLoader通过封装 PyTorchDataLoader, SageMaker 智能筛选被注册为在 PyTorch 训练作业的每次迭代中作为数据加载的一部分运行。以下代码示例演示如何实现 SageMaker AI 数据筛选到. PyTorch DataLoader

    from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False )
隐私网站条款Cookie 首选项
© 2025, Amazon Web Services, Inc. 或其附属公司。保留所有权利。