Point de contrôle et optimisation d'un modèle grâce au parallélisme de modèles - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Point de contrôle et optimisation d'un modèle grâce au parallélisme de modèles

La bibliothèque de parallélisme des SageMaker modèles fournit des API de point de contrôle pour enregistrer l'état du modèle et l'état de l'optimiseur divisés par les différentes stratégies de parallélisme des modèles, et pour charger des points de contrôle pour la formation continue à partir desquels vous souhaitez reprendre l'entraînement et le peaufiner. Les API prennent également en charge des options permettant d'enregistrer partiellement ou totalement les états du modèle et de l'optimiseur.

Point de contrôle d'un modèle distribué

Choisissez l'une des rubriques suivantes en fonction du framework entre PyTorch TensorFlow et de la version de la bibliothèque de parallélisme de SageMaker modèles que vous utilisez.

Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles v1.10.0 et versions ultérieures)

La bibliothèque de parallélisme des SageMaker modèles fournit des API de point de contrôle pour enregistrer et charger des points de contrôle complets ou partiels de l'état du modèle distribué et de son état d'optimiseur.

Note

Cette méthode de point de contrôle est recommandée si vous utilisez PyTorch et SageMaker modélisez la bibliothèque de parallélisme v1.10.0 ou version ultérieure.

Point de contrôle partiel

Pour enregistrer les points de contrôle d'un modèle entraîné avec le parallélisme de modèles, utilisez l'API smdistributed.modelparallel.torch.save_checkpoint avec l'option de point de contrôle partiel définie sur true (partial=True). Cela permet d'enregistrer chaque partition de modèle individuellement. Outre le modèle et l'état de l'optimiseur, vous pouvez également enregistrer des données personnalisées supplémentaires via l'argument user_content. Le modèle de point de contrôle, l'optimiseur et le contenu utilisateur sont enregistrés dans des fichiers séparés. L'appel d'API save_checkpoint crée des dossiers de points de contrôle selon la structure suivante.

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

Pour reprendre l'entraînement à partir de points de contrôle partiels, utilisez l'API smdistributed.modelparallel.torch.resume_from_checkpoint avec partial=True et spécifiez le répertoire du point de contrôle et la balise utilisée lors de l'enregistrement des points de contrôle partiels. Notez que le chargement réel des poids du modèle se produit après le partitionnement du modèle, lors de la première exécution de la fonction d'étape d'entraînement décorée par smdistributed.modelparallel.torch.step.

Lors de l'enregistrement d'un point de contrôle partiel, la bibliothèque enregistre également la décision de partition de modèle sous forme de fichiers avec extension de fichier .pt. Inversement, lors de la reprise à partir du point de contrôle partiel, la bibliothèque charge les fichiers de décision de partition. Une fois la décision de partition chargée, vous ne pouvez pas la modifier.

L'extrait de code suivant montre comment définir les API de point de contrôle dans un PyTorch script de formation.

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

Point de contrôle complet

Pour enregistrer l'artefact du modèle final à des fins d'inférence, utilisez l'API smdistributed.modelparallel.torch.save_checkpoint avec partial=False, qui combine les partitions du modèle pour créer un artefact de modèle unique. Notez que cela ne combine pas les états de l'optimiseur.

Pour initialiser l'entraînement avec des poids particuliers, à partir d'un point de contrôle complet du modèle, vous pouvez utiliser l'API smdistributed.modelparallel.torch.resume_from_checkpoint avec partial=False. Notez que cela ne charge pas les états de l'optimiseur.

Note

Avec le parallélisme des tenseurs, en général, state_dict doit être traduit entre l'implémentation du modèle d'origine et l'implémentation DistributedModel. Vous pouvez éventuellement fournir la fonction de traduction state_dict en tant qu'argument à smdistributed.modelparallel.torch.resume_from_checkpoint. Cependant, pour Modèles pris en charge prêts à l'emploi, la bibliothèque se charge de cette traduction automatiquement.

Le code suivant montre un exemple d'utilisation des API de point de contrôle pour vérifier complètement un PyTorch modèle entraîné avec le parallélisme des modèles.

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles entre v1.6.0 et v1.9.0)

La bibliothèque de parallélisme des SageMaker modèles fournit des fonctions Python permettant d'enregistrer des points de contrôle partiels ou complets pour les tâches d'entraînement avec le parallélisme des tenseurs. La procédure suivante explique comment utiliser smp.save() et smp.load() pour enregistrer et charger un point de contrôle lors de l'utilisation du parallélisme de tenseur.

Note

