Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.
Checkpointing aktivasi (atau checkpointing gradien) adalah teknik untuk mengurangi penggunaan memori dengan membersihkan aktivasi lapisan tertentu dan mengkomputernya kembali selama lintasan mundur. Secara efektif, ini memperdagangkan waktu komputasi ekstra untuk mengurangi penggunaan memori. Jika modul diperiksa, di akhir pass maju, input ke dan output dari modul tetap berada di memori. Tensor perantara apa pun yang akan menjadi bagian dari perhitungan di dalam modul itu dibebaskan selama pass maju. Selama lintasan mundur modul checkpoint, tensor ini dihitung ulang. Pada titik ini, lapisan di luar modul checkpointed ini telah menyelesaikan backward pass mereka, sehingga penggunaan memori puncak dengan checkpointing bisa lebih rendah.
catatan
Fitur ini tersedia untuk PyTorch di pustaka paralelisme SageMaker model v1.6.0 dan yang lebih baru.
Cara Menggunakan Checkpointing Aktivasi
Dengansmdistributed.modelparallel
, Anda dapat menggunakan pos pemeriksaan aktivasi pada perincian modul. Untuk semua torch.nn
modul kecualitorch.nn.Sequential
, Anda hanya dapat memeriksa pohon modul jika terletak dalam satu partisi dari perspektif paralelisme pipa. Dalam kasus torch.nn.Sequential
modul, setiap pohon modul di dalam modul sekuensial harus terletak sepenuhnya dalam satu partisi agar pos pemeriksaan aktivasi berfungsi. Saat Anda menggunakan partisi manual, perhatikan batasan ini.
Saat Anda menggunakan partisi model otomatis, Anda dapat menemukan log tugas partisi yang dimulai dengan Partition assignments:
di log pekerjaan pelatihan. Jika modul dipartisi di beberapa peringkat (misalnya, dengan satu keturunan pada satu peringkat dan keturunan lain pada peringkat yang berbeda), perpustakaan mengabaikan upaya untuk memeriksa modul dan memunculkan pesan peringatan bahwa modul tidak akan diperiksa.
catatan
Pustaka paralelisme SageMaker model mendukung operasi yang tumpang tindih dan tidak tumpang tindih dalam kombinasi dengan pos allreduce
pemeriksaan.
catatan
PyTorchAPI checkpointing asli tidak kompatibel dengan. smdistributed.modelparallel
Contoh 1: Kode contoh berikut menunjukkan cara menggunakan checkpointing aktivasi ketika Anda memiliki definisi model dalam skrip Anda.
import torch.nn as nn
import torch.nn.functional as F
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
# This call of fc1 will be checkpointed
x = checkpoint(self.fc1, x)
x = self.fc2(x)
return F.log_softmax(x, 1)
Contoh 2: Kode contoh berikut menunjukkan cara menggunakan checkpointing aktivasi ketika Anda memiliki model sekuensial dalam skrip Anda.
import torch.nn as nn
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
def forward(self, x):
# This call of self.seq will be checkpointed
x = checkpoint_sequential(self.seq, x)
return F.log_softmax(x, 1)
Contoh 3: Contoh kode berikut menunjukkan cara menggunakan checkpointing aktivasi saat Anda mengimpor model bawaan dari pustaka, seperti dan PyTorch Hugging Face Transformers. Apakah Anda memeriksa modul sekuensial atau tidak, lakukan hal berikut:
-
Bungkus model dengan
smp.DistributedModel()
. -
Tentukan objek untuk lapisan berurutan.
-
Bungkus objek layer sekuensial dengan
smp.set_activation_checkpointig()
.
import smdistributed.modelparallel.torch as smp
from transformers import AutoModelForCausalLM
smp.init()
model = AutoModelForCausalLM(*args, **kwargs)
model = smp.DistributedModel(model)
# Call set_activation_checkpointing API
transformer_layers = model.module.module.module.transformer.seq_layers
smp.set_activation_checkpointing(
transformer_layers, pack_args_as_tuple=True, strategy='each')