Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.
Terapkan penyaringan SageMaker cerdas ke skrip Anda PyTorch
Instruksi ini menunjukkan cara mengaktifkan penyaringan SageMaker cerdas dengan skrip pelatihan Anda.
-
Konfigurasikan antarmuka penyaringan SageMaker cerdas.
Pustaka penyaringan SageMaker cerdas menerapkan teknik pengambilan sampel berbasis kerugian ambang batas relatif yang membantu menyaring sampel dengan dampak yang lebih rendah dalam mengurangi nilai kerugian. Algoritma penyaringan SageMaker cerdas menghitung nilai kerugian dari setiap sampel data input menggunakan pass maju, dan menghitung persentil relatifnya terhadap nilai kehilangan data sebelumnya.
Dua parameter berikut adalah apa yang perlu Anda tentukan ke
RelativeProbabilisticSiftConfig
kelas untuk membuat objek konfigurasi penyaringan.-
Tentukan proporsi data yang harus digunakan untuk pelatihan ke
beta_value
parameter. -
Tentukan jumlah sampel yang digunakan dalam perbandingan dengan
loss_history_length
parameter.
Contoh kode berikut menunjukkan pengaturan objek
RelativeProbabilisticSiftConfig
kelas.from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )
Untuk informasi selengkapnya tentang
loss_based_sift_config
parameter dan kelas terkait, lihat SageMaker modul konfigurasi penyaringan cerdas di bagian referensi SDK Python penyaringan SageMaker cerdas.sift_config
Objek dalam contoh kode sebelumnya digunakan pada langkah 4 untuk menyiapkan kelas.SiftingDataloader
-
-
(Opsional) Konfigurasikan kelas transformasi batch penyaringan SageMaker cerdas.
Kasus penggunaan pelatihan yang berbeda memerlukan format data pelatihan yang berbeda. Mengingat berbagai format data, algoritma penyaringan SageMaker cerdas perlu mengidentifikasi cara melakukan penyaringan pada batch tertentu. Untuk mengatasi hal ini, SageMaker smart sifting menyediakan modul transformasi batch yang membantu mengubah batch menjadi format standar yang dapat disaring secara efisien.
-
SageMaker smart sifting menangani transformasi batch data pelatihan dalam format berikut: Daftar Python, kamus, tupel, dan tensor. Untuk format data ini, SageMaker smart sifting secara otomatis menangani konversi format data batch, dan Anda dapat melewati sisa langkah ini. Jika Anda melewati langkah ini, pada langkah 4 untuk mengkonfigurasi
SiftingDataloader
, biarkanbatch_transforms
parameterSiftingDataloader
ke nilai defaultnya, yaituNone
. -
Jika kumpulan data Anda tidak dalam format ini, Anda harus melanjutkan ke sisa langkah ini untuk membuat transformasi batch khusus menggunakan
SiftingBatchTransform
.Dalam kasus di mana kumpulan data Anda tidak berada dalam salah satu format yang didukung oleh penyaringan SageMaker cerdas, Anda mungkin mengalami kesalahan. Kesalahan format data tersebut dapat diatasi dengan menambahkan
batch_transforms
parameterbatch_format_index
or keSiftingDataloader
kelas, yang Anda atur di langkah 4. Berikut ini menunjukkan contoh kesalahan karena format data yang tidak kompatibel dan resolusi untuk mereka.Pesan Kesalahan Resolusi Jenis batch
{type(batch)}
tidak didukung secara default.Kesalahan ini menunjukkan format batch tidak didukung secara default. Anda harus menerapkan kelas transformasi batch kustom, dan menggunakannya dengan menentukannya ke batch_transforms
parameterSiftingDataloader
kelas.Tidak dapat mengindeks kumpulan jenis
{type(batch)}
Kesalahan ini menunjukkan objek batch tidak dapat diindeks secara normal. Pengguna harus menerapkan transformasi batch khusus dan meneruskan ini menggunakan batch_transforms
parameter.Ukuran batch
{batch_size}
tidak cocok dengan dimensi 0 atau dimensi 1 ukuranKesalahan ini terjadi ketika ukuran batch yang disediakan tidak sesuai dengan dimensi ke-0 atau ke-1 dari batch. Pengguna harus menerapkan transformasi batch khusus dan meneruskan ini menggunakan batch_transforms
parameter.Dimensi 0 dan dimensi 1 cocok dengan ukuran batch
Kesalahan ini menunjukkan bahwa karena beberapa dimensi cocok dengan ukuran batch yang disediakan, informasi lebih lanjut diperlukan untuk menyaring batch. Pengguna dapat memberikan batch_format_index
parameter untuk menunjukkan apakah batch dapat diindeks berdasarkan sampel atau fitur. Pengguna juga dapat menerapkan transformasi batch khusus, tetapi ini lebih banyak pekerjaan daripada yang diperlukan.Untuk mengatasi masalah yang disebutkan di atas, Anda perlu membuat kelas transformasi batch khusus menggunakan
SiftingBatchTransform
modul. Kelas transformasi batch harus terdiri dari sepasang fungsi transformasi dan reverse-transform. Pasangan fungsi mengonversi format data Anda ke format yang dapat diproses oleh algoritme penyaringan SageMaker cerdas. Setelah Anda membuat kelas transformasi batch, kelas mengembalikanSiftingBatch
objek yang akan Anda berikan keSiftingDataloader
kelas di langkah 4.Berikut ini adalah contoh kelas transformasi batch kustom
SiftingBatchTransform
modul.-
Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana potongan dataloader memiliki input, mask, dan label.
from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class
ListBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch -
Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana tidak ada label yang diperlukan untuk transformasi terbalik.
class
ListBatchTransformNoLabels
(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch -
Contoh implementasi batch tensor khusus dengan penyaringan SageMaker cerdas untuk kasus di mana potongan pemuat data memiliki input, masker, dan label.
from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class
TensorBatchTransform
(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch
Setelah Anda membuat
SiftingBatchTransform
kelas transformasi batch yang diimplementasikan, Anda menggunakan kelas ini di langkah 4 untuk menyiapkan kelas.SiftingDataloader
Sisa panduan ini mengasumsikan bahwaListBatchTransform
kelas dibuat. Pada langkah 4, kelas ini diteruskan kebatch_transforms
. -
-
-
Buat kelas untuk mengimplementasikan
Loss
antarmuka penyaringan SageMaker cerdas. Tutorial ini mengasumsikan bahwa kelas diberi namaSiftingImplementedLoss
. Saat menyiapkan kelas ini, kami sarankan Anda menggunakan fungsi kerugian yang sama dalam loop pelatihan model. Ikuti sublangkah berikut untuk membuat kelasLoss
implementasi penyaringan SageMaker cerdas.-
SageMaker smart sifting menghitung nilai kerugian untuk setiap sampel data pelatihan, sebagai lawan menghitung nilai kerugian tunggal untuk batch. Untuk memastikan bahwa penyaringan SageMaker cerdas menggunakan logika perhitungan kerugian yang sama, buat fungsi smart-sifting-implemented kerugian menggunakan
Loss
modul penyaringan SageMaker pintar yang menggunakan fungsi kerugian Anda dan hitung kerugian per sampel pelatihan.Tip
SageMaker algoritma penyaringan cerdas berjalan pada setiap sampel data, bukan pada seluruh batch, jadi Anda harus menambahkan fungsi inisialisasi untuk mengatur fungsi PyTorch kerugian tanpa strategi pengurangan apa pun.
class
SiftingImplementedLoss
(Loss): def __init__(self): self.loss =torch.nn.CrossEntropyLoss
(reduction='none')Ini juga ditunjukkan dalam contoh kode berikut.
-
Tentukan fungsi kerugian yang menerima
original_batch
(atautransformed_batch
jika Anda telah menyiapkan transformasi batch pada langkah 2) dan PyTorch model. Menggunakan fungsi kerugian yang ditentukan tanpa pengurangan, SageMaker smart sifting menjalankan forward pass untuk setiap sampel data untuk mengevaluasi nilai kerugiannya.
Kode berikut adalah contoh dari smart-sifting-implemented
Loss
antarmuka bernamaSiftingImplementedLoss
.from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class
SiftingImplementedLoss
(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction
=torch.nn.CrossEntropyLoss
(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction
(outputs.logits, batch[2])Sebelum loop pelatihan mencapai pass maju yang sebenarnya, perhitungan kerugian penyaringan ini dilakukan selama fase pemuatan data pengambilan batch di setiap iterasi. Nilai kerugian individu kemudian dibandingkan dengan nilai kerugian sebelumnya, dan persentil relatifnya diperkirakan per objek yang telah
RelativeProbabilisticSiftConfig
Anda atur pada langkah 1. -
-
Bungkus pemuat PyTroch data dengan SageMaker
SiftingDataloader
kelas.Terakhir, gunakan semua kelas implementasi penyaringan SageMaker cerdas yang Anda konfigurasikan pada langkah sebelumnya ke kelas SageMaker
SiftingDataloder
konfigurasi. Kelas ini adalah pembungkus untuk PyTorchDataLoader
. Dengan membungkus PyTorch DataLoader
, SageMaker smart sifting terdaftar untuk dijalankan sebagai bagian dari pemuatan data di setiap iterasi pekerjaan pelatihan. PyTorch Contoh kode berikut menunjukkan penerapan penyaringan SageMaker data ke a. PyTorchDataLoader
from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=
sift_config
, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader
, batch_transforms=ListBatchTransform
(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss
(), # PyTorch loss function wrapped by the Sifting Loss interface model=model
, log_batch_data=False
)