参数初始化延迟 - Amazon SageMaker

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

参数初始化延迟

在GPU内存有限的情况下,并非总是可以初始化用于训练的大型模型。要解决GPU内存不足的问题,可以在CPU内存上初始化模型。但是,对于参数超过200亿或400亿的大型号,即使是CPU内存也可能不够。在这种情况下,我们建议您在所 PyTorch 谓的元设备上初始化模型,这样就可以在不附加任何数据的情况下创建张量。元设备上的张量只需要形状信息,这允许在元设备上创建带有其参数的大型模型。Hugging Fac e Accelerate 提供了上下文init_empty_weights管理器,可帮助在元设备上创建此类模型,同时在普通设备上初始化缓冲区。在训练开始之前, PyTorch FSDP初始化模型参数。SMPv2 的延迟参数初始化功能延迟了模型参数的创建,使其在 PyTorch FSDP执行参数分片之后发生。 PyTorch FSDP在对模块进行分片时接受参数初始化函数 (param_init_fn),它会调param_init_fn用每个模块。将模块作为param_init_fnAPI参数并初始化其中的所有参数,不包括任何子模块的参数。请注意,此行为与原生 PyTorch v2.0.1 不同,后者存在导致参数多次初始化的错误。

SMPv2 提供了torch.sagemaker.delayed_param.DelayedParamIniterAPI用于应用延迟参数初始化的。

以下代码片段展示了如何将应用torch.sagemaker.delayed_param.DelayedParamIniterAPI于您的训练脚本。

假设你有一个 PyTorch FSDP训练脚本,如下所示。

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

请注意,延迟参数初始化方法与模型无关。要解决此问题,您需要编写一个init_weights函数,如前面的示例所示,以匹配原始模型定义中的初始化,并且该函数应涵盖模型的所有参数。为了简化准备此类init_weights函数的过程,SMPv2 为以下模型实现了此初始化函数:Hugging Face Transformers 中的 GPT-GPT 2、GPT-J、-neoX 和 Llama。torch.sagemaker.delayed_param.DelayedParamIniterAPI也可以与SMP张量并行实现 m torch.sagemaker.tensor_parallel.transformer.TransformerLMHead odel 一起使用,你可以在调用后调用它torch.sagemaker.transformAPI。

使用 torch.sagemaker.delayed_param.DelayedParamIniterAPI,您可以按如下方式调整 PyTorch FSDP脚本。创建具有空权重的模型后,将其注册torch.sagemaker.delayed_param.DelayedParamIniterAPI到模型中,然后定义其对象。将对象传递给该param_init_fn PyTorch FSDP类的。

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

关于绑定砝码的注意事项

在训练带有绑定权重的模型时,我们需要特别注意在使用延迟参数初始化权重后绑定权重。 PyTorchFSDP在使用上述方法初始化权重后没有绑定权重param_init_fn的机制。为了解决此类情况,我们添加了 API allow apost_init_hook_fn,它可以用来绑定权重。你可以在其中传递任何接受模块作为参数的函数,但我们也有一个预post_param_init_fn定义的,如果模块存在则调用DelayedParamIniter该模块tie_weights的方法。请注意,post_param_init_fn即使该模块没有tie_weights方法,也始终可以安全地传入。

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )