Checkpointing usando SMP - Amazon SageMaker AI

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

Checkpointing usando SMP

La libreria SageMaker model parallelism (SMP) supporta i checkpoint e fornisce APIs questo supporto PyTorch APIs per il corretto funzionamento del checkpoint durante l'utilizzo della libreria. SMP

PyTorch FSDP(Fully Sharded Data Parallelism) supporta tre tipi di checkpoint: completi, frammentati e locali, ciascuno con scopi diversi. I checkpoint completi vengono utilizzati quando si esporta il modello dopo il completamento dell'addestramento, poiché la generazione di un checkpoint completo è un processo computazionalmente costoso. I checkpoint suddivisi aiutano a salvare e caricare lo stato di un modello suddiviso per ogni singolo rango. Con i checkpoint sharded, puoi riprendere l'allenamento con diverse configurazioni hardware, ad esempio un numero diverso di. GPUs Tuttavia, il caricamento di checkpoint frammentati può essere lento a causa della comunicazione tra più dispositivi. La SMP libreria fornisce funzionalità di checkpoint locali, che consentono un recupero più rapido dello stato del modello senza sovraccarichi di comunicazione aggiuntivi. Tieni presente che i checkpoint creati da FSDP richiedono la scrittura su un file system di rete condiviso come AmazonFSx.

Checkpoint locali asincroni

Durante l'addestramento dei modelli di machine learning, non è necessario attendere che i file di checkpoint vengano salvati su disco nelle iterazioni successive. Con la versione SMP v2.5, la libreria supporta il salvataggio asincrono dei file di checkpoint. Ciò significa che la successiva iterazione di addestramento può essere eseguita contemporaneamente alle operazioni di input e output (I/O) per la creazione di checkpoint, senza essere rallentata o frenata da tali operazioni di I/O. Inoltre, il processo di recupero dei parametri del modello condiviso e dell'ottimizzatore PyTorch può richiedere molto tempo a causa della comunicazione collettiva aggiuntiva necessaria per lo scambio di metadati tensoriali distribuiti tra i ranghi. Anche quando viene utilizzato StateDictType.LOCAL_STATE_DICT per salvare i checkpoint locali per ogni rango, richiama comunque gli hook che eseguono comunicazioni collettive. PyTorch Per mitigare questo problema e ridurre il tempo necessario per il recupero dei checkpoint, SMP introduceSMStateDictType.SM_LOCAL_STATE_DICT, che consente un recupero più rapido del modello e ottimizza i checkpoint aggirando il sovraccarico di comunicazione collettiva.

Nota

Mantenere la coerenza in è un requisito per l'utilizzo di. FSDP SHARD_DEGREE SMStateDictType.SM_LOCAL_STATE_DICT Assicurati che SHARD_DEGREE rimanga invariato. Sebbene il numero di repliche del modello possa variare, il grado di frammentazione del modello deve essere identico alla configurazione di formazione precedente quando si riprende da un checkpoint.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

Il seguente frammento di codice mostra come caricare un checkpoint utilizzando. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

La memorizzazione di checkpoint per modelli di linguaggio di grandi dimensioni (LLMs) può essere costosa in quanto spesso richiede la creazione di un volume di file system di grandi dimensioni. Per ridurre i costi, hai la possibilità di salvare i checkpoint direttamente su Amazon S3 senza la necessità di servizi di file system aggiuntivi come Amazon. FSx Puoi sfruttare l'esempio precedente con il seguente frammento di codice per salvare i checkpoint su S3 specificando un S3 come destinazione. URL

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Checkpoint asincroni e condivisi

Potrebbero verificarsi situazioni in cui è necessario continuare l'allenamento con diverse configurazioni hardware, ad esempio modificando il numero di. GPUs In questi casi, i processi di addestramento devono caricare i checkpoint durante il resharding, il che significa riprendere l'allenamento successivo con un numero diverso di. SHARD_DEGREE Per risolvere lo scenario in cui è necessario riprendere l'allenamento con un numero diverso diSHARD_DEGREE, è necessario salvare i checkpoint del modello utilizzando il tipo di dizionario sharded state, rappresentato da. StateDictType.SHARDED_STATE_DICT Il salvataggio dei checkpoint in questo formato consente di gestire correttamente il processo di resharding quando si continua l'addestramento con una configurazione hardware modificata. Il frammento di codice fornito illustra come utilizzare il per salvare in modo asincrono checkpoint tsm API suddivisi, consentendo un processo di formazione più efficiente e semplificato.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Il processo di caricamento dei checkpoint condivisi è simile alla sezione precedente, ma prevede l'utilizzo del metodo and its. torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader load Il load metodo di questa classe consente di caricare i dati dei checkpoint condivisi, seguendo un processo analogo a quello descritto in precedenza.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Checkpoint del modello completo

Al termine dell'addestramento, è possibile salvare un checkpoint completo che combini tutti i frammenti di un modello in un unico file di checkpoint del modello. La SMP libreria supporta completamente i checkpoint PyTorch completi del modelloAPI, quindi non è necessario apportare modifiche.

Nota che se usi la SMPParallelismo tensoriale, la SMP libreria trasforma il modello. In questo caso, quando si esegue il checkpoint del modello completo, la SMP libreria traduce il modello nel formato checkpoint Hugging Face Transformers per impostazione predefinita.

Nei casi in cui ti alleni con il parallelismo SMP tensoriale e disattivi il processo di SMP traduzione, puoi usare l'translate_on_saveargomento di per attivare o PyTorch FullStateDictConfig API disattivare la SMP traduzione automatica secondo necessità. Ad esempio, se vi state concentrando sulla formazione di un modello, non è necessario aggiungere il processo di traduzione che comporta costi aggiuntivi. In tal caso, ti consigliamo di impostaretranslate_on_save=False. Inoltre, se prevedi di continuare a utilizzare la SMP traduzione del modello per ulteriori corsi di formazione in futuro, puoi disattivarla per salvare la SMP traduzione del modello per un uso successivo. La traduzione del modello nel formato di checkpoint del modello Hugging Face Transformers è necessaria quando si conclude l'addestramento del modello e lo si utilizza per l'inferenza.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Nota che l'opzione FullStateDictConfig(rank0_only=True, offload_to_cpu=True) è quella di raccogliere il modello sul dispositivo di livello 0 per risparmiare memoria durante l'CPUaddestramento di modelli di grandi dimensioni.

Per caricare nuovamente il modello per l'inferenza, procedete come mostrato nel seguente esempio di codice. Nota che la classe AutoModelForCausalLM potrebbe passare ad altre classi Factor Builder in Hugging Face Transformers, ad esempioAutoModelForSeq2SeqLM, a seconda del modello. Per ulteriori informazioni, consulta la documentazione di Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)