Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.
Point de contrôle et optimisation d'un modèle grâce au parallélisme de modèles
La bibliothèque de parallélisme des SageMaker modèles fournit des API de point de contrôle pour enregistrer l'état du modèle et l'état de l'optimiseur divisés par les différentes stratégies de parallélisme des modèles, et pour charger des points de contrôle pour la formation continue à partir desquels vous souhaitez reprendre l'entraînement et le peaufiner. Les API prennent également en charge des options permettant d'enregistrer partiellement ou totalement les états du modèle et de l'optimiseur.
Point de contrôle d'un modèle distribué
Choisissez l'une des rubriques suivantes en fonction du framework entre PyTorch TensorFlow et de la version de la bibliothèque de parallélisme de SageMaker modèles que vous utilisez.
Rubriques
- Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles v1.10.0 et versions ultérieures)
- Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles entre v1.6.0 et v1.9.0)
- Contrôle d'un modèle distribué TensorFlow
Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles v1.10.0 et versions ultérieures)
La bibliothèque de parallélisme des SageMaker modèles fournit des API de point de contrôle pour enregistrer et charger des points de contrôle complets ou partiels de l'état du modèle distribué et de son état d'optimiseur.
Note
Cette méthode de point de contrôle est recommandée si vous utilisez PyTorch et SageMaker modélisez la bibliothèque de parallélisme v1.10.0 ou version ultérieure.
Point de contrôle partiel
Pour enregistrer les points de contrôle d'un modèle entraîné avec le parallélisme de modèles, utilisez l'API smdistributed.modelparallel.torch.save_checkpoint
partial=True
). Cela permet d'enregistrer chaque partition de modèle individuellement. Outre le modèle et l'état de l'optimiseur, vous pouvez également enregistrer des données personnalisées supplémentaires via l'argument user_content
. Le modèle de point de contrôle, l'optimiseur et le contenu utilisateur sont enregistrés dans des fichiers séparés. L'appel d'API save_checkpoint
crée des dossiers de points de contrôle selon la structure suivante.
- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)
Pour reprendre l'entraînement à partir de points de contrôle partiels, utilisez l'API smdistributed.modelparallel.torch.resume_from_checkpoint
partial=True
et spécifiez le répertoire du point de contrôle et la balise utilisée lors de l'enregistrement des points de contrôle partiels. Notez que le chargement réel des poids du modèle se produit après le partitionnement du modèle, lors de la première exécution de la fonction d'étape d'entraînement décorée par smdistributed.modelparallel.torch.step
.
Lors de l'enregistrement d'un point de contrôle partiel, la bibliothèque enregistre également la décision de partition de modèle sous forme de fichiers avec extension de fichier .pt
. Inversement, lors de la reprise à partir du point de contrôle partiel, la bibliothèque charge les fichiers de décision de partition. Une fois la décision de partition chargée, vous ne pouvez pas la modifier.
L'extrait de code suivant montre comment définir les API de point de contrôle dans un PyTorch script de formation.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=True
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=True
)
Point de contrôle complet
Pour enregistrer l'artefact du modèle final à des fins d'inférence, utilisez l'API smdistributed.modelparallel.torch.save_checkpoint
avec partial=False
, qui combine les partitions du modèle pour créer un artefact de modèle unique. Notez que cela ne combine pas les états de l'optimiseur.
Pour initialiser l'entraînement avec des poids particuliers, à partir d'un point de contrôle complet du modèle, vous pouvez utiliser l'API smdistributed.modelparallel.torch.resume_from_checkpoint
avec partial=False
. Notez que cela ne charge pas les états de l'optimiseur.
Note
Avec le parallélisme des tenseurs, en général, state_dict
doit être traduit entre l'implémentation du modèle d'origine et l'implémentation DistributedModel
. Vous pouvez éventuellement fournir la fonction de traduction state_dict
en tant qu'argument à smdistributed.modelparallel.torch.resume_from_checkpoint
. Cependant, pour Modèles pris en charge prêts à l'emploi, la bibliothèque se charge de cette traduction automatiquement.
Le code suivant montre un exemple d'utilisation des API de point de contrôle pour vérifier complètement un PyTorch modèle entraîné avec le parallélisme des modèles.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "
/opt/ml/checkpoint/model_parallel
" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path
, tag=f"total_steps{total_steps}
", partial=False
, model=model
, optimizer=optimizer
, user_content=user_content
num_kept_partial_checkpoints=5
) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path
, partial=False
)
Vérification d'un PyTorch modèle distribué (pour la bibliothèque de parallélisme des SageMaker modèles entre v1.6.0 et v1.9.0)
La bibliothèque de parallélisme des SageMaker modèles fournit des fonctions Python permettant d'enregistrer des points de contrôle partiels ou complets pour les tâches d'entraînement avec le parallélisme des tenseurs. La procédure suivante explique comment utiliser smp.save()
smp.load()
Note
Cette méthode de point de contrôle est recommandée si vous utilisez PyTorchParallélisme de tenseur, et la bibliothèque de parallélisme du SageMaker modèle entre les versions v1.6.0 et v1.9.0.
-
Préparez un objet de modèle et enveloppez-le avec la fonction wrapper
smp.DistributedModel()
de la bibliothèque.model = MyModel(...) model = smp.DistributedModel(model)
-
Préparez un optimiseur pour le modèle. Un ensemble de paramètres de modèle est un argument itérable requis par les fonctions de l'optimiseur. Pour préparer un ensemble de paramètres de modèle, vous devez traiter
model.parameters()
pour attribuer des ID uniques à des paramètres de modèle individuels.Si plusieurs paramètres partagent le même ID dans l'argument itérable de paramètres de modèle, le chargement de l'état de l'optimiseur à points de contrôle échoue. Pour créer un argument itérable de paramètres de modèle avec des ID uniques pour l'optimiseur, consultez le code suivant :
unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
-
Enveloppez l'optimiseur à l'aide de la fonction wrapper
smp.DistributedOptimizer()
de la bibliothèque.optimizer = smp.DistributedOptimizer(optimizer)
-
Enregistrez le modèle et l'état de l'optimiseur à l'aide de
smp.save()
. Selon la manière dont vous souhaitez enregistrer les points de contrôle, choisissez l'une des deux options suivantes : -
Option 1 : enregistrez un modèle partiel sur chaque
mp_rank
pour unMP_GROUP
unique.model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )
Avec le parallélisme de tenseur, la bibliothèque enregistre les fichiers à points de contrôle nommés selon le format suivant :
checkpoint.pt_{pp_rank}_{tp_rank}
.Note
Avec le parallélisme de tenseur, assurez-vous de définir l'instruction if comme
if smp.rdp_rank() == 0
et non commeif smp.dp_rank() == 0
. Si l'état de l'optimiseur est partitionné avec un parallélisme de tenseur, tous les rangs parallèles aux données réduites doivent enregistrer leur propre partition de l'état de l'optimiseur. L'utilisation d'une mauvaise instruction if pour les points de contrôle peut entraîner un blocage de la tâche d'entraînement. Pour plus d'informations sur l'utilisation du parallélismeif smp.dp_rank() == 0
sans tenseur, consultez les instructions générales pour l'enregistrement et le chargement dans la documentationdu SDK SageMaker Python.
-
Option 2 : enregistrez le modèle complet.
if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
Note
Tenez compte des points suivants pour la création de points de contrôle complets :
-
Si vous définissez
gather_to_rank0=True
, tous les rangs autres que0
renvoient des dictionnaires vides. -
Pour la création de points de contrôle complets, vous ne pouvez créer des points de contrôle que pour le modèle. La création de points de contrôle complets des états de l'optimiseur n'est actuellement pas prise en charge.
-
Le modèle complet doit uniquement être enregistré sur
smp.rank() == 0
.
-
-
-
Chargez les points de contrôle à l'aide de
smp.load()
. Selon la manière dont vous avez enregistré les points de contrôle à l'étape précédente, choisissez l'une des deux options suivantes : -
Option 1 : chargez les points de contrôle partiels.
checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
Vous pouvez définir
same_partition_load=True
dansmodel.load_state_dict()
pour une charge plus rapide si vous savez que la partition ne changera pas. -
Option 2 : chargez les points de contrôle complets.
if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])
La condition
if smp.rdp_rank() == 0
n'est pas nécessaire, mais elle peut aider à éviter un chargement redondant entre différentsMP_GROUP
. La création de points de contrôle complets du dictionnaire des états de l'optimiseur n'est actuellement pas prise en charge avec le parallélisme de tenseur.
-
Contrôle d'un modèle distribué TensorFlow
Pour enregistrer un TensorFlow modèle pendant l'entraînement au parallélisme des modèles, utilisez les fonctions suivantes fournies par la bibliothèque de parallélisme des SageMaker modèles.
Optimisation d'un modèle distribué
L'optimisation doit être configurée dans votre script d'entraînement. L'extrait de code suivant montre un exemple de structure de script d'entraînement utilisant la classe AutoModelForCausalLMsmdistributed.model.parallel.torch
réglage précis.
Note
Pour optimiser un transformateur distribué (un modèle de transformateur encapsulé par smp.DistributedModel()
) avec la fonction smp.delayed_param_initialization
import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)
Pour un exemple complet de scripts d'entraînement et de blocs-notes Jupyter, consultez les exemples GPT-2 disponibles dans le référentiel d' PyTorchexemples