Verificação de ativação - Amazon SageMaker

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

Verificação de ativação

O ponto de verificação de ativação (ou ponto de verificação de gradiente) é uma técnica para reduzir o uso de memória limpando as ativações de determinadas camadas e recomputando-as durante uma passagem para trás. Efetivamente, isso troca o tempo extra de computação pelo uso reduzido da memória. Se um módulo for verificado, no final de uma passagem direta, as entradas e saídas do módulo permanecerão na memória. Quaisquer tensores intermediários que teriam feito parte da computação dentro desse módulo são liberados durante a passagem para frente. Durante a passagem para trás dos módulos com pontos de verificação, esses tensores são recalculados. Nesse ponto, as camadas além desse módulo de ponto de verificação concluíram sua passagem para trás, portanto, o pico de uso da memória com o ponto de verificação pode ser menor.

nota

Esse recurso está disponível PyTorch na biblioteca de paralelismo de SageMaker modelos v1.6.0 e versões posteriores.

Como usar o ponto de verificação de ativação

Com smdistributed.modelparallel, você pode usar o ponto de verificação de ativação na granularidade de um módulo. Para todos os módulos torch.nn, exceto torch.nn.Sequential, você só pode verificar uma árvore de módulos se ela estiver dentro de uma partição do ponto de vista do paralelismo do pipeline. No caso do módulo torch.nn.Sequential, cada árvore de módulos dentro do módulo sequencial deve estar completamente dentro de uma partição para que o ponto de verificação de ativação funcione. Ao usar o particionamento manual, esteja ciente dessas restrições.

Ao usar o particionamento automatizado de modelos, você pode encontrar os registros de atribuição de particionamento começando com os registros Partition assignments: do trabalho de treinamento. Se um módulo for particionado em várias classificações (por exemplo, com um descendente em uma classificação e outro descendente em uma classificação diferente), a biblioteca ignora a tentativa de verificar o módulo e gera uma mensagem de aviso de que o módulo não será verificado.

nota

A biblioteca de paralelismo de SageMaker modelos suporta operações sobrepostas e não allreduce sobrepostas em combinação com pontos de verificação.

nota

PyTorchO ponto de verificação nativo do não API é compatível comsmdistributed.modelparallel.

Exemplo 1: O código de amostra a seguir mostra como usar o ponto de verificação de ativação quando você tem uma definição de modelo em seu script.

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

Exemplo 2: O código de amostra a seguir mostra como usar o ponto de verificação de ativação quando você tem um modelo sequencial em seu script.

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

Exemplo 3: O código de exemplo a seguir mostra como usar o ponto de verificação de ativação ao importar um modelo pré-construído de uma biblioteca, como Hugging Face PyTorch Transformers. Independentemente de você verificar os módulos sequenciais ou não, faça o seguinte:

  1. Embrulhe o modelo em smp.DistributedModel().

  2. Defina um objeto para camadas sequenciais.

  3. Enrole o objeto da camada sequencial por smp.set_activation_checkpointig().

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')