本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
將 SageMaker 智慧篩選套用至 PyTorch 指令碼
這些指示示範如何使用訓練指令碼啟用 SageMaker 智慧篩選。
-
設定 SageMaker 智慧篩選介面。
SageMaker 智慧篩選程式庫實作相對閾值損失型取樣技術,有助於篩選對降低損失值影響較低的範例。 SageMaker 智慧篩選演算法會使用向前傳遞計算每個輸入資料範例的損失值,並根據先前資料的損失值計算其相對百分位數。
下列兩個參數是建立篩選組態物件時,您需要指定
RelativeProbabilisticSiftConfig
類別的參數。-
指定訓練
beta_value
參數時應使用的資料比例。 -
指定用於與
loss_history_length
參數比較的樣本數量。
下列程式碼範例示範設定
RelativeProbabilisticSiftConfig
類別的物件。from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )
如需有關
loss_based_sift_config
參數和相關類別的詳細資訊,請參閱 SageMaker 智慧篩選 Python SDK參考章節SageMaker 智慧篩選組態模組中的 。上述程式碼範例中的
sift_config
物件用於步驟 4 中,以設定SiftingDataloader
類別。 -
-
(選用) 設定 SageMaker 智慧篩選批次轉換類別。
不同的訓練使用案例需要不同的訓練資料格式。考慮到各種資料格式, SageMaker 智慧篩分演算法需要識別如何在特定批次上執行篩分。為了解決此問題, SageMaker 智慧篩分提供批次轉換模組,可協助將批次轉換為標準化格式,以便有效率地篩選。
-
SageMaker 智慧型篩選會以下列格式處理訓練資料的批次轉換:Python 清單、字典、組合和張量。對於這些資料格式, SageMaker 智慧篩選會自動處理批次資料格式轉換,您可以略過此步驟的其餘部分。如果您略過此步驟,請在步驟 4 中設定
SiftingDataloader
,將batch_transforms
參數保留SiftingDataloader
為其預設值,即None
。 -
如果您的資料集不是這些格式,您應該繼續進行此步驟的其餘部分,以使用 建立自訂批次轉換
SiftingBatchTransform
。如果您的資料集不是 SageMaker 智慧篩選支援的格式之一,您可能會遇到錯誤。將
batch_format_index
或batch_transforms
參數新增至您在步驟 4 中設定的SiftingDataloader
類別,即可解決此類資料格式錯誤。以下顯示因資料格式和解析度不相容而造成的範例錯誤。錯誤訊息 解析度 類型的批次
{type(batch)}
預設不支援 。此錯誤表示預設不支援批次格式。您應該實作自訂批次轉換類別,並將它指定至 SiftingDataloader
類別的batch_transforms
參數來使用它。無法為 類型的批次編製索引
{type(batch)}
此錯誤表示批次物件無法正常編製索引。使用者必須實作自訂批次轉換,並使用 batch_transforms
參數傳遞此轉換。批次大小
{batch_size}
不符合維度 0 或維度 1 大小當提供的批次大小與批次的第 0 或第 1 個維度不相符時,就會發生此錯誤。使用者必須實作自訂批次轉換,並使用 batch_transforms
參數傳遞此轉換。維度 0 和維度 1 皆符合批次大小
此錯誤表示,由於多個維度與提供的批次大小相符,因此需要更多資訊才能篩選批次。使用者可以提供 batch_format_index
參數,以指示批次是否可依範例或功能編製索引。使用者也可以實作自訂批次轉換,但這比所需的工作更多。若要解決上述問題,您需要使用
SiftingBatchTransform
模組建立自訂批次轉換類別。批次轉換類別應包含一對轉換和反向轉換函數。函數對會將您的資料格式轉換為 SageMaker 智慧篩選演算法可以處理的格式。建立批次轉換類別之後,類別會傳回要在步驟 4 中傳遞給SiftingDataloader
類別的SiftingBatch
物件。以下是
SiftingBatchTransform
模組的自訂批次轉換類別範例。-
自訂清單批次轉換實作的範例,針對資料載入器區塊具有輸入、遮罩和標籤的案例使用 SageMaker 智慧篩選。
from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class
ListBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch -
自訂清單批次轉換實作的範例,針對不需要標籤即可進行反向轉換的案例,使用 SageMaker 智慧篩選。
class
ListBatchTransformNoLabels
(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch -
自訂張量批次實作的範例,針對資料載入器區塊具有輸入、遮罩和標籤的情況,採用 SageMaker 智慧篩選。
from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class
TensorBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch
建立
SiftingBatchTransform
植入的批次轉換類別後,您可以在步驟 4 中使用此類別來設定SiftingDataloader
類別。本指南的其餘部分假設已建立ListBatchTransform
類別。在步驟 4 中,此類別會傳遞至batch_transforms
。 -
-
-
建立類別以實作 SageMaker 智慧型篩選
Loss
介面。本教學課程假設 類別名為SiftingImplementedLoss
。設定此類別時,建議您在模型訓練迴圈中使用相同的損失函數。執行下列子步驟,以建立Loss
實作的 SageMaker 智慧型篩選類別。-
SageMaker 智慧型篩選會計算每個訓練資料範例的損失值,而不是計算批次的單一損失值。為了確保 SageMaker 智慧篩分使用相同的損失計算邏輯,請使用 SageMaker 智慧型篩分
Loss
模組建立 smart-sifting-implemented損失函數,該模組使用您的損失函數並計算每個訓練範例的損失。提示
SageMaker 智慧篩選演算法會在每個資料範例上執行,而不是在整個批次上執行,因此您應該新增初始化函數來設定 PyTorch 損失函數,而不需要任何減少策略。
class
SiftingImplementedLoss
(Loss): def __init__(self): self.loss =torch.nn.CrossEntropyLoss
(reduction='none')這也會顯示在下列程式碼範例中。
-
定義接受
original_batch
(或transformed_batch
如果您已在步驟 2 中設定批次轉換) 和模型的 PyTorch損失函數。使用指定的損失函數而不減少, SageMaker 智慧篩分會為每個資料範例執行向前傳遞,以評估其損失值。
下列程式碼是名為 的 smart-sifting-implemented
Loss
介面範例SiftingImplementedLoss
。from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class
SiftingImplementedLoss
(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction
=torch.nn.CrossEntropyLoss
(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction
(outputs.logits, batch[2])在訓練循環達到實際向前傳遞之前,此篩選損失計算會在每次迭代中擷取批次的資料載入階段完成。然後,將個別損失值與先前的損失值進行比較,並根據您在步驟 1
RelativeProbabilisticSiftConfig
中設定的物件來估計其相對百分位數。 -
-
依 SageMaker
SiftingDataloader
類別包裝 PyTroch 資料載入器。最後,使用您在先前步驟中設定的所有 SageMaker 智慧篩選實作類別,以 SageMaker
SiftingDataloder
傳送至組態類別。此類別是 的包裝函式 PyTorchDataLoader
。透過包裝 PyTorch DataLoader
, SageMaker 智慧型篩選會註冊為 PyTorch 在訓練任務的每個迭代中作為資料載入的一部分執行。下列程式碼範例示範實作 SageMaker 資料篩選至 PyTorchDataLoader
。from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=
sift_config
, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader
, batch_transforms=ListBatchTransform
(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss
(), # PyTorch loss function wrapped by the Sifting Loss interface model=model
, log_batch_data=False
)