As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.
Verificação de ativação
O ponto de verificação de ativação (ou ponto de verificação de gradiente) é uma técnica para reduzir o uso de memória limpando as ativações de determinadas camadas e recomputando-as durante uma passagem para trás. Efetivamente, isso troca o tempo extra de computação pelo uso reduzido da memória. Se um módulo for verificado, no final de uma passagem direta, as entradas e saídas do módulo permanecerão na memória. Quaisquer tensores intermediários que teriam feito parte da computação dentro desse módulo são liberados durante a passagem para frente. Durante a passagem para trás dos módulos com pontos de verificação, esses tensores são recalculados. Nesse ponto, as camadas além desse módulo de ponto de verificação concluíram sua passagem para trás, portanto, o pico de uso da memória com o ponto de verificação pode ser menor.
nota
Esse recurso está disponível PyTorch na biblioteca de paralelismo de SageMaker modelos v1.6.0 e versões posteriores.
Como usar o ponto de verificação de ativação
Com smdistributed.modelparallel
, você pode usar o ponto de verificação de ativação na granularidade de um módulo. Para todos os módulos torch.nn
, exceto torch.nn.Sequential
, você só pode verificar uma árvore de módulos se ela estiver dentro de uma partição do ponto de vista do paralelismo do pipeline. No caso do módulo torch.nn.Sequential
, cada árvore de módulos dentro do módulo sequencial deve estar completamente dentro de uma partição para que o ponto de verificação de ativação funcione. Ao usar o particionamento manual, esteja ciente dessas restrições.
Ao usar o particionamento automatizado de modelos, você pode encontrar os registros de atribuição de particionamento começando com os registros Partition assignments:
do trabalho de treinamento. Se um módulo for particionado em várias classificações (por exemplo, com um descendente em uma classificação e outro descendente em uma classificação diferente), a biblioteca ignora a tentativa de verificar o módulo e gera uma mensagem de aviso de que o módulo não será verificado.
nota
A biblioteca de paralelismo de SageMaker modelos suporta operações sobrepostas e não allreduce
sobrepostas em combinação com pontos de verificação.
nota
PyTorchO ponto de verificação nativo do não API é compatível comsmdistributed.modelparallel
.
Exemplo 1: O código de amostra a seguir mostra como usar o ponto de verificação de ativação quando você tem uma definição de modelo em seu script.
import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)
Exemplo 2: O código de amostra a seguir mostra como usar o ponto de verificação de ativação quando você tem um modelo sequencial em seu script.
import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)
Exemplo 3: O código de exemplo a seguir mostra como usar o ponto de verificação de ativação ao importar um modelo pré-construído de uma biblioteca, como Hugging Face PyTorch Transformers. Independentemente de você verificar os módulos sequenciais ou não, faça o seguinte:
-
Embrulhe o modelo em
smp.DistributedModel()
. -
Defina um objeto para camadas sequenciais.
-
Enrole o objeto da camada sequencial por
smp.set_activation_checkpointig()
.
import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')