Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.
Allenamento misto di precisione
La SageMaker model parallelism (SMP) library v2 supporta l'addestramento di precisione misto pronto all'uso grazie all'integrazione con framework open source come Transformer Engine. PyTorch FSDP Per ulteriori informazioni, consulta i seguenti argomenti.
Argomenti
Addestramento di precisione misto con nessuna istanza P5 utilizzando Transformer Engine FP8
A partire dalla libreria SageMaker model parallelism (SMP) v2.2.0, la SMP libreria si integra con Transformer EngineMixedPrecision
Nota
SMPv2 offre FP8 supporto per i seguenti modelli Hugging Face Transformer:
-
GPT-NeoX (disponibile nella versione 2.2.0 e successive) SMP
-
Llama 2 (disponibile nella versione 2.2.0 e successive) SMP
-
Mixtral 8x7b e Mixtral 8x22b (disponibili nella versione 2.5.0 e successive) SMP
Nota
Questo FP8 corso di formazione sulla funzionalità P5 è disponibile nella seguente combinazione di librerie di e libreria: SageMaker PyTorch
-
SageMaker Python SDK v2.212.0 e versioni successive
-
PyTorch v2.2.0 e versioni successive
FP8(precisione in virgola mobile a 8 bit) è un tipo di dati emerso come un altro paradigma per accelerare l'addestramento dei modelli tramite deep learning. LLM Con il rilascio dei tipi di FP8 dati GPUs che supportano l'NVIDIAH100, puoi sfruttare i vantaggi derivanti dai miglioramenti delle prestazioni sulle istanze P5 dotate di H100GPUs, accelerando al contempo l'addestramento distribuito con un addestramento di precisione misto. FP8
Il tipo di FP8 dati si estende ulteriormente ai formati E4M3 ed E5M2. L'E4M3 offre una maggiore precisione, ha una gamma dinamica limitata ed è ideale per l'avanzamento nell'addestramento dei modelli. E5M2 ha una gamma dinamica più ampia, ma una precisione ridotta, ed è più adatto per il passaggio all'indietro, dove la precisione è meno critica e una gamma dinamica più ampia diventa vantaggiosa. Pertanto, ti consigliamo di utilizzare la ricetta della FP8 strategia ibrida
Per i tipi di dati a mezza precisione (FP16eBF16), le tecniche globali di scalabilità delle perdite come la scalabilità statica delle perdite o la scalabilità dinamica delle perdite gestiscono i problemi di convergenza derivanti dalla perdita di informazioni dovuta all'arrotondamento dei gradienti a semiprecisione. Tuttavia, l'intervallo dinamico di è ancora più ristretto e le tecniche di scala globale delle perdite non sono sufficienti. FP8 A questo punto, abbiamo bisogno di una tecnica di ridimensionamento per tensore a grana più fine. Il ridimensionamento ritardato è una strategia che seleziona un fattore di scala basato sui valori massimi assoluti osservati in una serie di tensori nelle iterazioni precedenti. Questa strategia presenta un compromesso: sfrutta tutti i vantaggi prestazionali del FP8 calcolo, ma richiede memoria per conservare la cronologia dei valori massimi dei tensori. Per saperne di più sulla strategia di scalabilità ritardata in generale, consulta il paper FP8Formats for Deep
In pratica, l'utilizzo FP8 è utile in tutti gli scenari di addestramento sulle istanze P5. Ti consigliamo vivamente di abilitarlo FP8 quando possibile per migliorare le prestazioni dell'allenamento.
SMPLa v2 supporta Transformer Engine fin dall'inizio. Pertanto, quando si esegue l'FP8allenamento con SMP v2 su istanze P5 di SageMaker (ml.p5.48xlarge
), l'unica cosa che devi fare è importare lo script di allenamento e continuare torch.sagemaker
a utilizzare il pacchetto Python Transformer Engine nativo. Per ulteriori informazioni sull'uso di Transformer Engine per la FP8 formazione in generale, consulta Using FP8 with Transformer Engine nella documentazione di Transformer
import torch.sagemaker as tsm import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format # Initialize the SMP torch.sagemaker API. tsm.init() # Define a transformer model and wrap it with the torch.sagemaker.transform API. from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(
ModelConfig
) model = tsm.transform(model) # Enable E4M3 during forward pass, E5M2 during backward pass. fp8_format = Format.HYBRID # Create an FP8 recipe. fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Enable FP8 autocasting. with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group): out = model(inp) loss = out.sum() loss.backward()
Per trovare un esempio pratico di FP8 addestramento con la SMP v2 sulle istanze P5, consultate il taccuino di esempio su Accelerate SageMaker PyTorch FSDP Training of Llama-v2
Addestramento di precisione misto con tipi di dati a semiprecisione utilizzando PyTorch FSDP
SMPv2 supporta i lavori PyTorch FSDPMixedPrecision
Nota
Questo addestramento misto di precisione con la PyTorch FSDP funzionalità è disponibile nella seguente combinazione di librerie di SageMaker e PyTorch libreria.
-
SMPv2.0.0 e versioni successive
-
SageMaker Python SDK v2.200.0 e versioni successive
-
PyTorch v2.0.1 e versioni successive
Il metodo standard per configurare un modello a precisione mista consiste nel creare il modello interno float32
e quindi consentire FSDP la trasmissione dei parametri bfloat16
su float16
o al volo mediante l'applicazione di una MixedPrecision
policy, come illustrato nel seguente frammento di codice. Per ulteriori informazioni sulle opzioni per modificare i parametri dtype
for, la riduzione o i buffer a precisione mista PyTorch, PyTorch FSDPMixedPrecision
API
# Native PyTorch API from torch.distributed.fsdp import MixedPrecision dtype = torch.bfloat16 mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype ) model = FSDP( model, ..., mixed_precision=mixed_precision_policy )
Nota che alcuni modelli (come il modello Hugging Face Transformers Llama) prevedono buffer come. float32
Per utilizzarlofloat32
, sostituitelo torch.bfloat16
con torch.float32
nella riga che definisce l'oggetto. dtype