Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.
Point de contrôle à l'aide du SMP
La bibliothèque de parallélisme des SageMaker modèles (SMP) prend en charge les points PyTorch APIs de contrôle et permet APIs de vérifier correctement les points de contrôle lors de l'utilisation de la bibliothèque SMP.
PyTorch Le FSDP (Fully Sharded Data Parallelism) prend en charge trois types de points de contrôle : complets, fragmentés et locaux, chacun ayant des objectifs différents. Des points de contrôle complets sont utilisés lors de l'exportation du modèle une fois l'entraînement terminé, car la génération d'un point de contrôle complet est un processus coûteux en termes de calcul. Les points de contrôle fragmentés permettent de sauvegarder et de charger l'état d'un modèle fragmenté pour chaque rang individuel. Grâce aux points de contrôle fragmentés, vous pouvez reprendre l'entraînement avec différentes configurations matérielles, par exemple un nombre différent de. GPUs Cependant, le chargement des points de contrôle fragmentés peut être lent en raison de la communication requise entre plusieurs appareils. La bibliothèque SMP fournit des fonctionnalités de point de contrôle local, qui permettent de récupérer plus rapidement l'état du modèle sans surcharger les communications. Notez que les points de contrôle créés par FSDP nécessitent d'écrire dans un système de fichiers réseau partagé tel qu'Amazon. FSx
Points de contrôle locaux asynchrones
Lors de l'entraînement de modèles d'apprentissage automatique, il n'est pas nécessaire d'effectuer les itérations suivantes pour attendre que les fichiers de points de contrôle soient enregistrés sur disque. Avec la sortie de SMP v2.5, la bibliothèque prend en charge l'enregistrement des fichiers de point de contrôle de manière asynchrone. Cela signifie que l'itération d'entraînement suivante peut être exécutée simultanément avec les opérations d'entrée et de sortie (E/S) pour créer des points de contrôle, sans être ralentie ou freinée par ces opérations d'E/S. De plus, le processus de récupération des paramètres du modèle fragmenté et de l'optimiseur PyTorch peut prendre du temps en raison de la communication collective supplémentaire requise pour échanger des métadonnées tensorielles distribuées entre les grades. Même lorsque vous l'utilisez StateDictType.LOCAL_STATE_DICT
pour enregistrer des points de contrôle locaux pour chaque rang, elle invoque PyTorch toujours des hooks qui effectuent une communication collective. Pour atténuer ce problème et réduire le temps nécessaire à la récupération des points de contrôle, SMP introduit SMStateDictType.SM_LOCAL_STATE_DICT
un système qui permet de récupérer plus rapidement les points de contrôle du modèle et de l'optimiseur en contournant la surcharge de communication collective.
Note
Le maintien de la cohérence du FSDP SHARD_DEGREE
est une condition préalable à l'utilisation du. SMStateDictType.SM_LOCAL_STATE_DICT
Assurez-vous que le SHARD_DEGREE
reste inchangé. Bien que le nombre de réplications du modèle puisse varier, le degré de fragmentation du modèle doit être identique à celui de la configuration d'entraînement précédente lorsque vous reprenez un point de contrôle.
import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )
L'extrait de code suivant montre comment charger un point de contrôle en utilisant. SMStateDictType.SM_LOCAL_STATE_DICT
import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )
Le stockage de points de contrôle pour les grands modèles de langage (LLMs) peut s'avérer coûteux car cela nécessite souvent la création d'un volume de système de fichiers important. Pour réduire les coûts, vous avez la possibilité d'enregistrer les points de contrôle directement dans Amazon S3 sans avoir besoin de services de système de fichiers supplémentaires tels qu'Amazon. FSx Vous pouvez utiliser l'exemple précédent avec l'extrait de code suivant pour enregistrer des points de contrôle dans S3 en spécifiant une URL S3 comme destination.
key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"
s3://{your_s3_bucket}/{key}
" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)
Points de contrôle partitionnés asynchrones
Dans certaines situations, vous devrez peut-être poursuivre votre formation avec différentes configurations matérielles, par exemple en modifiant le nombre de GPUs. Dans ces cas, vos processus de formation doivent charger des points de contrôle lors du repartage, ce qui implique de reprendre l'entraînement suivant avec un nombre différent de. SHARD_DEGREE
Afin de résoudre le scénario dans lequel vous devez reprendre l'entraînement avec un nombre différent deSHARD_DEGREE
, vous devez enregistrer les points de contrôle de votre modèle à l'aide du type de dictionnaire d'états fragmenté, représenté par. StateDictType.SHARDED_STATE_DICT
L'enregistrement des points de contrôle dans ce format vous permet de gérer correctement le processus de repartage lorsque vous poursuivez la formation avec une configuration matérielle modifiée. L'extrait de code fourni montre comment utiliser l'tsm
API pour enregistrer des points de contrôle fragmentés de manière asynchrone, permettant ainsi un processus de formation plus efficace et rationalisé.
import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )
Le processus de chargement des points de contrôle partagés est similaire à celui de la section précédente, mais il implique l'utilisation de la méthode torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader
et de sa load
méthode. La load
méthode de cette classe permet de charger les données de point de contrôle partagées, en suivant un processus analogue à celui décrit précédemment.
import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)
Modèles complets de points de contrôle
À la fin de la formation, vous pouvez enregistrer un point de contrôle complet qui combine tous les fragments d'un modèle dans un seul fichier de point de contrôle du modèle. La bibliothèque SMP prend entièrement en charge l'API des points de contrôle du modèle PyTorch complet, vous n'avez donc pas besoin d'apporter de modifications.
Notez que si vous utilisez le SMPParallélisme de tenseur, la bibliothèque SMP transforme le modèle. Dans ce cas, lorsque vous vérifiez le modèle complet, la bibliothèque SMP retraduit le modèle au format de point de contrôle Hugging Face Transformers par défaut.
Dans les cas où vous vous entraînez avec le parallélisme des tenseurs SMP et que vous désactivez le processus de traduction SMP, vous pouvez utiliser l'translate_on_save
argument de l' PyTorch FullStateDictConfig
API pour activer ou désactiver la traduction automatique SMP selon vos besoins. Par exemple, si vous vous concentrez sur la formation d'un modèle, vous n'avez pas besoin d'ajouter le processus de traduction, ce qui entraîne des frais supplémentaires. Dans ce cas, nous vous recommandons de définirtranslate_on_save=False
. De plus, si vous prévoyez de continuer à utiliser la traduction SMP du modèle pour une formation continue à l'avenir, vous pouvez la désactiver pour enregistrer la traduction SMP du modèle pour une utilisation ultérieure. Il est nécessaire de retraduire le modèle au format de point de contrôle du modèle Hugging Face Transformers lorsque vous terminez l'entraînement de votre modèle et que vous l'utilisez à des fins d'inférence.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(
save_dir
, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir
, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir
) dist.barrier()
Notez que l'option FullStateDictConfig(rank0_only=True,
offload_to_cpu=True)
consiste à rassembler le modèle sur le processeur du périphérique de 0e rang pour économiser de la mémoire lors de l'entraînement de grands modèles.
Pour recharger le modèle à des fins d'inférence, procédez comme indiqué dans l'exemple de code suivant. Notez que la classe AutoModelForCausalLM
peut être remplacée par d'autres classes de création de facteurs dans Hugging Face Transformers, par exemple AutoModelForSeq2SeqLM
en fonction de votre modèle. Pour plus d'informations, consultez la documentation de Hugging Face Transformers
from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(
save_dir
)