啟用檢查點 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

啟用檢查點

啟用檢查點是一種技術,透過清除某些層的啟用,並在向後傳遞期間重新計算它們來減少記憶體用量。實際上,這會轉換額外的運算時間,以減少記憶體用量。如果對模組進行檢查點,則在向前傳遞結束時,只有模組的初始輸入和模組的最終輸出會保留在記憶體中。 在向前傳遞期間, 會 PyTorch 釋放該模組內部運算中包含的任何中繼張數。在檢查點模組的向後通過期間, PyTorch 會重新計算這些張量。此時,此檢查點模組之外的圖層已完成向後傳遞,因此檢查點的尖峰記憶體用量會降低。

SMP v2 支援 PyTorch 啟用檢查點模組 apply_activation_checkpointing。以下是 Hugging Face GPT-NeoX 模型啟用檢查點的範例。

Hugging Face GPT-NeoX 模型的檢查點轉換器層

from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)

檢查 Hugging Face GPT-NeoX 模型的所有其他轉換器層

# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)

或者, PyTorch 也有用於檢查點的模組,Hugging Face Transformer 模型的子集會使用該torch.utils.checkpoint模組。此模組也適用於 SMP v2。不過,它需要您存取模型定義以新增檢查點包裝。因此,我們建議您使用 apply_activation_checkpointing方法。