Puntos de control de activación - Amazon SageMaker AI

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

Puntos de control de activación

Los puntos de control de activación (o puntos de control de gradiente) son una técnica para reducir el uso de memoria al borrar las activaciones de determinadas capas y volver a calcularlas durante una pasada hacia atrás. De hecho, esto cambia el tiempo de cálculo adicional por un menor uso de memoria. Si se comprueba un módulo, al final de una pasada hacia adelante, las entradas y salidas del módulo permanecen en la memoria. Todos los tensores intermedios que hubieran formado parte del cálculo dentro de ese módulo se liberan durante la pasada hacia adelante. Durante la pasada hacia atrás de los módulos con puntos de control, estos tensores vuelven a calcularse. En este punto, las capas situadas más allá de este módulo de puntos de control han terminado su pasada hacia atrás, por lo que el uso máximo de memoria con los puntos de control puede ser menor.

nota

Esta función está disponible en la biblioteca de paralelismo de modelos, versión PyTorch 1.6.0 y versiones posteriores SageMaker .

Cómo utilizar los puntos de control de activación

Con smdistributed.modelparallel, puede utilizar los puntos de control de activación en la granularidad de un módulo. En todos los módulos torch.nn, excepto torch.nn.Sequential, solo puede controlar un árbol de módulos si se encuentra dentro de una partición desde la perspectiva del paralelismo de canalización. En el caso del módulo torch.nn.Sequential, cada árbol de módulos del módulo secuencial debe estar completamente dentro de una partición para que los puntos de control de activación funcionen. Cuando utilice la división manual, tenga en cuenta estas restricciones.

Cuando utiliza la división automatizada de modelos, puede encontrar los registros de asignación de particiones que comienzan con Partition assignments: en los registros de trabajos de entrenamiento. Si un módulo está dividido en varios rangos (por ejemplo, con un descendiente en un rango y otro descendiente en un rango diferente), la biblioteca ignora el intento de poner puntos de control al módulo y muestra un mensaje de advertencia que indica que el módulo no estará sujeto a puntos de verificación.

nota

La biblioteca de paralelismo de SageMaker modelos admite operaciones de superposición y no superposición en combinación con puntos de control. allreduce

nota

PyTorchSu API de puntos de control nativa no es compatible con. smdistributed.modelparallel

Ejemplo 1: el siguiente código de ejemplo muestra cómo utilizar los puntos de control de activación cuando se tiene una definición de modelo en el script.

import torch.nn as nn import torch.nn.functional as F from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = F.max_pool2d(x, 2) x = torch.flatten(x, 1) # This call of fc1 will be checkpointed x = checkpoint(self.fc1, x) x = self.fc2(x) return F.log_softmax(x, 1)

Ejemplo 2: el siguiente código de ejemplo muestra cómo utilizar los puntos de control de activación cuando se tiene un modelo secuencial en el script.

import torch.nn as nn from smdistributed.modelparallel.torch.patches.checkpoint import checkpoint_sequential class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): # This call of self.seq will be checkpointed x = checkpoint_sequential(self.seq, x) return F.log_softmax(x, 1)

Ejemplo 3: El siguiente código de ejemplo muestra cómo utilizar los puntos de control de activación al importar un modelo prediseñado de una biblioteca, como Hugging Face PyTorch Transformers. Tanto si comprueba los módulos secuenciales como si no, haga lo siguiente:

  1. Encapsule el modelo con smp.DistributedModel().

  2. Defina un objeto para las capas secuenciales.

  3. Encapsule el objeto de la capa secuencial con smp.set_activation_checkpointig().

import smdistributed.modelparallel.torch as smp from transformers import AutoModelForCausalLM smp.init() model = AutoModelForCausalLM(*args, **kwargs) model = smp.DistributedModel(model) # Call set_activation_checkpointing API transformer_layers = model.module.module.module.transformer.seq_layers smp.set_activation_checkpointing( transformer_layers, pack_args_as_tuple=True, strategy='each')