モデル並列処理によるモデルのチェックポイントと微調整 - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

モデル並列処理によるモデルのチェックポイントと微調整

SageMaker モデル並列処理ライブラリは、さまざまなモデル並列処理戦略によって分割されたモデルの状態とオプティマイザの状態を保存し、トレーニングを再開して微調整する場所から継続的なトレーニングのチェックポイントをロードするためのチェックポイント APIs を提供します。この API は、モデルとオプティマイザの状態の一部または全体を保存するオプションもサポートしています。

分散モデルのチェックポイント機能

PyTorch と の間のフレームワーク TensorFlow と、使用する SageMaker モデル並列処理ライブラリのバージョンに応じて、次のいずれかのトピックを選択します。

分散 PyTorch モデルのチェックポイント ( SageMaker モデル並列処理ライブラリ v1.10.0 以降用)

SageMaker モデル並列処理ライブラリは、分散モデルの状態とそのオプティマイザの状態のチェックポイント全体または一部を保存およびロードするためのチェックポイント APIs を提供します。

注記

このチェックポイント方法は、 PyTorch と SageMaker モデル並列処理ライブラリ v1.10.0 以降を使用する場合に推奨されます。

部分チェックポイント

モデル並列処理を使用してトレーニングされたモデルのチェックポイントを保存するには、部分チェックポイントオプションを true (partial=True) に設定して、smdistributed.modelparallel.torch.save_checkpoint API を使用します。これにより、各モデルのパーティションが個別に保存されます。モデルとオプティマイザの状態に加えて、user_content 引数を介して追加のカスタムデータを保存することもできます。チェックポイントモデル、オプティマイザ、およびユーザーコンテンツは個別のファイルとして保存されます。save_checkpoint API コールにより、次の構造でチェックポイントフォルダが作成されます。

- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)

部分チェックポイントからトレーニングを再開するには、partial=True として smdistributed.modelparallel.torch.resume_from_checkpoint API を使用し、チェックポイントディレクトリと部分チェックポイントの保存時に使用したタグを指定します。モデルの重みが実際に読み込まれるのは、モデルパーティショニングの後、smdistributed.modelparallel.torch.step - 修飾トレーニングステップ関数の初回実行時に行われることにご注意ください。

部分チェックポイントを保存するとき、ライブラリは .pt ファイル拡張子のファイルとして、モデルパーティションの決定も保存します。逆に、部分チェックポイントから再開するとき、ライブラリはパーティション決定ファイルを一緒に読み込みます。パーティション決定が読み込まれたら、パーティションを変更することはできません。

次のコードスニペットは、トレーニングスクリプトで PyTorchチェックポイント APIsを設定する方法を示しています。

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True )

完全チェックポイント

推論の目的で最終モデルアーティファクトを保存するには、partial=False として smdistributed.modelparallel.torch.save_checkpoint API を使用します。これにより、モデルパーティションが組み合わされて単一のモデルアーティファクトが作成されます。これはオプティマイザの状態は組み合わせないことに注意してください。

完全なモデルチェックポイントを指定して、特定の重みでトレーニングを初期化するには、partial=False として smdistributed.modelparallel.torch.resume_from_checkpoint API を使用できます。これはオプティマイザの状態は読み込まないことに注意してください。

注記

テンソル並列処理では、一般的に元のモデル実装と DistributedModel 実装との間で state_dict を変換する必要があります。オプションで、smdistributed.modelparallel.torch.resume_from_checkpoint への引数として state_dict 変換関数を指定できます。ただし、すぐに利用可能なサポート対象モデル の場合、この変換はライブラリが自動的に処理します。

次のコードは、チェックポイント APIs を使用して、 PyTorch モデル並列処理でトレーニングされたモデルを完全にチェックポイントする方法の例を示しています。

import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_content num_kept_partial_checkpoints=5 ) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False )

分散 PyTorch モデルのチェックポイント (v1.6.0 から v1.9.0 までの SageMaker モデル並列処理ライブラリ用)

SageMaker モデル並列処理ライブラリは、テンソル並列処理によるトレーニングジョブのチェックポイントの一部または全体を保存するための Python 関数を提供します。次の手順は、テンソル並列処理を使用するときにチェックポイントを保存および読み込むように smp.save()smp.load() を使用する方法を示しています。

注記

