Apontando pontos de verificação e ajustando um modelo com paralelismo de modelos - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Apontando pontos de verificação e ajustando um modelo com paralelismo de modelos

A biblioteca de paralelismo de SageMaker modelos fornece APIs de ponto de verificação para salvar o estado do modelo e o estado do otimizador divididos pelas várias estratégias de paralelismo do modelo e para carregar pontos de verificação para treinamento contínuo de onde você deseja reiniciar o treinamento e ajustar. As APIs também oferecem opções de suporte para salvar parcialmente ou totalmente os estados do modelo e do otimizador.

Pontos de verificação de um modelo distribuído

Escolha um dos tópicos a seguir, dependendo da estrutura entre PyTorch e TensorFlow e da versão da biblioteca de paralelismo de SageMaker modelos que você usa.

Apontando um PyTorch modelo distribuído (para a biblioteca de paralelismo de SageMaker modelos v1.10.0 e posterior)

A biblioteca de paralelismo de SageMaker modelos fornece APIs de ponto de verificação para salvar e carregar pontos de verificação completos ou parciais do estado do modelo distribuído e do estado do otimizador.

nota

Esse método de ponto de verificação é recomendado se você usar PyTorch a biblioteca de paralelismo de SageMaker modelos v1.10.0 ou posterior.

Pontos de verificação parciais

Para salvar pontos de verificação de um treinamento de modelos com paralelismo de modelos, use a API smdistributed.modelparallel.torch.save_checkpoint com a opção de ponto de verificação parcial definida como true (partial=True). Isto salva cada partição de modelos individualmente. Além do modelo e do estado do otimizador, você também pode salvar quaisquer dados personalizados adicionais por meio do argumento user_content. O modelo com ponto de verificação, o otimizador e o conteúdo do usuário são salvos como arquivos separados. A chamada de API save_checkpoint cria pastas de pontos de verificação na estrutura a seguir.

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

Para retomar o treinamento a partir de pontos de verificação parciais, use a API smdistributed.modelparallel.torch.resume_from_checkpoint com partial=True e especifique o diretório do ponto de verificação e a tag usada ao salvar os pontos de verificação parciais. Observe que o carregamento real dos pesos do modelo ocorre após o particionamento do modelo, durante a primeira execução da step function de treinamento decorada smdistributed.modelparallel.torch.step.

Ao salvar um ponto de verificação parcial, a biblioteca também salva a decisão da partição do modelo como arquivos com extensão de arquivo .pt. Por outro lado, ao retomar o ponto de verificação parcial, a biblioteca carrega os arquivos de decisão de partição juntos. Depois que a decisão de partição é carregada, não é possível alterar a partição.

O trecho de código a seguir mostra como definir as APIs do ponto de verificação em um script de treinamento. PyTorch

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

Pontos de verificação totais

Para salvar o artefato do modelo final para fins de inferência, use a API smdistributed.modelparallel.torch.save_checkpoint com partial=False, que combinam as partições do modelo para criar um único artefato do modelo. Observe que isso não combina os estados do otimizador.

Para inicializar o treinamento com pesos específicos, considerando um ponto de verificação completo do modelo, você pode usar a API smdistributed.modelparallel.torch.resume_from_checkpoint com partial=False. Observe que isso não combina os estados de carregamento do otimizador.

nota

Com o paralelismo do tensor, em geral, o state_dict deve ser traduzido entre a implantação do modelo original e a implantação DistributedModel. Opcionalmente, você pode fornecer a função de tradução state_dict como um argumento para o smdistributed.modelparallel.torch.resume_from_checkpoint. No entanto, para Modelos compatíveis prontos para uso, a biblioteca cuida dessa tradução automaticamente.

O código a seguir mostra um exemplo de como usar as APIs de ponto de verificação para verificar totalmente um PyTorch modelo treinado com paralelismo de modelos.

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

Apontando um PyTorch modelo distribuído (para a biblioteca de paralelismo de SageMaker modelos entre v1.6.0 e v1.9.0)

A biblioteca de paralelismo de SageMaker modelos fornece funções Python para salvar pontos de verificação parciais ou completos para treinar trabalhos com paralelismo de tensores. O procedimento a seguir mostra como usar o smp.save() e smp.load() para salvar e carregar um ponto de verificação ao usar o paralelismo de tensores.

