延迟参数初始化 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

延迟参数初始化

在GPU内存有限的情况下,并非总是可以初始化用于训练的大型模型。要解决GPU内存不足的问题,可以在CPU内存上初始化模型。但是,对于参数超过200亿或400亿的大型号,即使是CPU内存也可能不够。在这种情况下,我们建议您在所 PyTorch 谓的元设备上初始化模型,这样就可以在不附加任何数据的情况下创建张量。元设备上的张量只需要形状信息,这样就可以在元设备上创建一个带有参数的大型模型。Hugging Face Accelerate 提供了上下文管理器 init_empty_weights,以帮助在元设备上创建此类模型,同时在普通设备上初始化缓冲区。在训练开始之前, PyTorch FSDP初始化模型参数。SMPv2 的延迟参数初始化功能延迟了模型参数的创建,使其在 PyTorch FSDP执行参数分片之后发生。 PyTorch FSDP在对模块进行分片时接受参数初始化函数 (param_init_fn),它会调param_init_fn用每个模块。将模块作为param_init_fnAPI参数并初始化其中的所有参数,不包括任何子模块的参数。请注意,此行为与原生 PyTorch v2.0.1 不同,后者存在导致参数多次初始化的错误。

SMPv2 提供了torch.sagemaker.delayed_param.DelayedParamIniterAPI用于应用延迟参数初始化的。

以下代码片段展示了如何将应用torch.sagemaker.delayed_param.DelayedParamIniterAPI于您的训练脚本。

假设你有一个 PyTorch FSDP训练脚本,如下所示。

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

请注意,延迟参数初始化方法与模型无关。要解决此问题,您需要编写一个 init_weights 函数,如前面的示例所示,以匹配原始模型定义中的初始化,并且此函数应涵盖模型的所有参数。为了简化准备此类init_weights函数的过程,SMPv2 为以下模型实现了此初始化函数:Hugging Face Transformers 中的 GPT-GPT 2、GPT-J、-neoX 和 Llama。torch.sagemaker.delayed_param.DelayedParamIniterAPI也可以与SMP张量并行实现 m torch.sagemaker.tensor_parallel.transformer.TransformerLMHead odel 一起使用,你可以在调用后调用它torch.sagemaker.transformAPI。

使用 torch.sagemaker.delayed_param.DelayedParamIniterAPI,您可以按如下方式调整 PyTorch FSDP脚本。创建具有空权重的模型后,将其注册torch.sagemaker.delayed_param.DelayedParamIniterAPI到模型中,然后定义其对象。将对象传递给该param_init_fn PyTorch FSDP类的。

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

关于并列权重的注意事项

在训练带有绑定权重的模型时,我们需要特别注意在使用延迟参数初始化权重后绑定权重。 PyTorchFSDP在使用上述方法初始化权重后没有绑定权重param_init_fn的机制。为了解决此类情况,我们添加了 API allow apost_init_hook_fn,它可以用来绑定权重。您可以在其中传递任何接受模块作为参数的函数,但我们也在 DelayedParamIniter 中定义了一个预定义的 post_param_init_fn,如果模块中存在 tie_weights 方法,它就会调用此方法。请注意,即使模块没有 tie_weights 方法,在 post_param_init_fn 中传递也是安全的。

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )