Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.
Checkpoint di attivazione
Il checkpoint di attivazione (o checkpoint gradiente) è una tecnica per ridurre l'utilizzo della memoria cancellando le attivazioni di determinati livelli e ricalcolandole durante un passaggio all'indietro. In effetti, ciò consente di rinunciare a tempi di calcolo aggiuntivi per ridurre l'utilizzo della memoria. Se un modulo è sottoposto a checkpoint, alla fine di un passaggio in avanti, gli input e gli output dal modulo rimangono in memoria. Tutti i tensori intermedi che avrebbero fatto parte del calcolo all'interno di quel modulo vengono liberati durante il passaggio in avanti. Durante il passaggio all'indietro dei moduli di checkpoint, questi tensori vengono ricalcolati. A questo punto, i livelli oltre questo modulo di checkpoint hanno terminato il passaggio all'indietro, quindi il picco di utilizzo della memoria con il checkpoint può essere inferiore.
Nota
Questa funzionalità è disponibile PyTorch nella libreria di parallelismo dei modelli v1.6.0 e successive SageMaker .
Come usare il checkpoint di attivazione
Con smdistributed.modelparallel
, è possibile utilizzare il checkpoint di attivazione con la granularità di un modulo. Per tutti i moduli torch.nn
tranne torch.nn.Sequential
, è possibile effettuare il checkpoint di un albero di moduli solo se si trova all'interno di una partizione dal punto di vista del parallelismo di pipeline. Nel caso del modulo torch.nn.Sequential
, ogni albero dei moduli all'interno del modulo sequenziale deve trovarsi completamente all'interno di una partizione affinché il checkpoint di attivazione funzioni. Quando utilizzi il partizionamento manuale, tieni presente queste restrizioni.
Quando si utilizza il partizionamento automatico dei modelli, è possibile trovare i registri delle assegnazioni di partizionamento a partire da Partition assignments:
nei registri dei processi di addestramento. Se un modulo è partizionato su più livelli (ad esempio, con uno discendente su una classificazione e un altro discendente su una classificazione diversa), la libreria ignora il tentativo di checkpoint del modulo e genera un messaggio di avviso che indica che il modulo non verrà sottoposto a checkpoint.
Nota
La libreria di parallelismo dei SageMaker modelli supporta operazioni di sovrapposizione e non sovrapposizione in combinazione con il checkpoint. allreduce
Nota
PyTorchil API checkpoint nativo smdistributed.modelparallel
non è compatibile con.
Esempio 1: il seguente codice di esempio mostra come utilizzare il checkpoint di attivazione quando nello script è presente una definizione del modello.
import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)
Esempio 2: il seguente codice di esempio mostra come utilizzare il checkpoint di attivazione quando nello script è presente un modello sequenziale.
import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)
Esempio 3: Il codice di esempio seguente mostra come utilizzare il checkpoint di attivazione quando si importa un modello predefinito da una libreria, ad esempio PyTorch Hugging Face Transformers. Indipendentemente dal fatto che tu faccia o meno il checkpoint dei i moduli sequenziali, procedi come segue:
-
Effettua il wrapping del modello per
smp.DistributedModel()
. -
Definisci un oggetto per i livelli sequenziali.
-
Effettua il wrapping dell'oggetto del livello sequenziale con
smp.set_activation_checkpointig()
.
import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')