本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
激活检查点
激活检查点(或梯度检查点)技术通过清除某些层的激活并在向后传递期间重新计算它们,来减少内存使用量。实际上,这是用额外的计算时间来换取内存使用量的减少。如果对模块执行了检查点操作,则在向前传递结束时,该模块的输入和输出将保留在内存中。在向前传递期间,任何本应是模块内部计算一部分的中间张量都会被释放。在有检查点的模块的向后传递过程中,会重新计算这些张量。此时,有检查点的模块之外的层已经完成其向后传递,因此检查点操作的峰值内存使用量可能会更低。
注意
此功能可在 SageMaker 模型并行度库 v1.6.0 及更高版本 PyTorch 中使用。
如何使用激活检查点
使用 smdistributed.modelparallel
,您可以按模块使用激活检查点。对于除 torch.nn.Sequential
之外的所有 torch.nn
模块,只有当从管道并行性的角度来看,模块树位于一个分区内时,才能对模块树执行检查点操作。对于 torch.nn.Sequential
模块,顺序模块内的每个模块树必须完全位于一个分区内,激活检查点才能起作用。使用手动分区时,请注意这些限制。
使用自动模型分区时,您可在训练作业日志中找到以 Partition assignments:
开头的分区分配日志。如果一个模块在多个秩(例如,一个后代属于一个秩,另一个后代处于不同的秩)上分区,则库会忽略对模块执行检查点的尝试,并发出一条警告消息,说明该模块没有检查点。
注意
SageMaker 模型并行度库支持重叠和非重叠操作以及检查点allreduce
操作。
注意
PyTorch的本机检查点与API不兼容。smdistributed.modelparallel
示例 1:以下示例代码演示了当脚本中有模型定义时,如何使用激活检查点操作。
import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)
示例 2:以下示例代码演示了当脚本中有顺序模型时,如何使用激活检查点操作。
import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)
示例 3:以下示例代码显示了在从库(例如和 Hugging Face Transformers PyTorch )导入预建模型时如何使用激活检查点。无论您是否对顺序模型执行检查点操作,请完成以下过程:
-
使用
smp.DistributedModel()
包装模型。 -
为顺序层定义一个对象。
-
使用
smp.set_activation_checkpointig()
包装顺序层对象。
import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')