本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
激活检查点技术通过清除某些层的激活并在向后传递期间重新计算它们,来减少内存使用量。实际上,这是用额外的计算时间来换取内存使用量的减少。如果对模块进行了检查点检查,则在正向传递结束时,只有该模块的初始输入和该模块的最终输出会保留在内存中。 PyTorch 在向前传递期间,释放作为该模块内部计算一部分的任何中间张量。在检查点模块的向后传递过程中, PyTorch 重新计算这些张量。此时,有检查点的模块之外的层已经完成其向后传递,因此检查点操作的峰值内存使用量会变得更低。
SMP v2 支持 PyTorch 激活检查点模块。apply_activation_checkpointing
Hugging Face GPT-NeoX 模型的检查点转换器层
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing
)
# check_fn receives a module as the arg,
# and it needs to return whether the module is to be checkpointed
def is_transformer_layer(module):
from transformers.models.gpt_neox import GPTNeoXLayer
return isinstance(submodule, GPTNeoXLayer)
apply_activation_checkpointing(model, check_fn=is_transformer_layer)
Hugging Face GPT-NeoX 模型的每个其他转换器层都要进行检查点检查
# check_fn receives a module as arg,
# and it needs to return whether the module is to be checkpointed
# here we define that function based on global variable (transformer_layers)
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing
)
transformer_layers = [
m for m model.modules() if isinstance(m, GPTNeoXLayer)
]
def is_odd_transformer_layer(module):
return transformer_layers.index(module) % 2 == 0
apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
或者, PyTorch 还有检查点torch.utils.checkpoint
模块,Hugging Face Transformers 模型的子集使用该模块。此模块也适用于 SMP v2。但是,这需要您有权访问模型定义,才能添加检查点封装器。因此,我们建议您使用 apply_activation_checkpointing
方法。