nota

Esse método de ponto de verificação é recomendado se você usar PyTorchParalelismo tensorial, e a biblioteca de paralelismo de SageMaker modelos entre v1.6.0 e v1.9.0.

  1. Prepare um objeto de modelo e envolva-o com a função wrapper smp.DistributedModel() da biblioteca.

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Prepare um otimizador para o modelo. Um conjunto de parâmetros do modelo é um argumento iterável exigido pelas funções do otimizador. Para preparar uma configuração de parâmetros do modelo, você deve processar model.parameters() para a atribuição de IDs exclusivos aos parâmetros individuais do modelo.

    Se houver parâmetros com IDs duplicadas no parâmetro do modelo iterável, o carregamento do estado do otimizador com ponto de verificação falhará. Para criar um item iterável de parâmetros de modelo com IDs exclusivas para seu otimizador, veja o seguinte:

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Envolva o otimizador usando a função wrapper da biblioteca smp.DistributedOptimizer().

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Salve o modelo e o estado do otimizador usando smp.save(). Dependendo de como deseja salvar os pontos de verificação, escolha uma das duas opções:

    • Opção 1: Salve um modelo parcial em cada mp_rank para um único MP_GROUP.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Com paralelismo de tensores, a biblioteca salva arquivos com pontos de verificação nomeados no seguinte formato: checkpoint.pt_{pp_rank}_{tp_rank}.

      nota

      Com o paralelismo de tensores, certifique-se de configurar a instrução ‘if’ como if smp.rdp_rank() == 0 em vez de if smp.dp_rank() == 0. Quando o estado do otimizador é fragmentado com paralelismo de tensores, todas as classificações de paralelismo de dados reduzidos devem salvar suas próprias partições de estado do otimizador. Usar uma instrução if errada para os pontos de verificação pode resultar na paralisação do trabalho de treinamento. Para obter mais informações sobre como usar if smp.dp_rank() == 0 sem paralelismo de tensores, consulte Instruções gerais para salvar e carregar na documentação do SDK do PythonSageMaker .

    • Opção 2: Salve o modelo completo.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      nota

      Considere o seguinte para um pontos de verificação completos:

      • Se você definir gather_to_rank0=True, todas as outras classificações, exceto 0, retornarão dicionários vazios.

      • Para um ponto de verificação completo, você só pode verificar o modelo. Atualmente, não há suporte para pontos de verificação completos dos estados do otimizador.

      • O modelo completo só precisa ser salvo no smp.rank() == 0.

  5. Carregue os pontos de verificação usando smp.load(). Dependendo de como verificação os pontos na etapa anterior, escolha uma das duas opções a seguir:

    • Opção 1: Carregue os pontos de verificação parciais.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Você pode configurar same_partition_load=True no model.load_state_dict() para um carregamento mais rápido se souber que a partição não será alterada.

    • Opção 2: Carregue os pontos de verificação completos.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      A condição if smp.rdp_rank() == 0 não é obrigatória, mas pode ajudar a evitar o carregamento redundante entre diferentes MP_GROUPs. O estado completo do otimizador de ponto de verificação atualmente não é suportado pelo paralelismo de tensores.

Verificando um modelo distribuído TensorFlow

Para salvar um TensorFlow modelo durante o treinamento com o paralelismo de modelos, use as seguintes funções fornecidas pela biblioteca de paralelismo de SageMaker modelos.

Ajuste de um modelo distribuído

O ajuste fino precisa ser configurado em seu script de treinamento. O trecho de código a seguir mostra um exemplo de estrutura de um script de treinamento usando a classe AutoModelForCausalLM de Hugging Face Transformers com modificações para registrar os módulos e as configurações para ajuste fino. smdistributed.model.parallel.torch

nota

O ajuste fino de um transformador distribuído (um modelo de transformador empacotado por smp.DistributedModel()) com a função smp.delayed_param_initialization ativada requer que o trabalho ajustado seja configurado com um sistema de arquivos FSx for Lustre. Nos casos em que você deseja ajustar um modelo em grande escala com a opção de inicialização atrasada de parâmetros, você deve configurar um sistema de arquivos FSx for Lustre.

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

Para obter um exemplo completo de scripts de treinamento e notebooks Jupyter, consulte os exemplos do GPT-2 no repositório Examples. PyTorch SageMaker GitHub