Point de contrôle à l'aide de SMP - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Point de contrôle à l'aide de SMP

La bibliothèque SageMaker model parallelism (SMP) prend en charge les points PyTorch APIs de contrôle et permet APIs de vérifier correctement les points de contrôle lors de l'utilisation de la bibliothèque. SMP

PyTorch FSDP(Fully Sharded Data Parallelism) prend en charge trois types de points de contrôle : complets, fragmentés et locaux, chacun ayant des objectifs différents. Des points de contrôle complets sont utilisés lors de l'exportation du modèle une fois l'entraînement terminé, car la génération d'un point de contrôle complet est un processus coûteux en termes de calcul. Les points de contrôle fragmentés permettent de sauvegarder et de charger l'état d'un modèle fragmenté pour chaque rang individuel. Grâce aux points de contrôle fragmentés, vous pouvez reprendre l'entraînement avec différentes configurations matérielles, par exemple un nombre différent de. GPUs Cependant, le chargement des points de contrôle fragmentés peut être lent en raison de la communication requise entre plusieurs appareils. La SMP bibliothèque fournit des fonctionnalités de point de contrôle local, qui permettent de récupérer plus rapidement l'état du modèle sans surcharge de communication supplémentaire. Notez que les points de contrôle créés par FSDP nécessitent d'écrire sur un système de fichiers réseau partagé tel qu'AmazonFSx.

Points de contrôle locaux asynchrones

Lors de l'entraînement de modèles d'apprentissage automatique, il n'est pas nécessaire d'effectuer les itérations suivantes pour attendre que les fichiers de points de contrôle soient enregistrés sur disque. Avec la sortie de la SMP version 2.5, la bibliothèque prend en charge l'enregistrement des fichiers de point de contrôle de manière asynchrone. Cela signifie que l'itération d'entraînement suivante peut être exécutée simultanément avec les opérations d'entrée et de sortie (E/S) pour créer des points de contrôle, sans être ralentie ou freinée par ces opérations d'E/S. De plus, le processus de récupération des paramètres du modèle fragmenté et de l'optimiseur PyTorch peut prendre du temps en raison de la communication collective supplémentaire requise pour échanger des métadonnées tensorielles distribuées entre les grades. Même lorsque vous l'utilisez StateDictType.LOCAL_STATE_DICT pour enregistrer des points de contrôle locaux pour chaque rang, elle invoque PyTorch toujours des hooks qui effectuent une communication collective. Pour atténuer ce problème et réduire le temps nécessaire à la récupération des points de contrôle, SMP introduitSMStateDictType.SM_LOCAL_STATE_DICT, qui permet de récupérer plus rapidement les points de contrôle du modèle et de l'optimiseur en contournant la surcharge de communication collective.

Note

Le maintien de la cohérence dans le FSDP SHARD_DEGREE est une condition préalable à l'utilisation duSMStateDictType.SM_LOCAL_STATE_DICT. Assurez-vous que le SHARD_DEGREE reste inchangé. Bien que le nombre de réplications du modèle puisse varier, le degré de fragmentation du modèle doit être identique à celui de la configuration d'entraînement précédente lorsque vous reprenez un point de contrôle.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

L'extrait de code suivant montre comment charger un point de contrôle en utilisant. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Le stockage de points de contrôle pour les grands modèles de langage (LLMs) peut s'avérer coûteux car cela nécessite souvent la création d'un volume de système de fichiers important. Pour réduire les coûts, vous avez la possibilité d'enregistrer les points de contrôle directement dans Amazon S3 sans avoir besoin de services de système de fichiers supplémentaires tels qu'Amazon. FSx Vous pouvez utiliser l'exemple précédent avec l'extrait de code suivant pour enregistrer des points de contrôle dans S3 en spécifiant un S3 URL comme destination.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Points de contrôle partitionnés asynchrones

Dans certaines situations, vous devrez peut-être poursuivre votre formation avec différentes configurations matérielles, par exemple en modifiant le nombre deGPUs. Dans ces cas, vos processus de formation doivent charger des points de contrôle lors du repartage, ce qui implique de reprendre l'entraînement suivant avec un nombre différent de. SHARD_DEGREE Afin de résoudre le scénario dans lequel vous devez reprendre l'entraînement avec un nombre différent deSHARD_DEGREE, vous devez enregistrer les points de contrôle de votre modèle à l'aide du type de dictionnaire d'états fragmenté, représenté par. StateDictType.SHARDED_STATE_DICT L'enregistrement des points de contrôle dans ce format vous permet de gérer correctement le processus de repartage lorsque vous poursuivez la formation avec une configuration matérielle modifiée. L'extrait de code fourni illustre comment utiliser le pour enregistrer des points de contrôle tsm API partitionnés de manière asynchrone, permettant ainsi un processus de formation plus efficace et rationalisé.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Le processus de chargement des points de contrôle partagés est similaire à celui décrit dans la section précédente, mais il implique l'utilisation de la méthode torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader et de sa load méthode. La load méthode de cette classe permet de charger les données de point de contrôle partagées, en suivant un processus analogue à celui décrit précédemment.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Modèles complets de points de contrôle

À la fin de la formation, vous pouvez enregistrer un point de contrôle complet qui combine tous les fragments d'un modèle dans un seul fichier de point de contrôle du modèle. La SMP bibliothèque prend entièrement en charge les points de contrôle du modèle PyTorch completAPI, vous n'avez donc pas besoin d'apporter de modifications.

Notez que si vous utilisez le SMPParallélisme de tenseur, la SMP bibliothèque transforme le modèle. Dans ce cas, lorsque vous vérifiez le modèle complet, la SMP bibliothèque retraduit le modèle au format de point de contrôle Hugging Face Transformers par défaut.

Dans les cas où vous vous entraînez avec le parallélisme des SMP tenseurs et que vous désactivez le processus de SMP traduction, vous pouvez utiliser l'translate_on_saveargument de PyTorch FullStateDictConfig API pour activer ou désactiver la SMP traduction automatique selon vos besoins. Par exemple, si vous vous concentrez sur la formation d'un modèle, vous n'avez pas besoin d'ajouter le processus de traduction, ce qui entraîne des frais supplémentaires. Dans ce cas, nous vous recommandons de définirtranslate_on_save=False. De plus, si vous prévoyez de continuer à utiliser la SMP traduction du modèle pour une formation continue à l'avenir, vous pouvez la désactiver pour enregistrer la SMP traduction du modèle pour une utilisation ultérieure. Il est nécessaire de retraduire le modèle au format de point de contrôle du modèle Hugging Face Transformers lorsque vous terminez l'entraînement de votre modèle et que vous l'utilisez à des fins d'inférence.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Notez que l'option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) consiste à rassembler le modèle sur un appareil CPU de 0e rang pour économiser de la mémoire lors de l'entraînement de grands modèles.

Pour recharger le modèle à des fins d'inférence, procédez comme indiqué dans l'exemple de code suivant. Notez que la classe AutoModelForCausalLM peut être remplacée par d'autres classes de création de facteurs dans Hugging Face Transformers, par exemple AutoModelForSeq2SeqLM en fonction de votre modèle. Pour plus d'informations, consultez la documentation de Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)