Puntos de control y ajuste de un modelo con paralelismo de modelos - Amazon SageMaker

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

Puntos de control y ajuste de un modelo con paralelismo de modelos

La biblioteca de paralelismo de SageMaker modelos proporciona API de puntos de control para guardar el estado del modelo y el estado del optimizador divididos por las distintas estrategias de paralelismo del modelo, y para cargar los puntos de control para el entrenamiento continuo desde donde desee reiniciar el entrenamiento y ajustarlo con precisión. Las API también admiten opciones para guardar parcial o totalmente los estados del modelo y del optimizador.

Punto de control de un modelo distribuido

Elija uno de los siguientes temas en función del marco PyTorch y TensorFlow de la versión de la biblioteca de paralelismo de modelos que utilice. SageMaker

Verificación de un PyTorch modelo distribuido (para la biblioteca de paralelismo de SageMaker modelos v1.10.0 y versiones posteriores)

La biblioteca de paralelismo de SageMaker modelos proporciona API de puntos de control para guardar y cargar puntos de control totales o parciales del estado del modelo distribuido y su estado del optimizador.

nota

Se recomienda utilizar este método de puntos de control si se utiliza la biblioteca de paralelismo de modelos PyTorch v1.10.0 o posterior SageMaker .

Puntos de control parciales

Para guardar los puntos de control de un modelo entrenado con paralelismo de modelos, usa la API smdistributed.modelparallel.torch.save_checkpoint con la opción de puntos de control parcial establecida en true (partial=True). Esto guarda cada partición del modelo de forma individual. Además del modelo y el estado del optimizador, también puede guardar cualquier dato personalizado adicional mediante el argumento user_content. El modelo de puntos de control, el optimizador y el contenido del usuario se guardan como archivos independientes. La llamada a la API save_checkpoint crea carpetas de puntos de control con la siguiente estructura.

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

Para reanudar el entrenamiento a partir de puntos de control parciales, utilice la API smdistributed.modelparallel.torch.resume_from_checkpoint con partial=True y especifique el directorio de puntos de control y la etiqueta utilizada mientras guarda los puntos de control parciales. Tenga en cuenta que la carga real de los pesos del modelo ocurre después de la partición del modelo, durante la primera ejecución de la función de paso de entrenamiento decorada por smdistributed.modelparallel.torch.step.

Al guardar un punto de control parcial, la biblioteca también guarda la decisión de partición del modelo como archivos con la extensión del archivo .pt. Por el contrario, al reanudar desde el punto de control parcial, la biblioteca carga los archivos de decisión de partición juntos. Una vez que la decisión de partición esté cargada, no podrá cambiar la partición.

El siguiente fragmento de código muestra cómo configurar las API de puntos de control en un script de entrenamiento. PyTorch

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

Puntos de control completos

Para guardar el artefacto del modelo final con fines de inferencia, utilice la API smdistributed.modelparallel.torch.save_checkpoint con partial=False, que combina las particiones del modelo para crear un único artefacto modelo. Tenga en cuenta que esto no combina los estados del optimizador.

Para iniciar el entrenamiento con pesos específicos, con un punto de control del modelo completo, puede usar la API smdistributed.modelparallel.torch.resume_from_checkpoint con partial=False. Tenga en cuenta que esto no carga los estados del optimizador.

nota

Con el paralelismo de tensores, en general, el state_dict debe traducirse entre la implementación del modelo original y la implementación de DistributedModel. Si lo desea, puede proporcionar la función de traducción state_dict como argumento para el smdistributed.modelparallel.torch.resume_from_checkpoint. Sin embargo, para Modelos compatibles listos para usar, la biblioteca se encarga de esta traducción automáticamente.

El siguiente código muestra un ejemplo de cómo utilizar las API de puntos de control para controlar completamente un PyTorch modelo entrenado con paralelismo de modelos.

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

Controlar un PyTorch modelo distribuido (para la biblioteca de paralelismo de modelos entre las versiones 1.6.0 y SageMaker 1.9.0)

La biblioteca de SageMaker modelos de paralelismo proporciona funciones de Python para guardar puntos de control parciales o completos para trabajos de entrenamiento con paralelismo tensorial. El siguiente procedimiento muestra cómo usar smp.save() and smp.load() para guardar y cargar un punto de control cuando usa el paralelismo de tensores.

