アクティベーションチェックポイント - Amazon SageMaker AI

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

アクティベーションチェックポイント

アクティベーションチェックポイントは、特定の層のアクティベーションを消去し、バックワードパス (逆伝播) の間に再計算することでメモリ使用量を削減する手法です。実質的に、計算時間が増える代わりに、メモリ使用量が削減します。モジュールをチェックポイントする場合、フォワードパス (順伝播) の終了時に、モジュールへの初期入力とモジュールからの最終出力だけがメモリに保持されます。そのモジュール内の計算に関するすべての中間テンソルは、フォワードパスの間に解放されます。チェックポイントしたモジュールのバックワードパスの間に、これらのテンソルは再計算されます。この時点で、チェックポイントしたこのモジュール以降の層ではバックワードパスが終了しているため、チェックポイントの使用時はピークメモリ使用量が低くなります。

SMP v2 は、PyTorch のアクティベーションチェックポイントモジュール apply_activation_checkpointing をサポートしています。以下は、Hugging Face GPT-NeoX モデルのアクティベーションチェックポイントの例です。

Hugging Face GPT-NeoX モデルの Transformer 層をチェックポイントする

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

Hugging Face GPT-NeoX モデルの Transformer 層を 1 つおきにチェックポイントする

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

または、PyTorch にはチェックポイント用の torch.utils.checkpoint モジュールもあり、Hugging Face Transformers モデルの一部で使用されています。このモジュールは SMP v2 でも動作します。ただし、チェックポイントラッパーを追加するために、モデル定義にアクセスできる必要があります。そのため、apply_activation_checkpointing メソッドの使用を推奨します。