アクティベーションチェックポイント - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

アクティベーションチェックポイント

アクティベーションチェックポイント (または勾配チェックポイント) は、特定のレイヤーのアクティベーションをクリアし、バックワードパス時に再計算することでメモリ使用量を削減する手法です。これは事実上、余分な計算時間を抑え、メモリ使用量の削減につながります。モジュールがチェックポイントされると、フォワードパスの終了時に、モジュールへの入力とモジュールからの出力はメモリに残ります。そのモジュール内の計算に含まるはずの中間テンソルがフォワードパス中に解放されます。チェックポイントされたモジュールのバックワードパス中に、これらのテンソルは再計算されます。この時点で、このチェックポイントモジュールを超えるレイヤーはバックワードパスを終了しているため、チェックポイントを使用した場合のピークメモリ使用量は、これより少なくなる可能性があります。

注記

この機能は、 SageMaker モデル並列処理ライブラリ v1.6.0 以降 PyTorch で で使用できます。

アクティベーションチェックポイントの使用方法

smdistributed.modelparallel では、モジュール単位の精度でアクティベーションチェックポイントを使用できます。torch.nn.Sequential を除くすべての torch.nn モジュールでは、パイプライン並列処理の観点から、モジュールツリーが 1 つのパーティション内にある場合にのみモジュールツリーをチェックポイントできます。torch.nn.Sequential モジュールの場合、アクティベーションチェックポイントを機能させるためには、シーケンシャルモジュール内の各モジュールツリーが完全に 1 つのパーティション内にある必要があります。手動パーティショニングを使用するときは、次の制限に注意してください。

自動モデルパーティショニングを使用する場合は、トレーニングジョブログの Partition assignments: で始まるパーティション割り当てログを探すことができます。モジュールが複数のランクにわたってパーティション化されている場合 (例えば、ある子孫があるランクに、別の子孫が別のランクにある場合)、ライブラリはモジュールのチェックポイントの試みを無視し、モジュールはチェックポイントされないという警告メッセージを表示します。

注記

SageMaker モデル並列処理ライブラリは、チェックポイントと組み合わせて、重複するオペレーションと重複しないallreduceオペレーションの両方をサポートします。

注記

PyTorchのネイティブチェックポイントAPIは と互換性がありませんsmdistributed.modelparallel

例 1: 次のサンプルコードは、スクリプトにモデル定義がある場合にアクティベーションチェックポイントを使用する方法を示しています。

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

例 2: 次のサンプルコードは、スクリプトにシーケンシャルモデルがある場合にアクティベーションチェックポイントを使用する方法を示しています。

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

例 3: 次のサンプルコードは、 PyTorch や Hugging Face Transformers などのライブラリから構築済みモデルをインポートするときにアクティベーションチェックポイントを使用する方法を示しています。シーケンシャルモジュールをチェックポイントするかどうかにかかわらず、次の操作を行います。

  1. smp.DistributedModel() でモデルをラップする。

  2. シーケンシャルレイヤーのオブジェクトを定義する。

  3. シーケンシャルレイヤーのオブジェクトを smp.set_activation_checkpointig() でラップする。

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')