Checkpointing mit SMP - Amazon SageMaker

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

Checkpointing mit SMP

Die SageMaker Modellparallelism (SMP) -Bibliothek unterstützt PyTorch APIs Checkpoints und bietet APIs diese Hilfe bei der korrekten Verwendung der Bibliothek. SMP

PyTorch FSDP(Fully Sharded Data Parallelism) unterstützt drei Arten von Checkpoints: vollständige Checkpoints, Sharded und Local, die jeweils unterschiedlichen Zwecken dienen. Vollständige Checkpoints werden verwendet, wenn das Modell nach Abschluss des Trainings exportiert wird, da die Generierung eines vollständigen Checkpoints ein rechenintensiver Prozess ist. Mit Hilfe von Sharded Checkpoints kann der Status eines Modells, das für jeden einzelnen Rang geteilt wurde, gespeichert und geladen werden. Mithilfe von Prüfpunkten in mehreren Gruppen können Sie das Training mit unterschiedlichen Hardwarekonfigurationen fortsetzen, z. B. mit einer anderen Anzahl von. GPUs Das Laden von Shard-Checkpoints kann jedoch aufgrund der Kommunikation zwischen mehreren Geräten langsam sein. Die SMP Bibliothek bietet lokale Checkpoint-Funktionen, die ein schnelleres Abrufen des Modellstatus ohne zusätzlichen Kommunikationsaufwand ermöglichen. Beachten Sie, dass Checkpoints, die von erstellt wurden, das Schreiben in ein gemeinsam genutztes Netzwerk-Dateisystem wie Amazon FSDP FSx erfordern.

Asynchrone lokale Checkpoints

Beim Training von Modellen für maschinelles Lernen müssen keine nachfolgenden Iterationen durchgeführt werden, um darauf zu warten, dass die Checkpoint-Dateien auf der Festplatte gespeichert werden. Mit der Veröffentlichung von SMP v2.5 unterstützt die Bibliothek das asynchrone Speichern von Checkpoint-Dateien. Das bedeutet, dass die nachfolgende Trainingsiteration gleichzeitig mit den Eingabe- und Ausgabeoperationen (I/O) zur Erstellung von Checkpoints ausgeführt werden kann, ohne durch diese I/O-Operationen verlangsamt oder behindert zu werden. Außerdem PyTorch kann das Abrufen von Shard-Modell- und Optimizer-Parametern zeitaufwändig sein, da zusätzliche kollektive Kommunikation erforderlich ist, um verteilte Tensor-Metadaten zwischen Rängen auszutauschen. Selbst wenn es verwendet wird, StateDictType.LOCAL_STATE_DICT um lokale Checkpoints für jeden Rang zu speichern, ruft es PyTorch immer noch Hooks auf, die kollektive Kommunikation durchführen. Um dieses Problem zu beheben und den Zeitaufwand für das Abrufen von Checkpoints zu reduzieren, SMP wird eingeführtSMStateDictType.SM_LOCAL_STATE_DICT, was einen schnelleren Abruf von Modell- und Optimierer-Checkpoints ermöglicht, indem der kollektive Kommunikationsaufwand umgangen wird.

Anmerkung

Die Aufrechterhaltung der Konsistenz ist eine Voraussetzung für die Nutzung von FSDPSHARD_DEGREE. SMStateDictType.SM_LOCAL_STATE_DICT Stellen Sie sicher, dass das unverändert SHARD_DEGREE bleibt. Die Anzahl der Modellreplikationen kann zwar variieren, der Grad der Modellsplitter muss jedoch mit dem vorherigen Trainingsaufbau identisch sein, wenn der Vorgang von einem Checkpoint aus fortgesetzt wird.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

Der folgende Codeausschnitt zeigt, wie ein Checkpoint mithilfe von geladen wird. SMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

Das Speichern von Checkpoints für umfangreiche Sprachmodelle (LLMs) kann teuer sein, da dafür oft ein großes Dateisystemvolumen erstellt werden muss. Um die Kosten zu senken, haben Sie die Möglichkeit, Checkpoints direkt in Amazon S3 zu speichern, ohne dass zusätzliche Dateisystemdienste wie Amazon erforderlich sind. FSx Sie können das vorherige Beispiel mit dem folgenden Codeausschnitt nutzen, um Checkpoints in S3 zu speichern, indem Sie S3 als Ziel angeben. URL

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Asynchrone, fragmentierte Checkpoints

