Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.
Compruebe la posición utilizando SMP
La biblioteca SageMaker model parallelism (SMP) es compatible con los puntos de control y proporciona APIs esa ayuda PyTorch APIs para comprobar correctamente los puntos de control mientras se utiliza la biblioteca. SMP
PyTorch FSDP(Paralelismo de datos totalmente fragmentado) admite tres tipos de puntos de control: completos, fragmentados y locales, cada uno con diferentes propósitos. Al exportar el modelo una vez finalizado el entrenamiento, se utilizan puntos de control completos, ya que generar un punto de control completo es un proceso costoso desde el punto de vista computacional. Los puntos de control fragmentados ayudan a guardar y cargar el estado de un modelo fragmentado para cada rango individual. Con los puntos de control fragmentados, puede reanudar el entrenamiento con diferentes configuraciones de hardware, por ejemplo, un número diferente de. GPUs Sin embargo, la carga de los puntos de control fragmentados puede resultar lenta debido a la comunicación entre varios dispositivos. La SMP biblioteca proporciona funcionalidades de puntos de control locales, que permiten recuperar más rápidamente el estado del modelo sin sobrecargas de comunicación adicionales. Tenga en cuenta que los puntos de control creados por FSDP requieren escribirse en un sistema de archivos de red compartido, como AmazonFSx.
Puntos de control locales asíncronos
Al entrenar modelos de aprendizaje automático, no es necesario que las iteraciones posteriores esperen a que los archivos de los puntos de control se guarden en el disco. Con el lanzamiento de la versión SMP 2.5, la biblioteca permite guardar los archivos de puntos de control de forma asíncrona. Esto significa que la siguiente iteración de entrenamiento puede ejecutarse simultáneamente con las operaciones de entrada y salida (E/S) para crear puntos de control, sin que esas operaciones de E/S la ralenticen ni la frenen. Además, el proceso de recuperación de los parámetros fragmentados del modelo y del optimizador PyTorch puede llevar mucho tiempo debido a la comunicación colectiva adicional que se requiere para intercambiar los metadatos de los tensores distribuidos entre los rangos. Incluso si se utiliza StateDictType.LOCAL_STATE_DICT
para guardar los puntos de control locales para cada rango, PyTorch sigue invocando ganchos que permiten la comunicación colectiva. Para mitigar este problema y reducir el tiempo necesario para recuperar los puntos de controlSMStateDictType.SM_LOCAL_STATE_DICT
, SMP introduce un sistema que permite recuperar más rápidamente los puntos de control del modelo y del optimizador al evitar la sobrecarga de comunicación colectiva.
nota
Mantener la coherencia en el FSDP SHARD_DEGREE
es un requisito para utilizar el. SMStateDictType.SM_LOCAL_STATE_DICT
Asegúrese de que SHARD_DEGREE
permanezca sin cambios. Si bien el número de réplicas del modelo puede variar, el grado de fragmentación del modelo debe ser idéntico al de la configuración de entrenamiento anterior cuando se reanuda desde un punto de control.
import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )
El siguiente fragmento de código muestra cómo cargar un punto de control utilizando. SMStateDictType.SM_LOCAL_STATE_DICT
import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )
Almacenar puntos de control para modelos de lenguaje grandes (LLMs) puede resultar caro, ya que a menudo requiere crear un gran volumen de sistema de archivos. Para reducir los costes, tiene la opción de guardar los puntos de control directamente en Amazon S3 sin necesidad de servicios de sistema de archivos adicionales como Amazon. FSx Puede aprovechar el ejemplo anterior con el siguiente fragmento de código para guardar los puntos de control en S3 especificando un S3 como destino. URL
key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"
s3://{your_s3_bucket}/{key}
" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)
Puntos de control fragmentados asíncronos
Puede haber situaciones en las que necesite seguir entrenando con diferentes configuraciones de hardware, como cambiar el número de. GPUs En estos casos, sus procesos de entrenamiento deben cargar los puntos de control mientras se refragmentan, lo que significa reanudar el entrenamiento posterior con un número diferente de. SHARD_DEGREE
Para abordar el escenario en el que necesitas reanudar el entrenamiento con un número diferente de puntos de controlSHARD_DEGREE
, debes guardar los puntos de control del modelo utilizando el tipo de diccionario de estados fragmentados, que se representa por. StateDictType.SHARDED_STATE_DICT
Guardar los puntos de control en este formato le permite gestionar correctamente el proceso de refragmentación al continuar el entrenamiento con una configuración de hardware modificada. El fragmento de código proporcionado ilustra cómo utilizarlos para guardar de forma asíncrona los tsm
API puntos de control fragmentados, lo que permite un proceso de formación más eficiente y ágil.
import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )
El proceso de carga de los puntos de control compartidos es similar al de la sección anterior, pero implica el uso del método y su método. torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader
load
El load
método de esta clase permite cargar los datos de los puntos de control compartidos siguiendo un proceso análogo al descrito anteriormente.
import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)
Modelos completos de puntos de control
Al final del entrenamiento, puede guardar un punto de control completo que combine todos los fragmentos de un modelo en un único archivo de puntos de control del modelo. La SMP biblioteca es totalmente compatible con PyTorch todos los puntos de control del modeloAPI, por lo que no es necesario realizar ningún cambio.
Tenga en cuenta que si utiliza el SMPParalelismo de tensores, la SMP biblioteca transforma el modelo. Al comprobar el modelo completo en este caso, la SMP biblioteca vuelve a traducir el modelo al formato de punto de control Hugging Face Transformers de forma predeterminada.
En los casos en los que entrenes con el paralelismo SMP tensorial y desactives el proceso de SMP traducción, puedes usar el translate_on_save
argumento de PyTorch FullStateDictConfig
API para activar o desactivar la SMP traducción automática según sea necesario. Por ejemplo, si te estás centrando en entrenar un modelo, no necesitas añadir el proceso de traducción, lo que supone una sobrecarga. En ese caso, le recomendamos que configuretranslate_on_save=False
. Además, si planea seguir utilizando la SMP traducción del modelo para formación adicional en el futuro, puede desactivarla para guardar la SMP traducción del modelo para usarla más adelante. Es necesario volver a traducir el modelo al formato de punto de control del modelo Hugging Face Transformers cuando termines el entrenamiento de tu modelo y lo utilices como inferencia.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(
save_dir
, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir
, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir
) dist.barrier()
Ten en cuenta que la opción FullStateDictConfig(rank0_only=True,
offload_to_cpu=True)
es recopilar el modelo en un dispositivo CPU de rango 0 para ahorrar memoria al entrenar modelos grandes.
Para volver a cargar el modelo para su inferencia, haga lo que se muestra en el siguiente ejemplo de código. Ten en cuenta que la clase AutoModelForCausalLM
podría cambiar a otras clases de creación de factores en Hugging Face Transformers, por ejemplo, AutoModelForSeq2SeqLM
según tu modelo. Para obtener más información, consulte la documentación de Hugging Face Transformers
from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(
save_dir
)