を使用したチェックポイント SMP - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

を使用したチェックポイント SMP

SageMaker モデル並列処理 (SMP) ライブラリは PyTorch APIsチェックポイントをサポートしており、SMPライブラリの使用中にチェックポイントを適切に行うAPIsのに役立ちます。

PyTorch FSDP (完全シャードデータ並列処理) は、フルチェックポイント、シャードチェックポイント、ローカルチェックポイントの 3 種類のチェックポイントをサポートし、それぞれ異なる目的に対応します。完全なチェックポイントの生成は計算コストがかかるプロセスであるため、トレーニング完了後にモデルをエクスポートするときは、完全なチェックポイントが使用されます。シャードチェックポイントは、個々のランクごとにシャードされたモデルの状態を保存してロードするのに役立ちます。シャードチェックポイントを使用すると、異なる数の など、異なるハードウェア設定でトレーニングを再開できますGPUs。ただし、複数のデバイス間の通信により、シャードチェックポイントのロードが遅くなる可能性があります。SMP ライブラリにはローカルチェックポイント機能があり、追加の通信オーバーヘッドなしでモデルの状態をすばやく取得できます。によって作成されたチェックポイントは、Amazon などの共有ネットワークファイルシステムに書き込むFSDP必要があることに注意してくださいFSx。

非同期ローカルチェックポイント

機械学習モデルをトレーニングする場合、チェックポイントファイルがディスクに保存されるのを後続の反復で待つ必要はありません。SMP v2.5 のリリースでは、ライブラリはチェックポイントファイルの非同期保存をサポートしています。つまり、後続のトレーニング反復は、チェックポイントを作成するための入出力 (I/O) オペレーションと同時に実行でき、それらの I/O オペレーションによって速度が低下したり、抑制されたりすることはありません。また、 でシャードモデルとオプティマイザパラメータを取得するプロセスは、分散テンソルメタデータをランク間で交換するために必要な追加の集合通信のため、時間がかかる PyTorch 場合があります。StateDictType.LOCAL_STATE_DICT を使用して各ランクのローカルチェックポイントを保存しても、 は集合通信を実行するフックを呼び出し PyTorch ます。この問題を軽減し、チェックポイントの取得に必要な時間を短縮するために、 に SMP が導入されました。これによりSMStateDictType.SM_LOCAL_STATE_DICT、集合的な通信オーバーヘッドをバイパスすることで、モデルチェックポイントとオプティマイザチェックポイントをすばやく取得できます。

注記

の一貫性を維持することFSDPSHARD_DEGREEは、 を利用するための要件ですSMStateDictType.SM_LOCAL_STATE_DICTSHARD_DEGREE に変更がないことを確認します。モデルレプリケーションの数は異なる場合がありますが、チェックポイントから再開する場合、モデルシャードの度合いは以前のトレーニング設定と同じである必要があります。

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

次のコードスニペットは、 を使用してチェックポイントをロードする方法を示していますSMStateDictType.SM_LOCAL_STATE_DICT

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

大規模な言語モデル (LLMs) のチェックポイントの保存には、大きなファイルシステムボリュームの作成が必要になることが多いため、コストがかかる場合があります。コストを削減するために、Amazon などの追加のファイルシステムサービスを使用せずに、チェックポイントを Amazon S3 に直接保存することもできますFSx。前の例を次のコードスニペットで活用して、送信先URLとして S3 を指定することでチェックポイントを S3 に保存できます。

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

非同期シャードチェックポイント

の数の変更など、さまざまなハードウェア設定でトレーニングを継続する必要がある場合がありますGPUs。このような場合、トレーニングプロセスはリシャーディング中にチェックポイントをロードする必要があります。つまり、後続のトレーニングを異なる数の で再開しますSHARD_DEGREE。別の数の でトレーニングを再開する必要があるシナリオに対処するにはSHARD_DEGREE、 で表されるシャード状態ディクショナリタイプを使用してモデルチェックポイントを保存する必要がありますStateDictType.SHARDED_STATE_DICT。この形式でチェックポイントを保存すると、変更されたハードウェア設定でトレーニングを続行するときに、再シャーディングプロセスを適切に処理できます。提供されたコードスニペットは、 を使用してシャードチェックポイントtsmAPIを非同期的に保存する方法を説明し、より効率的で合理化されたトレーニングプロセスを可能にします。

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

共有チェックポイントをロードするプロセスは前のセクションと似ていますが、 torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReaderとその load メソッドの使用が含まれます。このクラスの loadメソッドを使用すると、前述のプロセスに類似したプロセスに従って、共有チェックポイントデータをロードできます。

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

完全なモデルチェックポイント

トレーニングの終了時に、モデルのすべてのシャードを 1 つのモデルチェックポイントファイルにまとめた完全なチェックポイントを保存できます。SMP ライブラリはモデルチェックポイント PyTorch 全体を完全にサポートしているためAPI、変更を加える必要はありません。

を使用する場合SMPテンソル並列性、SMPライブラリはモデルを変換することに注意してください。この場合、フルモデルをチェックポイントすると、SMPライブラリはモデルをデフォルトで Hugging Face Transformers チェックポイント形式に変換します。

SMP テンソル並列処理を使用してトレーニングし、SMP翻訳プロセスをオフにする場合は、 のtranslate_on_save引数を使用して、必要に応じてSMP自動翻訳のオン/オフ PyTorch FullStateDictConfigAPIを切り替えることができます。例えば、モデルのトレーニングに集中している場合、オーバーヘッドを追加する翻訳プロセスを追加する必要はありません。この場合、 を設定することをお勧めしますtranslate_on_save=False。また、今後さらなるトレーニングのためにモデルのSMP翻訳を引き続き使用する予定がある場合は、オフに切り替えると、後で使用するためにモデルのSMP翻訳を保存できます。モデルのトレーニングをまとめ、それを推論に使用する場合は、モデルを Hugging Face Transformers モデルチェックポイント形式に戻す必要があります。

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

オプションFullStateDictConfig(rank0_only=True, offload_to_cpu=True)は、0 ランク目のデバイスの でモデルCPUを収集して、大きなモデルのトレーニング時にメモリを節約することです。

推論のためにモデルをロードバックするには、次のコード例に示すようにロードバックします。クラスはAutoModelForSeq2SeqLM、モデルによっては、 などの Hugging Face Transformer の他のファクタービルダークラスに変更AutoModelForCausalLMされる場合があることに注意してください。詳細については、「Hugging Face Transformers のドキュメント」を参照してください。

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)