Checkpoint di attivazione - Amazon SageMaker

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

Checkpoint di attivazione

Il checkpoint di attivazione è una tecnica per ridurre l'utilizzo della memoria cancellando le attivazioni di determinati livelli e ricalcolandole durante il passaggio all'indietro. In effetti, ciò consente di risparmiare tempo di calcolo aggiuntivo per ridurre l'utilizzo della memoria. Se un modulo è sottoposto a checkpoint, alla fine di un passaggio in avanti, rimangono in memoria solo gli input iniziali al modulo e le uscite finali del modulo. PyTorch rilascia tutti i tensori intermedi che fanno parte del calcolo all'interno di quel modulo durante il passaggio in avanti. Durante il passaggio all'indietro dei moduli checkpoint, ricalcola questi tensori. PyTorch A questo punto, i livelli oltre questo modulo di checkpoint hanno terminato il passaggio all'indietro, quindi il picco di utilizzo della memoria con il checkpoint si riduce.

SMPv2 supporta il modulo di checkpoint di attivazione, PyTorch . apply_activation_checkpointing Di seguito sono riportati alcuni esempi di checkpoint di attivazione del modello GPT Hugging Face -NeoX.

Strati Checkpointing Transformer del modello Hugging Face -NeoX GPT

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Controllo di ogni altro strato Transformer del modello Hugging Face -NeoX GPT

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

In alternativa, ha PyTorch anche il torch.utils.checkpoint modulo per il checkpointing, che viene utilizzato da un sottoinsieme di modelli Hugging Face Transformers. Questo modulo funziona anche con v2. SMP Tuttavia, richiede l'accesso alla definizione del modello per aggiungere il checkpoint wrapper. Pertanto, ti consigliamo di utilizzare il metodo. apply_activation_checkpointing