Cette méthode de point de contrôle est recommandée si vous utilisez PyTorchParallélisme de tenseur, et la bibliothèque de parallélisme du SageMaker modèle entre les versions v1.6.0 et v1.9.0.

  1. Préparez un objet de modèle et enveloppez-le avec la fonction wrapper smp.DistributedModel() de la bibliothèque.

    model = MyModel(...) model = smp.DistributedModel(model)
  2. Préparez un optimiseur pour le modèle. Un ensemble de paramètres de modèle est un argument itérable requis par les fonctions de l'optimiseur. Pour préparer un ensemble de paramètres de modèle, vous devez traiter model.parameters() pour attribuer des ID uniques à des paramètres de modèle individuels.

    Si plusieurs paramètres partagent le même ID dans l'argument itérable de paramètres de modèle, le chargement de l'état de l'optimiseur à points de contrôle échoue. Pour créer un argument itérable de paramètres de modèle avec des ID uniques pour l'optimiseur, consultez le code suivant :

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. Enveloppez l'optimiseur à l'aide de la fonction wrapper smp.DistributedOptimizer() de la bibliothèque.

    optimizer = smp.DistributedOptimizer(optimizer)
  4. Enregistrez le modèle et l'état de l'optimiseur à l'aide de smp.save(). Selon la manière dont vous souhaitez enregistrer les points de contrôle, choisissez l'une des deux options suivantes :

    • Option 1 : enregistrez un modèle partiel sur chaque mp_rank pour un MP_GROUP unique.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      Avec le parallélisme de tenseur, la bibliothèque enregistre les fichiers à points de contrôle nommés selon le format suivant : checkpoint.pt_{pp_rank}_{tp_rank}.

      Note

      Avec le parallélisme de tenseur, assurez-vous de définir l'instruction if comme if smp.rdp_rank() == 0 et non comme if smp.dp_rank() == 0. Si l'état de l'optimiseur est partitionné avec un parallélisme de tenseur, tous les rangs parallèles aux données réduites doivent enregistrer leur propre partition de l'état de l'optimiseur. L'utilisation d'une mauvaise instruction if pour les points de contrôle peut entraîner un blocage de la tâche d'entraînement. Pour plus d'informations sur l'utilisation du parallélisme if smp.dp_rank() == 0 sans tenseur, consultez les instructions générales pour l'enregistrement et le chargement dans la documentation du SDK SageMaker Python.

    • Option 2 : enregistrez le modèle complet.

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      Note

      Tenez compte des points suivants pour la création de points de contrôle complets :

      • Si vous définissez gather_to_rank0=True, tous les rangs autres que 0 renvoient des dictionnaires vides.

      • Pour la création de points de contrôle complets, vous ne pouvez créer des points de contrôle que pour le modèle. La création de points de contrôle complets des états de l'optimiseur n'est actuellement pas prise en charge.

      • Le modèle complet doit uniquement être enregistré sur smp.rank() == 0.

  5. Chargez les points de contrôle à l'aide de smp.load(). Selon la manière dont vous avez enregistré les points de contrôle à l'étape précédente, choisissez l'une des deux options suivantes :

    • Option 1 : chargez les points de contrôle partiels.

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      Vous pouvez définir same_partition_load=True dans model.load_state_dict() pour une charge plus rapide si vous savez que la partition ne changera pas.

    • Option 2 : chargez les points de contrôle complets.

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      La condition if smp.rdp_rank() == 0 n'est pas nécessaire, mais elle peut aider à éviter un chargement redondant entre différents MP_GROUP. La création de points de contrôle complets du dictionnaire des états de l'optimiseur n'est actuellement pas prise en charge avec le parallélisme de tenseur.

Contrôle d'un modèle distribué TensorFlow

Pour enregistrer un TensorFlow modèle pendant l'entraînement au parallélisme des modèles, utilisez les fonctions suivantes fournies par la bibliothèque de parallélisme des SageMaker modèles.

Optimisation d'un modèle distribué

L'optimisation doit être configurée dans votre script d'entraînement. L'extrait de code suivant montre un exemple de structure de script d'entraînement utilisant la classe AutoModelForCausalLM de Hugging Face Transformers avec des modifications pour l'enregistrement des modules et des paramètres pour un smdistributed.model.parallel.torch réglage précis.

Note

Pour optimiser un transformateur distribué (un modèle de transformateur encapsulé par smp.DistributedModel()) avec la fonction smp.delayed_param_initialization activée, la tâche d'optimisation doit être configurée avec un système de fichiers FSx pour Lustre. Si vous souhaitez optimiser un modèle à grande échelle à l'aide de l'option d'initialisation différée des paramètres, vous devez configurer un système de fichiers FSx pour Lustre.

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

Pour un exemple complet de scripts d'entraînement et de blocs-notes Jupyter, consultez les exemples GPT-2 disponibles dans le référentiel d' PyTorchexemples. SageMaker GitHub