Es kann Situationen geben, in denen Sie das Training mit unterschiedlichen Hardwarekonfigurationen fortsetzen müssen, z. B. wenn Sie die Anzahl der Geräte ändern. GPUs In diesen Fällen müssen Ihre Trainingsprozesse beim Resharding Checkpoints laden, was bedeutet, dass nachfolgende Schulungen mit einer anderen Anzahl von wieder aufgenommen werden müssen. SHARD_DEGREE Um das Szenario zu lösen, in dem Sie das Training mit einer anderen Anzahl von fortsetzen müssenSHARD_DEGREE, müssen Sie Ihre Modell-Checkpoints mithilfe des Wörterbuchtyps Sharded State speichern, der durch dargestellt wird. StateDictType.SHARDED_STATE_DICT Wenn Sie Checkpoints in diesem Format speichern, können Sie den Resharding-Prozess ordnungsgemäß durchführen, wenn Sie das Training mit einer geänderten Hardwarekonfiguration fortsetzen. Der bereitgestellte Codeausschnitt veranschaulicht, wie Sie Shard-Checkpoints asynchron speichern können, was einen effizienteren und optimierten Trainingsprozess ermöglicht. tsm API

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Der Vorgang zum Laden gemeinsam genutzter Checkpoints ähnelt dem vorherigen Abschnitt, beinhaltet jedoch die Verwendung der Methode und ihrer Methode. torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader load Die load Methode dieser Klasse ermöglicht es Ihnen, die gemeinsamen Checkpoint-Daten nach einem Verfahren zu laden, das dem zuvor beschriebenen Vorgang entspricht.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Vollständige Modell-Checkpoints

Am Ende des Trainings können Sie einen vollständigen Checkpoint speichern, der alle Shards eines Modells in einer einzigen Modell-Checkpoint-Datei zusammenfasst. Die SMP Bibliothek unterstützt die PyTorch vollständigen Modell-Checkpoints vollständigAPI, sodass Sie keine Änderungen vornehmen müssen.

Beachten Sie, dass die SMP Bibliothek das Modell transformiert SMPTensor-Parallelität, wenn Sie die verwenden. Wenn in diesem Fall das vollständige Modell überprüft wird, übersetzt die SMP Bibliothek das Modell standardmäßig zurück in das Checkpoint-Format von Hugging Face Transformers.

In Fällen, in denen Sie mit der SMP Tensorparallelität trainieren und den SMP Übersetzungsprozess ausschalten, können Sie das translate_on_save Argument von verwenden, PyTorch FullStateDictConfig API um die SMP automatische Übersetzung nach Bedarf ein- oder auszuschalten. Wenn Sie sich beispielsweise darauf konzentrieren, ein Modell zu trainieren, müssen Sie den Übersetzungsprozess nicht hinzufügen, was den Overhead erhöht. In diesem Fall empfehlen wir Ihnen, die Einstellung vorzunehmentranslate_on_save=False. Wenn Sie die SMP Übersetzung des Modells auch in future für weitere Schulungen verwenden möchten, können Sie sie ausschalten, um die SMP Übersetzung des Modells für die spätere Verwendung zu speichern. Die Rückübersetzung des Modells in das Modell-Checkpoint-Format von Hugging Face Transformers ist erforderlich, wenn Sie das Training Ihres Modells abschließen und es für Inferenzen verwenden.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Beachten Sie, dass die Option FullStateDictConfig(rank0_only=True, offload_to_cpu=True) darin besteht, das Modell auf dem Gerät CPU der obersten Stufe zu sammeln, um beim Training großer Modelle Speicherplatz zu sparen.

Um das Modell zur Inferenz wieder zu laden, gehen Sie wie im folgenden Codebeispiel gezeigt vor. Beachten Sie, dass die Klasse in Hugging Face Transformers AutoModelForCausalLM möglicherweise zu anderen Factor Builder-Klassen wechseltAutoModelForSeq2SeqLM, z. B. je nach Modell. Weitere Informationen finden Sie in der Dokumentation zu Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)