このチェックポイント方法は PyTorch、v1.6.0 から v1.9.0 までの テンソル並列処理、、および SageMaker モデル並列処理ライブラリを使用する場合に推奨されます。

  1. モデルオブジェクトを準備し、ライブラリのラッパー関数 smp.DistributedModel() でラップします。

    model = MyModel(...) model = smp.DistributedModel(model)
  2. モデルのオプティマイザを準備します。一連のモデルパラメータは、オプティマイザ関数で必要とされる反復可能な引数です。一連のモデルパラメータを準備するには、model.parameters() を処理して、個々のモデルパラメータに一意の ID を割り当てる必要があります。

    モデルパラメータイテラブルに重複した ID を持つパラメータがある場合、チェックポイントされたオプティマイザの状態の読み込みは失敗します。オプティマイザの一意の ID を持つモデルパラメータの反復可能オブジェクトを作成するには、以下を参照してください。

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. ライブラリのラッパー関数 smp.DistributedOptimizer() を使用してオプティマイザをラップします。

    optimizer = smp.DistributedOptimizer(optimizer)
  4. smp.save() を使用して、モデルとオプティマイザの状態を保存します。チェックポイントの保存方法に応じて、次の 2 つのオプションのいずれかを選択します:

    • オプション 1: 単一の MP_GROUP に対して各 mp_rank の部分モデルを保存する。

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      テンソル並列処理では、ライブラリは次の形式の名前のチェックポイントファイルを保存します。checkpoint.pt_{pp_rank}_{tp_rank}

      注記

      テンソル並列処理では、if ステートメントを if smp.dp_rank() == 0 ではなく if smp.rdp_rank() == 0 のように設定していることを確認してください。オプティマイザの状態をテンソル並列処理でシャーディングする場合、すべてのリダクションデータの並列ランクは、オプティマイザの状態の独自のパーティションを保存する必要があります。チェックポイントに間違った if ステートメントを使用すると、トレーニングジョブが停止する場合があります。テンソル並列処理if smp.dp_rank() == 0を使用しない の使用の詳細については、SageMaker Python SDK ドキュメント「保存とロードの一般的な指示」を参照してください。

    • オプション 2: 完全モデルを保存する。

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      注記

      完全チェックポイントについては、次の点を考慮してください:

      • gather_to_rank0=True を設定すると、0 以外のすべてのランクは空のディクショナリを返します。

      • 完全チェックポイントの場合、モデルのチェックポイントのみ可能です。現在、オプティマイザの状態の完全チェックポイントはサポートされていません。

      • 完全モデルは smp.rank() == 0 で保存する必要があります。

  5. smp.load() を使用してチェックポイントを読み込みます。前のステップでチェックポイントした方法に応じて、次の 2 つのオプションのいずれかを選択します:

    • オプション 1: 部分チェックポイントを読み込む。

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      パーティションが変更されないことがわかっている場合は、読み込みの高速化のために model.load_state_dict()same_partition_load=True を設定できます。

    • オプション 2: 完全チェックポイントを読み込む。

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      if smp.rdp_rank() == 0 の条件は必須ではありませんが、異なる MP_GROUP 間での冗長な読み込みを回避できます。現在、完全チェックポイントのオプティマイザの状態 dict は、テンソル並列処理ではサポートされていません。

分散 TensorFlow モデルのチェックポイント

TensorFlow モデル並列処理によるトレーニング中にモデルを保存するには、 SageMaker モデル並列処理ライブラリが提供する次の関数を使用します。

分散モデルの微調整

微調整はトレーニングスクリプトで設定する必要があります。次のコードスニペットは、モジュールsmdistributed.model.parallel.torchと微調整の設定を登録するための変更を加えた Hugging Face Transformer の AutoModelForCausalLM クラスを使用するトレーニングスクリプトの構造の例を示しています。

注記

smp.delayed_param_initialization 関数を有効にして分散トランスフォーマー (smp.DistributedModel() でラップされたトランスフォーマーモデル) を微調整するには、FSx for Lustre ファイルシステムで微調整ジョブを構成する必要があります。遅延パラメータ初期化オプションを使用して大規模モデルを微調整する場合は、FSx for Lustre ファイルシステムを設定する必要があります。

import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

トレーニングスクリプトと Jupyter Notebook の完全な例については、「例」リポジトリの「 の GPT-2 の例 PyTorch」を参照してください。 SageMaker GitHub