Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.
Los puntos de comprobación de activación son una técnica para reducir el uso de memoria al borrar las activaciones de determinadas capas y volver a calcularlas durante una pasada hacia atrás. De hecho, esto cambia el tiempo de cálculo adicional por un menor uso de memoria. Si se comprueba un módulo, al final de una transferencia hacia adelante, solo permanecen en la memoria las entradas iniciales del módulo y las salidas finales del módulo. PyTorch libera cualquier tensor intermedio que forme parte del cálculo dentro de ese módulo durante la pasada hacia adelante. Al pasar hacia atrás los módulos con puntos de control, PyTorch vuelve a calcular estos tensores. En este punto, las capas situadas más allá de este módulo de puntos de comprobación han terminado su pasada hacia atrás, por lo que el uso máximo de memoria con los puntos de comprobación es menor.
SMP v2 admite el módulo de puntos de control de activación, PyTorch . apply_activation_checkpointing
Capas del transformador de puntos de comprobación del modelo de Hugging Face GPT-NeoX
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing
)
# check_fn receives a module as the arg,
# and it needs to return whether the module is to be checkpointed
def is_transformer_layer(module):
from transformers.models.gpt_neox import GPTNeoXLayer
return isinstance(submodule, GPTNeoXLayer)
apply_activation_checkpointing(model, check_fn=is_transformer_layer)
Aplicación de puntos de comprobación a una capa del transformador de cada dos del modelo de Hugging Face GPT-NeoX
# check_fn receives a module as arg,
# and it needs to return whether the module is to be checkpointed
# here we define that function based on global variable (transformer_layers)
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing
)
transformer_layers = [
m for m model.modules() if isinstance(m, GPTNeoXLayer)
]
def is_odd_transformer_layer(module):
return transformer_layers.index(module) % 2 == 0
apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
Como alternativa, PyTorch también cuenta con el torch.utils.checkpoint
módulo de puntos de control, que es utilizado por un subconjunto de modelos de Hugging Face Transformers. Este módulo también funciona con SMP v2. Sin embargo, requiere que tenga acceso a la definición del modelo para añadir el encapsulador de puntos de comprobación. Por tanto, le recomendamos que utilice el método apply_activation_checkpointing
.