Terapkan penyaringan SageMaker cerdas ke skrip Anda PyTorch - Amazon SageMaker

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

Terapkan penyaringan SageMaker cerdas ke skrip Anda PyTorch

Instruksi ini menunjukkan cara mengaktifkan penyaringan SageMaker cerdas dengan skrip pelatihan Anda.

  1. Konfigurasikan antarmuka penyaringan SageMaker cerdas.

    Pustaka penyaringan SageMaker cerdas menerapkan teknik pengambilan sampel berbasis kerugian ambang batas relatif yang membantu menyaring sampel dengan dampak yang lebih rendah dalam mengurangi nilai kerugian. Algoritma penyaringan SageMaker cerdas menghitung nilai kerugian dari setiap sampel data input menggunakan pass maju, dan menghitung persentil relatifnya terhadap nilai kehilangan data sebelumnya.

    Dua parameter berikut adalah apa yang perlu Anda tentukan ke RelativeProbabilisticSiftConfig kelas untuk membuat objek konfigurasi penyaringan.

    • Tentukan proporsi data yang harus digunakan untuk pelatihan ke beta_value parameter.

    • Tentukan jumlah sampel yang digunakan dalam perbandingan dengan loss_history_length parameter.

    Contoh kode berikut menunjukkan pengaturan objek RelativeProbabilisticSiftConfig kelas.

    from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )

    Untuk informasi selengkapnya tentang loss_based_sift_config parameter dan kelas terkait, lihat SageMaker modul konfigurasi penyaringan cerdas di bagian referensi SDK Python penyaringan SageMaker cerdas.

    sift_configObjek dalam contoh kode sebelumnya digunakan pada langkah 4 untuk menyiapkan kelas. SiftingDataloader

  2. (Opsional) Konfigurasikan kelas transformasi batch penyaringan SageMaker cerdas.

    Kasus penggunaan pelatihan yang berbeda memerlukan format data pelatihan yang berbeda. Mengingat berbagai format data, algoritma penyaringan SageMaker cerdas perlu mengidentifikasi cara melakukan penyaringan pada batch tertentu. Untuk mengatasi hal ini, SageMaker smart sifting menyediakan modul transformasi batch yang membantu mengubah batch menjadi format standar yang dapat disaring secara efisien.

    1. SageMaker smart sifting menangani transformasi batch data pelatihan dalam format berikut: Daftar Python, kamus, tupel, dan tensor. Untuk format data ini, SageMaker smart sifting secara otomatis menangani konversi format data batch, dan Anda dapat melewati sisa langkah ini. Jika Anda melewati langkah ini, pada langkah 4 untuk mengkonfigurasiSiftingDataloader, biarkan batch_transforms parameter SiftingDataloader ke nilai defaultnya, yaituNone.

    2. Jika kumpulan data Anda tidak dalam format ini, Anda harus melanjutkan ke sisa langkah ini untuk membuat transformasi batch khusus menggunakanSiftingBatchTransform.

      Dalam kasus di mana kumpulan data Anda tidak berada dalam salah satu format yang didukung oleh penyaringan SageMaker cerdas, Anda mungkin mengalami kesalahan. Kesalahan format data tersebut dapat diatasi dengan menambahkan batch_transforms parameter batch_format_index or ke SiftingDataloader kelas, yang Anda atur di langkah 4. Berikut ini menunjukkan contoh kesalahan karena format data yang tidak kompatibel dan resolusi untuk mereka.

      Pesan Kesalahan Resolusi

      Jenis batch {type(batch)} tidak didukung secara default.

      Kesalahan ini menunjukkan format batch tidak didukung secara default. Anda harus menerapkan kelas transformasi batch kustom, dan menggunakannya dengan menentukannya ke batch_transforms parameter SiftingDataloader kelas.

      Tidak dapat mengindeks kumpulan jenis {type(batch)}

      Kesalahan ini menunjukkan objek batch tidak dapat diindeks secara normal. Pengguna harus menerapkan transformasi batch khusus dan meneruskan ini menggunakan batch_transforms parameter.

      Ukuran batch {batch_size} tidak cocok dengan dimensi 0 atau dimensi 1 ukuran

      Kesalahan ini terjadi ketika ukuran batch yang disediakan tidak sesuai dengan dimensi ke-0 atau ke-1 dari batch. Pengguna harus menerapkan transformasi batch khusus dan meneruskan ini menggunakan batch_transforms parameter.

      Dimensi 0 dan dimensi 1 cocok dengan ukuran batch

      Kesalahan ini menunjukkan bahwa karena beberapa dimensi cocok dengan ukuran batch yang disediakan, informasi lebih lanjut diperlukan untuk menyaring batch. Pengguna dapat memberikan batch_format_index parameter untuk menunjukkan apakah batch dapat diindeks berdasarkan sampel atau fitur. Pengguna juga dapat menerapkan transformasi batch khusus, tetapi ini lebih banyak pekerjaan daripada yang diperlukan.

      Untuk mengatasi masalah yang disebutkan di atas, Anda perlu membuat kelas transformasi batch khusus menggunakan SiftingBatchTransform modul. Kelas transformasi batch harus terdiri dari sepasang fungsi transformasi dan reverse-transform. Pasangan fungsi mengonversi format data Anda ke format yang dapat diproses oleh algoritme penyaringan SageMaker cerdas. Setelah Anda membuat kelas transformasi batch, kelas mengembalikan SiftingBatch objek yang akan Anda berikan ke SiftingDataloader kelas di langkah 4.

      Berikut ini adalah contoh kelas transformasi batch kustom SiftingBatchTransform modul.

      • Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana potongan dataloader memiliki input, mask, dan label.

        from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class ListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch
      • Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana tidak ada label yang diperlukan untuk transformasi terbalik.

        class ListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch
      • Contoh implementasi batch tensor khusus dengan penyaringan SageMaker cerdas untuk kasus di mana potongan pemuat data memiliki input, masker, dan label.

        from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class TensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch

      Setelah Anda membuat SiftingBatchTransform kelas transformasi batch yang diimplementasikan, Anda menggunakan kelas ini di langkah 4 untuk menyiapkan kelas. SiftingDataloader Sisa panduan ini mengasumsikan bahwa ListBatchTransform kelas dibuat. Pada langkah 4, kelas ini diteruskan kebatch_transforms.

  3. Buat kelas untuk mengimplementasikan Loss antarmuka penyaringan SageMaker cerdas. Tutorial ini mengasumsikan bahwa kelas diberi namaSiftingImplementedLoss. Saat menyiapkan kelas ini, kami sarankan Anda menggunakan fungsi kerugian yang sama dalam loop pelatihan model. Ikuti sublangkah berikut untuk membuat kelas Loss implementasi penyaringan SageMaker cerdas.

    1. SageMaker smart sifting menghitung nilai kerugian untuk setiap sampel data pelatihan, sebagai lawan menghitung nilai kerugian tunggal untuk batch. Untuk memastikan bahwa penyaringan SageMaker cerdas menggunakan logika perhitungan kerugian yang sama, buat fungsi smart-sifting-implemented kerugian menggunakan Loss modul penyaringan SageMaker pintar yang menggunakan fungsi kerugian Anda dan hitung kerugian per sampel pelatihan.

      Tip

      SageMaker algoritma penyaringan cerdas berjalan pada setiap sampel data, bukan pada seluruh batch, jadi Anda harus menambahkan fungsi inisialisasi untuk mengatur fungsi PyTorch kerugian tanpa strategi pengurangan apa pun.

      class SiftingImplementedLoss(Loss): def __init__(self): self.loss = torch.nn.CrossEntropyLoss(reduction='none')

      Ini juga ditunjukkan dalam contoh kode berikut.

    2. Tentukan fungsi kerugian yang menerima original_batch (atau transformed_batch jika Anda telah menyiapkan transformasi batch pada langkah 2) dan PyTorch model. Menggunakan fungsi kerugian yang ditentukan tanpa pengurangan, SageMaker smart sifting menjalankan forward pass untuk setiap sampel data untuk mengevaluasi nilai kerugiannya.

    Kode berikut adalah contoh dari smart-sifting-implemented Loss antarmuka bernamaSiftingImplementedLoss.

    from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class SiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])

    Sebelum loop pelatihan mencapai pass maju yang sebenarnya, perhitungan kerugian penyaringan ini dilakukan selama fase pemuatan data pengambilan batch di setiap iterasi. Nilai kerugian individu kemudian dibandingkan dengan nilai kerugian sebelumnya, dan persentil relatifnya diperkirakan per objek yang telah RelativeProbabilisticSiftConfig Anda atur pada langkah 1.

  4. Bungkus pemuat PyTroch data dengan SageMaker SiftingDataloader kelas.

    Terakhir, gunakan semua kelas implementasi penyaringan SageMaker cerdas yang Anda konfigurasikan pada langkah sebelumnya ke kelas SageMaker SiftingDataloder konfigurasi. Kelas ini adalah pembungkus untuk PyTorch DataLoader. Dengan membungkus PyTorchDataLoader, SageMaker smart sifting terdaftar untuk dijalankan sebagai bagian dari pemuatan data di setiap iterasi pekerjaan pelatihan. PyTorch Contoh kode berikut menunjukkan penerapan penyaringan SageMaker data ke a. PyTorch DataLoader

    from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False )