使用检查点功能 SMP - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用检查点功能 SMP

SageMaker 模型 parallelism (SMP) 库 PyTorch APIs支持检查点,并提供了APIs在使用库时正确检查点的SMP帮助。

PyTorch FSDP(完全分片数据并行度)支持三种类型的检查点:完整检查点、分片检查点和本地检查点,每种检查点都有不同的用途。训练完成后导出模型时会使用完整检查点,因为生成完整的检查点是一个计算成本很高的过程。分片检查点有助于保存和加载为每个等级分片的模型的状态。使用分片检查点,您可以使用不同的硬件配置(例如不同数量的)恢复训练。GPUs但是,由于涉及多个设备之间的通信,加载分片检查点可能会很慢。该SMP库提供了本地检查点功能,可以更快地检索模型的状态,而无需额外的通信开销。请注意,由创建的检查点FSDP需要写入共享网络文件系统,例如 Amazon FSx。

异步本地检查点

训练机器学习模型时,无需进行后续迭代来等待检查点文件保存到磁盘。随着 SMP v2.5 的发布,该库支持异步保存检查点文件。这意味着随后的训练迭代可以与用于创建检查点的输入和输出 (I/O) 操作同时运行,而不会因为这些 I/O 操作而减慢或阻碍。此外,由于跨等级交换分布式张量元数据需要额外的集体通信, PyTorch 因此在中检索分片模型和优化器参数的过程可能很耗时。即使使用StateDictType.LOCAL_STATE_DICT保存每个等级的本地检查点, PyTorch 仍会调用执行集体通信的挂钩。为了缓解此问题并减少检索检查点所需的时间SMStateDictType.SM_LOCAL_STATE_DICT,SMP引入了通过绕过集体通信开销来更快地检索模型和优化器检查点。

注意

保持一致性FSDPSHARD_DEGREE是利用的必要条件SMStateDictType.SM_LOCAL_STATE_DICT。确保SHARD_DEGREE保持不变。虽然模型复制的数量可能有所不同,但从检查点恢复时,模型分片度必须与之前的训练设置相同。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

以下代码片段演示了如何使用SMStateDictType.SM_LOCAL_STATE_DICT加载检查点。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

存储大型语言模型的检查点 (LLMs) 可能很昂贵,因为它通常需要创建较大的文件系统容量。为了降低成本,您可以选择将检查点直接保存到 Amazon S3,而无需其他文件系统服务,例如 Amazon FSx。您可以利用前面的示例和以下代码片段,通过将 S3 指定URL为目标来将检查点保存到 S3。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

异步分片检查点

在某些情况下,您可能需要继续使用不同的硬件配置进行训练,例如更改硬件配置的数量GPUs。在这些情况下,您的训练过程必须在重新分片时加载检查点,这意味着要使用不同的数量来恢复后续训练。SHARD_DEGREE为了解决需要使用不同数量的恢复训练的情况SHARD_DEGREE,必须使用分片状态字典类型保存模型检查点,该类型由StateDictType.SHARDED_STATE_DICT表示。以这种格式保存检查点可以让您在使用修改后的硬件配置继续训练时正确处理重新分片过程。提供的代码片段说明了如何使用异步保存分片检查点,从而实现更高效、更简化的训练过程。tsm API

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

加载共享检查点的过程与上一节类似,但它涉及使用torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader及其load方法。该类的load方法允许您按照与前面描述的过程类似的过程加载共享的检查点数据。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

完整模型检查点

训练结束时,您可以保存一个完整的检查点,该检查点将模型的所有分片合并到一个模型检查点文件中。该SMP库完全支持 PyTorch 完整的模型检查点API,因此您无需进行任何更改。

请注意,如果您使用 SMP张量并行性,则SMP库会变换模型。在这种情况下,当对完整模型进行检查点时,默认情况下,SMP库会将模型转换回 Hugging Face Transformers 检查点格式。

如果您使用SMP张量并行度进行训练并关闭SMP翻译过程,则可以根据需要使用的translate_on_save参数 PyTorch FullStateDictConfigAPI来打开或关闭SMP自动翻译。例如,如果您专注于训练模型,则无需添加会增加开销的翻译过程。在这种情况下,我们建议您进行设置translate_on_save=False。此外,如果您计划在将来继续使用模型的SMP平移进行进一步训练,则可以将其关闭以保存模型的SMP平移以备日后使用。当你总结模型的训练并使用它进行推理时,需要将模型转换回 Hugging Face Transformers 模型检查点格式。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

请注意,在训练大型模型时,可以选择FullStateDictConfig(rank0_only=True, offload_to_cpu=True)将CPU模型收集到第 0 级设备上,以节省内存。

要重新加载模型以进行推理,您可以按照以下代码示例所示进行操作。请注意,该类AutoModelForCausalLM可能会更改为 Hugging Face Transformers 中的其他因子生成器类,AutoModelForSeq2SeqLM例如,具体取决于您的模型。有关更多信息,请参阅 Hugging Face 变形金刚文档

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)