Affinement - Amazon SageMaker

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Affinement

Le réglage fin est un processus de formation continue de modèles préentraînés afin d'améliorer les performances pour des cas d'utilisation spécifiques.

CPUsIl est simple de peaufiner les petits modèles qui s'adaptent parfaitement à un seul GPU modèle ou ceux qui s'adaptent entièrement à 8 copies du modèle. Il ne nécessite aucune modification particulière par rapport à l'FSDPentraînement régulier. Dans le domaine des modèles plus grands, vous devez envisager d'utiliser la fonctionnalité d'initialisation différée des paramètres, qui peut s'avérer délicate.

Pour résoudre ce problème, la SMP bibliothèque charge le modèle complet sur l'un des rangs tandis que les autres rangs créent des modèles avec des poids vides sur un méta-appareil. PyTorch FSDPInitialise ensuite les poids sur les rangs non nuls à l'aide de la init_weights fonction, et synchronise les poids de tous les rangs avec les poids du 0e rang avec défini sur. sync_module_states True L'extrait de code suivant montre comment le configurer dans votre script d'entraînement.

import torch.distributed as dist from transformers import AutoModelForCasalLM from accelerate import init_empty_weights from torch.sagemaker.delayed_param import DelayedParamIniter if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) delayed_initer = DelayedParamIniter(model) model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None )

Réglage précis d'un modèle de transformateur Hugging Face préentraîné avec parallélisme des tenseurs SMP

Cette section décrit le chargement des modèles de transformateurs pour deux cas d'utilisation : le réglage fin des petits modèles de transformateurs et le réglage fin des grands modèles de transformateurs. Pour les modèles plus petits sans initialisation différée des paramètres, enveloppez le modèle avec le torch.sagemaker.transform API avant de l'enrouler avec PyTorchFSDP.

import functools from transformers import AutoModelForCausalLM from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.sagemaker import transform model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True) # Transform model while loading state dictionary from rank 0. tp_model = transform(model, load_state_dict_from_rank0=True) # Wrap with FSDP. model = FSDP( tp_model, ... sync_module_states=True, )

Pour les modèles plus grands, l'approche précédente entraîne un manque de CPU mémoire. Nous vous recommandons d'utiliser l'initialisation différée des paramètres pour éviter de tels problèmes CPU de mémoire. Dans ce cas, vous pouvez appliquer le torch.sagemaker.transform API et torch.sagemaker.delayed_param.DelayedParamIniter API comme indiqué dans l'exemple de code suivant.

from transformers import AutoModelForCausalLM from torch.sagemaker import transform from torch.sagemaker.delayed_param import DelayedParamIniter # Create one instance of model without delayed param # on CPU, on one rank. if dist.get_rank() == 0: model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True) else: with init_empty_weights(): model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...)) # Transform model while loading state dictionary from rank 0 model = transform(model, load_state_dict_from_rank0=True) if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks delayed_initer = DelayedParamIniter(model) else: delayed_initer = None with ( delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext() ): # Wrap the model with FSDP model = FSDP( model, ..., sync_module_states=True, param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None )