nota

Se recomienda utilizar este método de puntos de control si se utiliza PyTorch, y la biblioteca de paralelismo de modelos entre las versiones 1.6.0 y Paralelismo de tensores 1.9.0. SageMaker

  1. Prepare un objeto modelo y encapsularlo con la función envolvente de la biblioteca smp.DistributedModel().

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Prepare un optimizador para el modelo. Un conjunto de parámetros de modelo es un argumento iterable que requieren las funciones del optimizador. Para preparar un conjunto de parámetros de modelo, debe procesarmodel.parameters() para asignar ID exclusivos a parámetros de modelo individuales.

    Si hay parámetros con ID duplicados en el parámetro iterable del modelo, falla la carga del estado del optimizador con puntos de control. Para crear un iterable de parámetros de modelo con ID exclusivos para el optimizador, consulte lo siguiente:

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Encapsule el optimizador utilizando la función de envoltura de la biblioteca smp.DistributedOptimizer().

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Guarde el modelo y el estado del optimizador mediante smp.save(). En función de cómo quiera guardar los puntos de control, seleccione una de las dos opciones siguientes:

    • Opción 1: guarde un modelo parcial en cada mp_rank para un soloMP_GROUP.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Con el paralelismo de tensores, la biblioteca guarda los archivos marcados con el siguiente formato: checkpoint.pt_{pp_rank}_{tp_rank}.

      nota

      Con el paralelismo de tensores, asegúrese de configurar la instrucción if como if smp.rdp_rank() == 0 en lugar de if smp.dp_rank() == 0. Cuando el estado del optimizador está dividido con paralelismo de tensores, todos los rangos paralelos de datos reducidos deben guardar su propia partición del estado del optimizador. El uso de una declaración if incorrecta para los puntos de control podría provocar un estancamiento del trabajo de entrenamiento. Para obtener más información sobre el uso if smp.dp_rank() == 0 sin paralelismo tensorial, consulta la Instrucción general para guardar y cargar en la documentación del SDK de PythonSageMaker .

    • Opción 2: guarde el modelo completo.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      nota

      Tenga en cuenta lo siguiente para el control completo:

      • Si establece gather_to_rank0=True, todos los rangos excepto 0 devuelven diccionarios vacíos.

      • Para un control completo, solo puede controlar el modelo. Actualmente no se admite el control completo de los estados del optimizador.

      • El modelo completo solo necesita guardarse en smp.rank() == 0.

  5. Cargar los puntos de control mediante smp.load(). En función de cómo haya seleccionado en el paso anterior, seleccione una de las dos opciones siguientes:

    • Opción 1: cargue los puntos de control parciales.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Puede establecer same_partition_load=True en model.load_state_dict() para una carga más rápida, si sabe que la partición no cambiará.

    • Opción 2: carga los puntos de control completos.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      La condición if smp.rdp_rank() == 0 no es necesaria, pero puede ayudar a evitar la carga redundante entre diferentes MP_GROUPs. Actualmente, el dictado de estado del optimizador de puntos de control completo no es compatible con el paralelismo de tensores.

Verificación de un modelo distribuido TensorFlow

Para guardar un TensorFlow modelo mientras se entrena con el paralelismo de modelos, utilice las siguientes funciones que proporciona la biblioteca de paralelismo de modelos. SageMaker

Ajustar con precisión un modelo distribuido

El ajuste preciso debe configurarse en su script de entrenamiento. El siguiente fragmento de código muestra un ejemplo de estructura de un guion de entrenamiento que utiliza la clase AutoModelForCausalLM de Hugging Face Transformers con modificaciones para registrar smdistributed.model.parallel.torch los módulos y ajustes para su ajuste.

nota

Para ajustar con precisión un transformador distribuido (un modelo Transformer encapsulador por smp.DistributedModel()) con la función smp.delay ed_param_initialization activada, es necesario configurar el trabajo de ajuste con un sistema de archivos FSx para Lustre. En los casos en los que desee ajustar un modelo a gran escala con la opción de inicialización retardada de parámetros, debe configurar un sistema de archivos FSx para Lustre.

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

Para ver un ejemplo completo de guiones de entrenamiento y cuadernos de Jupyter, consulta los ejemplos de la GPT-2 en el repositorio de ejemplos. PyTorch SageMaker GitHub