SMP を使用したチェックポイント - Amazon SageMaker

SMP を使用したチェックポイント

SageMaker モデル並列処理 (SMP) ライブラリは、チェックポイント用の PyTorch API に対応し、SMP ライブラリの使用中にチェックポイントを適切に行うのに役立つ API を提供しています。

PyTorch FSDP (Fully Sharded Data Parallelism) は、フル (full)、シャード (sharded)、ローカル (local) の 3 種類のチェックポイントをサポートしていますが、それぞれ目的が異なります。フルチェックポイントの生成は計算コストがかかるプロセスであるため、フルチェックポイントは、トレーニングの完了後にモデルをエクスポートする際に使用されます。シャードチェックポイントは、個々のランクごとに分割されたモデルの状態を保存し、ロードするのに役立ちます。シャードチェックポイントを使用すると、異なるハードウェア構成 (GPU の数が違うなど) でトレーニングを再開できます。ただし、シャードチェックポイントのロードは、複数のデバイス間の通信が必要になるため、遅くなる可能性があります。SMP ライブラリにはローカルチェックポイント機能があり、追加の通信オーバーヘッドを伴わずにモデルの状態をすばやく取得できます。FSDP によって作成されたチェックポイントは、Amazon FSx などの共有ネットワークファイルシステムへの書き込みが必要です。

非同期ローカルチェックポイント

機械学習モデルをトレーニングする際、チェックポイントファイルがディスクに保存されるのを後続のイテレーションが待つ必要はありません。SMP v2.5 のリリースで、チェックポイントファイルの非同期保存に対応しました。つまり、後続のトレーニングイテレーションは、チェックポイント作成時の入出力 (I/O) 演算と同時に実行でき、それらの I/O 演算によって低速化したり、中断されたりすることはありません。また、PyTorch でシャーディングされたモデルとオプティマイザのパラメータを取得するプロセスは、分散テンソルのメタデータをランク間で交換するために追加の集合通信が必要になるため、時間がかかる場合があります。StateDictType.LOCAL_STATE_DICT を使用して各ランクのローカルチェックポイントを保存した場合でも、PyTorch は依然として集合通信を実行するフックを呼び出します。この問題を軽減し、チェックポイントの取得にかかる時間を短縮するために、SMP は SMStateDictType.SM_LOCAL_STATE_DICT を導入しました。これにより、集合通信のオーバーヘッドを回避することで、モデルおよびオプティマイザのチェックポイントをより迅速に取得できます。

注記

SMStateDictType.SM_LOCAL_STATE_DICT を使用するには、FSDP SHARD_DEGREE の一貫性を維持する必要があります。SHARD_DEGREE が変更されないように徹底してください。モデルレプリケーションの数は変更されても大丈夫ですが、チェックポイントから再開する際は、モデルのシャード度合いが以前のトレーニング設定と同一である必要があります。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

次のコードスニペットは、SMStateDictType.SM_LOCAL_STATE_DICT を使用してチェックポイントをロードする方法を示しています。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

大規模言語モデル (LLM) のチェックポイントの保存には、多くの場合、大容量のファイルシステムの作成が必要になるため、コストがかかる可能性があります。コストを削減するために、Amazon FSx などの追加のファイルシステムサービスを使用する必要なく、チェックポイントを Amazon S3 に直接保存することができます。前述の例を基にして次のコードスニペットを使用し、S3 の URL を保存先として指定することで、チェックポイントを S3 に保存できます。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

非同期シャードチェックポイント

GPU の数を変えるなど、異なるハードウェア構成でトレーニングを続行しなければならない場合があります。このような場合、トレーニングプロセスではリシャーディング中にチェックポイントをロードする必要があります。つまり、後続のトレーニングを異なる数の SHARD_DEGREE で再開することになります。異なる数の SHARD_DEGREE でトレーニングを再開する必要がある状況に対応するには、シャーディングされた状態のディクショナリタイプ (StateDictType.SHARDED_STATE_DICT) でモデルのチェックポイントを保存する必要があります。この形式でチェックポイントを保存すると、変更後のハードウェア構成でトレーニングを続行するときに、リシャーディングプロセスを適切に処理できます。提示されているコードスニペットは、tsm API を使用してシャードチェックポイントを非同期的に保存し、より効率的で合理的なトレーニングプロセスを実現する方法を示しています。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

共有チェックポイントをロードするプロセスは前のセクションと似ていますが、ここでは torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader とその load メソッドを使用しています。このクラスの load メソッドを使用すると、前述のプロセスに類似したプロセスに従って、共有チェックポイントデータをロードできます。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

フルモデルチェックポイント

トレーニングの最後には、モデルのすべてのシャードを 1 つのモデルチェックポイントファイルにまとめたフルチェックポイントを保存できます。SMP ライブラリは PyTorch のフルモデルチェックポイント API を完全にサポートしているため、変更を加える必要はありません。

SMP テンソル並列性 を使用する場合、SMP ライブラリはモデルを変換します。この場合、完全なモデルをチェックポイントする際に、モデルがデフォルトで元の Hugging Face Transformers チェックポイント形式に変換されます。

SMP のテンソル並列処理でトレーニングを行い、SMP の変換プロセスをオフにする場合は、PyTorch の FullStateDictConfig API の translate_on_save 引数を使用して、SMP の自動変換のオン/オフを適宜切り替えることができます。例えば、モデルのトレーニングに集中する場合は、追加のオーバーヘッドが生じる変換プロセスを追加する必要はありません。この場合は、translate_on_save=False を設定することをお勧めします。また、今後のトレーニングでモデルの SMP 変換を使用する予定がある場合は、オフに切り替えて、モデルの SMP 変換を後で使用するために保存できます。モデルのトレーニングを終了し、その後推論に使用する場合は、モデルを Hugging Face Transformers モデルチェックポイント形式に戻す必要があります。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

FullStateDictConfig(rank0_only=True, offload_to_cpu=True) オプションは、大型モデルのトレーニング時にメモリを節約するために、0 番ランクのデバイスの CPU にモデルを集約します。

推論のためにモデルを再ロードするには、次のコード例のように処理します。AutoModelForSeq2SeqLM クラスは、モデルによっては、Hugging Face Transformer の他のファクタービルダークラス (AutoModelForCausalLM など) に変更される場合があります。詳細については、Hugging Face Transformers のドキュメントを参照してください。

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)