Ponto de verificação usando SMP - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Ponto de verificação usando SMP

A biblioteca SageMaker model parallelism (SMP) suporta pontos de verificação e fornece APIs ajuda PyTorch APIs para fazer o checkpoint de forma adequada ao usar a biblioteca. SMP

PyTorch FSDP(Paralelismo de dados totalmente fragmentado) suporta três tipos de pontos de verificação: completos, fragmentados e locais, cada um com finalidades diferentes. Pontos de verificação completos são usados ao exportar o modelo após a conclusão do treinamento, pois gerar um ponto de verificação completo é um processo computacionalmente caro. Os pontos de verificação fragmentados ajudam a salvar e carregar o estado de um modelo fragmentado para cada classificação individual. Com pontos de verificação fragmentados, você pode retomar o treinamento com diferentes configurações de hardware, como um número diferente de. GPUs No entanto, o carregamento de pontos de verificação fragmentados pode ser lento devido à comunicação envolvida entre vários dispositivos. A SMP biblioteca fornece funcionalidades de ponto de verificação local, que permitem uma recuperação mais rápida do estado do modelo sem sobrecarga adicional de comunicação. Observe que os pontos de verificação criados por FSDP exigem a gravação em um sistema de arquivos de rede compartilhado, como a AmazonFSx.

Pontos de verificação locais assíncronos

Ao treinar modelos de aprendizado de máquina, não é necessário que as iterações subsequentes aguardem que os arquivos do ponto de verificação sejam salvos em disco. Com o lançamento da versão SMP 2.5, a biblioteca suporta o salvamento de arquivos de ponto de verificação de forma assíncrona. Isso significa que a iteração de treinamento subsequente pode ser executada simultaneamente com as operações de entrada e saída (E/S) para criar pontos de verificação, sem ser retardada ou retida por essas operações de E/S. Além disso, o processo de recuperação dos parâmetros fragmentados do modelo e do otimizador PyTorch pode ser demorado devido à comunicação coletiva adicional necessária para trocar metadados de tensores distribuídos entre as classificações. Mesmo quando usado StateDictType.LOCAL_STATE_DICT para salvar pontos de verificação locais para cada classificação, PyTorch ainda invoca ganchos que realizam comunicação coletiva. Para mitigar esse problema e reduzir o tempo necessário para a recuperação do ponto de verificação, SMP introduzSMStateDictType.SM_LOCAL_STATE_DICT, que permite uma recuperação mais rápida dos pontos de verificação do modelo e do otimizador, contornando a sobrecarga de comunicação coletiva.

nota

Manter a consistência no FSDP SHARD_DEGREE é um requisito para utilizar o. SMStateDictType.SM_LOCAL_STATE_DICT Certifique-se de que SHARD_DEGREE permaneça inalterado. Embora o número de replicações do modelo possa variar, o grau de fragmento do modelo precisa ser idêntico à configuração de treinamento anterior ao sair de um ponto de verificação.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

O trecho de código a seguir demonstra como carregar um ponto de verificação utilizando. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Armazenar pontos de verificação para modelos de linguagem grandes (LLMs) pode ser caro, pois geralmente requer a criação de um grande volume de sistema de arquivos. Para reduzir custos, você tem a opção de salvar pontos de verificação diretamente no Amazon S3 sem a necessidade de serviços adicionais de sistema de arquivos, como o Amazon. FSx Você pode aproveitar o exemplo anterior com o seguinte trecho de código para salvar pontos de verificação no S3 especificando um S3 como destino. URL

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Pontos de verificação compartilhados assíncronos

Pode haver situações em que você precise continuar treinando com diferentes configurações de hardware, como alterar o número deGPUs. Nesses casos, seus processos de treinamento devem carregar os pontos de verificação durante a fragmentação, o que significa retomar o treinamento subsequente com um número diferente de. SHARD_DEGREE Para resolver o cenário em que você precisa retomar o treinamento com um número diferente deSHARD_DEGREE, você deve salvar os pontos de verificação do modelo usando o tipo de dicionário de estado fragmentado, representado por. StateDictType.SHARDED_STATE_DICT Salvar pontos de verificação nesse formato permite que você gerencie adequadamente o processo de refragmentação ao continuar o treinamento com uma configuração de hardware modificada. O trecho de código fornecido ilustra como usar o para salvar pontos de verificação fragmentados tsm API de forma assíncrona, permitindo um processo de treinamento mais eficiente e simplificado.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

O processo de carregamento de pontos de verificação compartilhados é semelhante ao da seção anterior, mas envolve o uso do torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader e seu load método. O load método dessa classe permite carregar os dados compartilhados do ponto de verificação, seguindo um processo análogo ao descrito anteriormente.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Pontos de verificação do modelo completo

Ao final do treinamento, você pode salvar um ponto de verificação completo que combina todos os fragmentos de um modelo em um único arquivo de ponto de verificação do modelo. A SMP biblioteca oferece suporte total aos pontos de verificação PyTorch completos do modeloAPI, portanto, você não precisa fazer nenhuma alteração.

Observe que, se você usar o SMPParalelismo de tensores, a SMP biblioteca transformará o modelo. Ao verificar o modelo completo nesse caso, a SMP biblioteca traduz o modelo de volta para o formato de ponto de verificação Hugging Face Transformers por padrão.

Nos casos em que você treina com o paralelismo SMP tensorial e desativa o processo de SMP tradução, você pode usar o translate_on_save argumento do PyTorch FullStateDictConfig API para ativar ou desativar a SMP tradução automática conforme necessário. Por exemplo, se você está se concentrando em treinar um modelo, não precisa adicionar o processo de tradução, o que aumenta a sobrecarga. Nesse caso, recomendamos que você definatranslate_on_save=False. Além disso, se você planeja continuar usando a SMP tradução do modelo para treinamento adicional no futuro, você pode desativá-la para salvar a SMP tradução do modelo para uso posterior. É necessário traduzir o modelo de volta para o formato de ponto de verificação do modelo Hugging Face Transformers quando você encerra o treinamento do seu modelo e o usa para inferência.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Observe que a opção FullStateDictConfig(rank0_only=True, offload_to_cpu=True) é reunir o modelo no dispositivo CPU do 0º nível para economizar memória ao treinar modelos grandes.

Para carregar o modelo de volta para inferência, faça isso conforme mostrado no exemplo de código a seguir. Observe que a classe AutoModelForCausalLM pode mudar para outras classes de construtor de fatores em Hugging Face Transformers, comoAutoModelForSeq2SeqLM, dependendo do seu modelo. Para obter mais informações, consulte a documentação do Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)