啟用檢查點 - Amazon SageMaker

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

啟用檢查點

啟用檢查點 (或漸層檢查點) 是減少記憶體使用量的技術,方法是清除某些圖層的啟用,並在向後傳遞期間重新加以運算。實際上,這是以額外運算時間換取減少記憶體使用量。如果模組進行檢查點作業,則在向前傳遞結束時,模組的輸入與輸出都將保留在記憶體。在向前傳遞期間,原本會成為該模組內部運算一部分的任何中級張量都將被釋放。在檢查點模組的向後傳遞期間,會重新運算這些張量。此時,超出此檢查點模組的圖層已完成其向後傳遞,因此可降低運用檢查點的最高記憶體使用量。

注意

此功能適用於 PyTorch SageMaker 模型平行程式庫 v1.6.0 及更新版本。

如何使用啟用檢查點

當使用 smdistributed.modelparallel 時,您可以在模組的精細程度使用啟用檢查點。對於除 torch.nn.Sequential 外的所有 torch.nn 模組,僅當從管道平行處理的角度而言,模組樹狀目錄位於單一分割內時,您才能對其進行檢查點作業。對於 torch.nn.Sequential 模組,循序模組內部的每個模組樹狀目錄必須完全位於單一分割內,以便啟用檢查點作業。當您使用手動分割時,請注意這些限制。

當您使用自動化模型分割時,您可以在訓練任務日誌找到開頭為 Partition assignments: 的分割指派日誌。如果跨多個等級分割模組 (例如,其中一個子代位於某一等級,另一子代位於不同等級),程式庫會忽略而不嘗試對模組進行檢查點作業,並提出警告訊息,指出不會檢查該模組。

注意

SageMaker 模型平行程式庫支援重疊和非重疊allreduce作業,並結合檢查點。

注意

PyTorch的本機檢查點API與smdistributed.modelparallel.

範例 1:下列範例程式碼示範當指令碼具模型定義時,如何使用啟用檢查點。

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

範例 2:下列範例程式碼示範當指令碼具循序模型時,如何使用啟用檢查點。

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

例 3:下列範例程式碼顯示如何在從程式庫匯入預先建置的模型 (例如「Hugging Face 變壓器」) 時使用啟動檢查點。 PyTorch 無論您是否針對循序模組進行檢查點作業,請執行以下操作:

  1. smp.DistributedModel() 包裝模型。

  2. 定義循序圖層物件。

  3. smp.set_activation_checkpointig() 包裝循序圖層物件。

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')