기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.
활성화 체크포인트(또는 그라디언트 체크포인트)는 특정 레이어의 활성화를 지우고 역방향 패스 중에 이를 다시 계산하여 메모리 사용량을 줄이는 기법입니다. 이렇게 하면 추가 계산 시간이 줄어들어 메모리 사용량이 줄어듭니다. 모듈이 체크포인트로 지정된 경우 순방향 패스가 끝날 때 모듈의 입력과 출력은 메모리에 남습니다. 해당 모듈 내 계산의 일부를 구성한 모든 중간 텐서는 순방향 패스 중에 비워집니다. 체크포인트 모듈을 역방향으로 패스하는 동안 이러한 텐서는 다시 계산됩니다. 이 시점에서 이 체크포인트 모듈 뒤의 레이어는 역방향 패스를 완료했으므로 체크포인트의 최대 메모리 사용량을 줄일 수 있습니다.
참고
이 기능은 SageMaker 모델 병렬 처리 라이브러리 v1.6.0 이상에서 PyTorch에 사용할 수 있습니다.
활성화 체크포인트 사용 방법
smdistributed.modelparallel
을 사용하면 모듈의 세부 수준에서 활성화 체크포인트를 사용할 수 있습니다. torch.nn.Sequential
을 제외한 모든 torch.nn
모듈의 경우 파이프라인 병렬 처리 관점에서 볼 때 모듈 트리가 한 파티션 내에 있는 경우에만 모듈 트리를 체크포인트할 수 있습니다. torch.nn.Sequential
모듈의 경우 활성화 체크포인트가 작동하려면 순차 모듈 내의 각 모듈 트리가 완전히 한 파티션 내에 있어야 합니다. 수동 분할을 사용할 때는 이러한 제한 사항에 유의하세요.
자동 모델 분할을 사용하는 경우 훈련 작업 로그에서 Partition assignments:
로 시작하는 분할 할당 로그를 확인할 수 있습니다. 모듈이 여러 랭크로 분할된 경우(예: 한 랭크에 하위 항목 하나가 있고 다른 랭크에 또 다른 하위 항목 하나) 라이브러리는 모듈을 체크포인트하려는 시도를 무시하고 모듈을 체크포인트하지 않을 것이라는 경고 메시지를 표시합니다.
참고
SageMaker 모델 병렬 처리 라이브러리는 겹치는 작업 및 겹치지 않는 allreduce
작업 모두를 체크포인트와 함께 지원합니다.
참고
PyTorch의 네이티브 체크포인트 API는 smdistributed.modelparallel
과 호환되지 않습니다.
예제 1: 다음 샘플 코드는 스크립트에 모델 정의가 있을 때 활성화 체크포인트를 사용하는 방법을 보여줍니다.
import torch.nn as nn
import torch.nn.functional as F
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
# This call of fc1 will be checkpointed
x = checkpoint(self.fc1, x)
x = self.fc2(x)
return F.log_softmax(x, 1)
예제 2: 다음 샘플 코드는 스크립트에 순차적 모델이 있을 때 활성화 체크포인트를 사용하는 방법을 보여줍니다.
import torch.nn as nn
from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
def forward(self, x):
# This call of self.seq will be checkpointed
x = checkpoint_sequential(self.seq, x)
return F.log_softmax(x, 1)
예제 3: 다음 샘플 코드는 PyTorch 및 Hugging Face 변환기처럼 라이브러리에서 사전 구축된 모델을 가져올 때 활성화 체크포인트를 사용하는 방법을 보여줍니다. 순차 모듈 체크포인트 여부에 관계없이 다음을 수행합니다.
-
smp.DistributedModel()
로 모델을 래핑합니다. -
순차 계층용 객체를 정의합니다.
-
smp.set_activation_checkpointig()
로 순차 계층 객체를 래핑합니다.
import smdistributed.modelparallel.torch as smp
from transformers import AutoModelForCausalLM
smp.init()
model = AutoModelForCausalLM(*args, **kwargs)
model = smp.DistributedModel(model)
# Call set_activation_checkpointing API
transformer_layers = model.module.module.module.transformer.seq_layers
smp.set_activation_checkpointing(
transformer_layers, pack_args_as_tuple=True, strategy='each')