SageMaker モデル並列処理ライブラリ v2 のリファレンス - Amazon SageMaker AI

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

SageMaker モデル並列処理ライブラリ v2 のリファレンス

以下は、SageMaker モデル並列処理ライブラリ v2 (SMP v2) のリファレンスです。

SMP v2 の主要機能の設定パラメータ

以下は、SageMaker モデル並列処理ライブラリ v2 の主要機能を有効にし、設定するためのパラメータを網羅したリストです。これらは JSON 形式で記述し、SageMaker Python SDK の PyTorch 推定器に渡すか、SageMaker HyperPod 用の JSON ファイルとして保存する必要があります。

{ "hybrid_shard_degree": Integer, "sm_activation_offloading": Boolean, "activation_loading_horizon": Integer, "fsdp_cache_flush_warnings": Boolean, "allow_empty_shards": Boolean, "tensor_parallel_degree": Integer, "context_parallel_degree": Integer, "expert_parallel_degree": Integer, "random_seed": Integer }
  • hybrid_shard_degree (整数) – シャーディング並列処理の並列度を指定します。0 から world_size の間の整数を指定する必要があります。デフォルト値は 0 です。

    • 0 に設定すると、tensor_parallel_degree が 1 の場合は、スクリプトのネイティブの PyTorch 実装と API にフォールバックします。それ以外の場合は、tensor_parallel_degreeworld_size に基づいて可能な限り最大の hybrid_shard_degree が計算されます。ネイティブの PyTorch FSDP ユースケースにフォールバックする場合、使用する戦略が FULL_SHARD であれば、GPU のクラスター全体にシャーディング (分割) が行われます。戦略が HYBRID_SHARD または _HYBRID_SHARD_ZERO2 の場合は、hybrid_shard_degree が 8 である場合と同等です。テンソル並列処理が有効な場合は、見直された hybrid_shard_degree に基づいてシャーディングが行われます。

    • 1 に設定すると、tensor_parallel_degree が 1 の場合、スクリプトの NO_SHARD のネイティブ PyTorch 実装と API にフォールバックします。それ以外の場合は、特定のテンソル並列グループ内で NO_SHARD と同等の動作をします。

    • 2 から world_size までの整数に設定すると、指定数の GPU 間でシャーディングが行われます。FSDP スクリプトで sharding_strategy を設定していない場合は、HYBRID_SHARD に設定されます。_HYBRID_SHARD_ZERO2 を設定した場合は、指定した sharding_strategy が使用されます。

  • sm_activation_offloading (ブール値) - SMP アクティベーションオフロードの実装を有効にするかどうかを指定します。False の場合は、ネイティブの PyTorch 実装を使用します。True の場合は、SMP のアクティベーションオフロードの実装を使用します。また、PyTorch のアクティベーションオフロードラッパー (torch.distributed.algorithms._checkpoint.checkpoint_wrapper.offload_wrapper) もスクリプトで使用する必要があります。詳細については、「アクティベーションオフロード」を参照してください。デフォルト値は True です。

  • activation_loading_horizon (整数) – FSDP のアクティベーションオフロードの範囲 (horizon) タイプを指定する整数。これは、チェックポイントまたはオフロードされた層の入力が GPU メモリに同時に存在できる最大数です。詳細については、「アクティベーションオフロード」を参照してください。入力値は正の整数である必要があります。デフォルト値は 2 です。

  • fsdp_cache_flush_warnings (ブール値) – PyTorch メモリマネージャーで発生したキャッシュフラッシュを検知し、警告するかどうかを指定します。キャッシュフラッシュは計算性能を低下させる可能性があります。デフォルト値は True です。

  • allow_empty_shards (ブール値) – テンソルのシャーディング時にテンソルを分割できない場合に、空のシャードを許容するかどうかを指定します。これは、特定のシナリオにおける、チェックポイント中のクラッシュに対する実験的な修正です。これを無効にすると、元の PyTorch の動作に戻ります。デフォルト値は False です。

  • tensor_parallel_degree (整数) – テンソル並列処理の並列度を指定します。1 から world_size の間の値を指定する必要があります。デフォルト値は 1 です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「テンソル並列性」を参照してください。

  • context_parallel_degree (整数) – コンテキスト並列処理の並列度を指定します。1 から world_size の間の値を指定する必要があり、<= hybrid_shard_degree であることが必要です。デフォルト値は 1 です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「コンテキスト並列処理」を参照してください。

  • expert_parallel_degree (整数) – エキスパート並列処理の並列度を指定します。1 から world_size の間の値を指定する必要があります。デフォルト値は 1 です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「エキスパート並列処理」を参照してください。

  • random_seed (整数) – SMP のテンソル並列処理またはエキスパート並列処理による分散型モジュールにおける、ランダムオペレーションに使用されるシード番号。このシードはテンソル並列ランクまたはエキスパート並列ランクに追加され、各ランクの実際のシードを設定します。この値は、テンソル並列ランクおよびエキスパート並列ランクごとに一意です。SMP v2 では、テンソル並列ランクやエキスパート並列ランク間で生成される乱数が、それぞれテンソル並列処理やエキスパート並列処理を使用しない場合と一致するようになっています。

SMP v2 torch.sagemaker パッケージのリファレンス

このセクションは、SMP v2 が提供する torch.sagemaker パッケージのリファレンスです。

torch.sagemaker.delayed_param.DelayedParamIniter

PyTorch モデルに パラメータの遅延初期化 を適用するための API。

class torch.sagemaker.delayed_param.DelayedParamIniter( model: nn.Module, init_method_using_config : Callable = None, verbose: bool = False, )

パラメータ

  • model (nn.Module) – SMP v2 のパラメータの遅延初期化機能をラップおよび適用する PyTorch モデル。

  • init_method_using_config (Callable) – SMP v2 またはサポート対象の SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル のテンソル並列実装を使用する場合は、このパラメータをデフォルト値 (None) のままにしてください。デフォルトでは、DelayedParamIniter API は、指定されたモデルを正しく初期化する方法を判定します。他のモデルでは、カスタムのパラメータ初期化関数を作成し、スクリプトに追加する必要があります。次のコードスニペットは、SMP v2 が SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル 向けに実装したデフォルトの init_method_using_config 関数です。次のコードスニペットを参照して独自の初期化設定関数を作成し、スクリプトに追加して、SMP DelayedParamIniter API の init_method_using_config パラメータに渡してください。

    from torch.sagemaker.utils.module_utils import empty_module_params, move_buffers_to_device # Define a custom init config function. def custom_init_method_using_config(module): d = torch.cuda.current_device() empty_module_params(module, device=d) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) move_buffers_to_device(module, device=d) delayed_initer = DelayedParamIniter(model, init_method_using_config=custom_init_method_using_config)

    前述のコードスニペットの torch.sagemaker.module_util 関数の詳細については、「torch.sagemaker ユーティリティ関数とプロパティ」を参照してください。

  • verbose (ブール値) – 初期化と検証の最中に詳細なログ記録を有効にするかどうかを指定します。デフォルト値は False です。

方法

  • get_param_init_fn() – PyTorch FSDP ラッパークラスの param_init_fn 引数に渡すことができるパラメータ初期化関数を返します。

  • get_post_param_init_fn() – PyTorch FSDP ラッパークラスの post_param_init_fn 引数に渡すことができるパラメータ初期化関数を返します。これは、モデルで重みを共有 (tied weights) している場合に必要です。モデルは tie_weights メソッドを実装する必要があります。詳細については、「パラメータの遅延初期化」の「重み共有に関する注意事項」を参照してください。

  • count_num_params (module: nn.Module, *args: Tuple[nn.Parameter]) – パラメータ初期化関数によって初期化されるパラメータの数をカウントします。これは、以下の validate_params_and_buffers_inited メソッドを実装する際に役立ちます。通常、この関数を明示的に呼び出す必要はありません。validate_params_and_buffers_inited メソッドがバックエンドでこのメソッドを暗黙的に呼び出すからです。

  • validate_params_and_buffers_inited (enabled: bool=True) – 初期化されたパラメータ数が、モデル内のパラメータの合計数と一致することを検証するためのコンテキストマネージャーです。すべてのパラメータとバッファが、メタデバイスではなく GPU デバイス上にあることも検証します。これらの条件が満たされない場合は、AssertionErrors が発生します。このコンテキストマネージャーの使用は任意です。パラメータを初期化する際に必須ではありません。

torch.sagemaker.distributed.checkpoint.state_dict_saver.async_save

非同期保存用のエントリ API。このメソッドを使用して、指定された checkpoint_id に非同期的に state_dict を保存します。

def async_save( state_dict: STATE_DICT_TYPE, *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, queue : AsyncCallsQueue = None, sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, wait_error_handling: bool = True, force_check_all_plans: bool = True, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:

パラメータ

  • state_dict (dict) – 必須。保存対象の状態ディクショナリ。

  • checkpoint_id (str) - 必須。チェックポイントの保存先のストレージパス。

  • storage_writer (StorageWriter) - オプション。書き込みオペレーションを実行する、PyTorch の StorageWriter のインスタンス。指定しない場合は、デフォルト設定の StorageWriter が使用されます。

  • planner (SavePlanner) - オプション。PyTorch の SavePlanner のインスタンス。指定しない場合は、デフォルト設定の SavePlanner が使用されます。

  • process_group (ProcessGroup) - オプション。使用するプロセスグループ。None の場合は、デフォルト (グローバル) のプロセスグループが使用されます。

  • coordinator_rank (int) - オプション。集合通信演算 (AllReduce など) を実行する場合のコーディネーターのランク。

  • queue (AsyncRequestQueue) - オプション。使用する非同期スケジューラ。デフォルトでは、グローバルパラメータ DEFAULT_ASYNC_REQUEST_QUEUE を取ります。

  • sharded_strategy (PyTorchDistSaveShardedStrategy) - オプション。チェックポイントの保存に使用するシャーディング戦略。指定しない場合は、デフォルトで torch.sagemaker.distributed.checkpoint.state_dict_saver.PyTorchDistSaveShardedStrategy が使用されます。

  • wait_error_handling (bool) - オプション。すべてのランクがエラー処理を終了するまで待つかどうかを指定するフラグ。デフォルト値は True です。

  • force_check_all_plans (bool) - オプション。キャッシュヒットが発生した場合でも、ランク間でプランを強制的に同期するかどうかを決定するフラグ。デフォルト値は True です。

  • s3_region (str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id を基にリージョンが推測されます。

  • s3client_config (S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfig が使用されます。part_size パラメータは、デフォルトでは 64MB に設定されています。

torch.sagemaker.distributed.checkpoint.state_dict_saver.maybe_finalize_async_calls

この関数を使用して、トレーニングプロセスで複数の非同期リクエストが完了するのを監視できます。

def maybe_finalize_async_calls( blocking=True, process_group=None ) -> List[int]:

パラメータ

  • blocking (bool) - オプション。True の場合は、アクティブなリクエストがすべて完了するまで待機します。それ以外の場合は、既に完了した非同期リクエストのみが確定されます。デフォルト値は True です。

  • process_group (ProcessGroup) - オプション。操作対象のプロセスグループ。None に設定すると、デフォルト (グローバル) のプロセスグループが使用されます。

戻り値

  • 正常に終了した非同期呼び出しのインデックスを含むリスト。

torch.sagemaker.distributed.checkpoint.state_dict_saver.save

このメソッドを使用して、指定された checkpoint_id に同期的に state_dict を保存します。

def save( state_dict: STATE_DICT_TYPE, *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, wait_error_handling: bool = True, force_check_all_plans: bool = True, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:

パラメータ

  • state_dict (dict) – 必須。保存対象の状態ディクショナリ。

  • checkpoint_id (str) - 必須。チェックポイントの保存先のストレージパス。

  • storage_writer (StorageWriter) - オプション。書き込みオペレーションを実行する、PyTorch の StorageWriter のインスタンス。指定しない場合は、デフォルト設定の StorageWriter が使用されます。

  • planner (SavePlanner) - オプション。PyTorch の SavePlanner のインスタンス。指定しない場合は、デフォルト設定の SavePlanner が使用されます。

  • process_group (ProcessGroup) - オプション。使用するプロセスグループ。None の場合は、デフォルト (グローバル) のプロセスグループが使用されます。

  • coordinator_rank (int) - オプション。集合通信演算 (AllReduce など) を実行する場合のコーディネーターのランク。

  • wait_error_handling (bool) - オプション。すべてのランクがエラー処理を終了するまで待つかどうかを指定するフラグ。デフォルト値は True です。

  • force_check_all_plans (bool) - オプション。キャッシュヒットが発生した場合でも、ランク間でプランを強制的に同期するかどうかを決定するフラグ。デフォルト値は True です。

  • s3_region (str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id を基にリージョンが推測されます。

  • s3client_config (S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfig が使用されます。part_size パラメータは、デフォルトでは 64MB に設定されています。

torch.sagemaker.distributed.checkpoint.state_dict_loader.load

分散モデルの状態ディクショナリ (state_dict) をロードします。

def load( state_dict: Dict[str, Any], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_reader: Optional[StorageReader] = None, planner: Optional[LoadPlanner] = None, process_group: Optional[dist.ProcessGroup] = None, check_keys_matched: bool = True, coordinator_rank: int = 0, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:

パラメータ

  • state_dict (dict) – 必須。ロード対象の state_dict

  • checkpoint_id (str) - 必須。チェックポイントの ID。checkpoint_id の指すものは、ストレージによって異なります。場合によっては、フォルダまたはファイルのパスになり、また、ストレージがキーと値のストアである場合は、キーになります。

  • storage_reader (StorageReader) - オプション。読み取りオペレーションを実行する、PyTorch の StorageReader のインスタンス。指定しない場合、分散チェックポイントでは checkpoint_id に基づいてリーダーが自動的に推測されます。checkpoint_idNone である場合は、例外エラーが発生します。

  • planner (StorageReader) - オプション。PyTorch の LoadPlanner のインスタンス。指定しない場合、デフォルト設定の LoadPlanner が使用されます。

  • check_keys_matched (bool) - オプション。有効にすると、すべてのランクの state_dict キーが一致しているかどうかを AllGather を使用して確認します。

  • s3_region (str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id を基にリージョンが推測されます。

  • s3client_config (S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfig が使用されます。part_size パラメータは、デフォルトでは 64MB に設定されています。

torch.sagemaker.moe.moe_config.MoEConfig

Mixture-of-Experts (MoE) の SMP 実装を設定するための設定クラス。このクラスを通じて MoE 設定値を指定し、torch.sagemaker.transform API コールに渡すことができます。このクラスを使用して MoE モデルをトレーニングする方法については、「エキスパート並列処理」を参照してください。

class torch.sagemaker.moe.moe_config.MoEConfig( smp_moe=True, random_seed=12345, moe_load_balancing="sinkhorn", global_token_shuffle=False, moe_all_to_all_dispatcher=True, moe_aux_loss_coeff=0.001, moe_z_loss_coeff=0.001 )

パラメータ

  • smp_moe (ブール値) - MoE の SMP 実装を使用するかどうかを指定します。デフォルト値は True です。

  • random_seed (整数) - エキスパート並列分散モジュールにおけるランダムオペレーションに使用されるシード番号。このシードはエキスパート並列ランクに追加され、各ランクの実際のシードを設定します。この値は、エキスパート並列ランクごとに一意です。デフォルト値は 12345 です。

  • moe_load_balancing (文字列) - MoE ルーターの負荷分散タイプを指定します。有効なオプションは、aux_losssinkhornbalancednone です。デフォルト値は sinkhorn です。

  • global_token_shuffle (ブール値) - 同じ EP グループ内の EP ランク間でトークンをシャッフルするかどうかを指定します。デフォルト値は False です。

  • moe_all_to_all_dispatcher (ブール値) - MoE の通信に All-to-all ディスパッチャーを使用するかどうかを指定します。デフォルト値は True です。

  • moe_aux_loss_coeff (浮動小数点) - 負荷分散の補助損失の係数。デフォルト値は 0.001 です。

  • moe_z_loss_coeff (浮動小数点) - z-loss の係数。デフォルト値は 0.001 です。

torch.sagemaker.nn.attn.FlashSelfAttention

SMP v2 で FlashAttention を使用するための API。

class torch.sagemaker.nn.attn.FlashSelfAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, triton_flash_attention: bool = False, use_alibi: bool = False, )

パラメータ

  • attention_dropout_prob (float) – アテンションに適用するドロップアウト率。デフォルト値は 0.0 です。

  • scale (float) – このパラメータを渡した場合、指定したスケール係数がソフトマックスに適用されます。None (デフォルト値) に設定すると、スケール係数は 1 / sqrt(attention_head_size) になります。デフォルト値は None です。

  • triton_flash_attention (bool) – このパラメータを渡した場合、フラッシュアテンションの Triton 実装が使用されます。Attention with Linear Biases (ALiBi) をサポートする場合は必須です (次の use_alibi パラメータを参照)。このバージョンのカーネルはドロップアウトをサポートしていません。デフォルト値は False です。

  • use_alibi (bool) – このパラメータを渡した場合、指定したマスクを使用して Attention with Linear Biases (ALiBi) が有効になります。ALiBi を使用する場合は、次のようなアテンションマスクを用意する必要があります。デフォルト値は False です。

    def generate_alibi_attn_mask(attention_mask, batch_size, seq_length, num_attention_heads, alibi_bias_max=8): device, dtype = attention_mask.device, attention_mask.dtype alibi_attention_mask = torch.zeros( 1, num_attention_heads, 1, seq_length, dtype=dtype, device=device ) alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view( 1, 1, 1, seq_length ) m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device) m.mul_(alibi_bias_max / num_attention_heads) alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1))) alibi_attention_mask.add_(alibi_bias) alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length] if attention_mask is not None and attention_mask.bool().any(): alibi_attention_mask.masked_fill( attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf") ) return alibi_attention_mask

方法

  • forward(self, qkv, attn_mask=None, causal=False, cast_dtype=None, layout="b h s d") – 一般的な PyTorch モジュール関数。module(x) が呼び出されると、SMP はこの関数を自動的に実行します。

    • qkv – 形式が (batch_size x seqlen x (3 x num_heads) x head_size) または (batch_size, (3 x num_heads) x seqlen x head_size) である torch.Tensortorch.Tensors のタプルである場合は、それぞれの形状が (batch_size x seqlen x num_heads x head_size) または (batch_size x num_heads x seqlen x head_size) である場合があります。その形状に基づいて適切なレイアウト引数を渡す必要があります。

    • attn_mask – 形式が (batch_size x 1 x 1 x seqlen) である torch.Tensor。このアテンションマスクパラメータを有効にするには、triton_flash_attention=Trueuse_alibi=True が必須です。このメソッドを使用してアテンションマスクを生成する方法については、「FlashAttention」のコード例を参照してください。デフォルト値は None です。

    • causalFalse (この引数のデフォルト値) に設定すると、マスクは適用されません。True に設定すると、forward メソッドは標準の下三角形状のマスクを使用します。デフォルト値は False です。

    • cast_dtype – 特定の dtype に設定すると、attn の前に qkv テンソルがその dtype にキャストされます。これは、回転埋め込み後に qkfp32 である Hugging Face Transformer GPT-NeoX モデルなどの実装に役立ちます。None に設定すると、キャストは適用されません。デフォルト値は None です。

    • layout (文字列) – 指定可能な値は b h s d または b s h d です。attn に適した変換を適用できるように、渡された qkv テンソルのレイアウトを設定する必要があります。デフォルト値は b h s d です。

戻り値

形状が (batch_size x num_heads x seq_len x head_size) である単一の torch.Tensor

torch.sagemaker.nn.attn.FlashGroupedQueryAttention

SMP v2 で FlashGroupedQueryAttention を使用するための API。この API の使用方法の詳細については、「グループ化クエリアテンションに FlashAttention カーネルを使用する」を参照してください。

class torch.sagemaker.nn.attn.FlashGroupedQueryAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, )

パラメータ

  • attention_dropout_prob (float) – アテンションに適用するドロップアウト率。デフォルト値は 0.0 です。

  • scale (float) – このパラメータを渡した場合、指定したスケール係数がソフトマックスに適用されます。None に設定した場合は、1 / sqrt(attention_head_size) がスケール係数として使用されます。デフォルト値は None です。

方法

  • forward(self, q, kv, causal=False, cast_dtype=None, layout="b s h d") – 一般的な PyTorch モジュール関数。module(x) が呼び出されると、SMP はこの関数を自動的に実行します。

    • q – 形式が (batch_size x seqlen x num_heads x head_size) または (batch_size x num_heads x seqlen x head_size) である torch.Tensor。その形状に基づいて適切なレイアウト引数を渡す必要があります。

    • kv – 形式が (batch_size x seqlen x (2 x num_heads) x head_size) または (batch_size, (2 x num_heads) x seqlen x head_size) である torch.Tensor。または、2 つの torch.Tensor のタプルである場合は、それぞれの形状が (batch_size x seqlen x num_heads x head_size) または (batch_size x num_heads x seqlen x head_size) である場合があります。その形状に基づいて適切な layout 引数を渡す必要もあります。

    • causalFalse (この引数のデフォルト値) に設定すると、マスクは適用されません。True に設定すると、forward メソッドは標準の下三角形状のマスクを使用します。デフォルト値は False です。

    • cast_dtype – 特定の dtype に設定すると、attn の前に qkv テンソルがその dtype にキャストされます。これは、回転埋め込み後に q,kfp32 である Hugging Face Transformers GPT-NeoX などの実装に役立ちます。None に設定すると、キャストは適用されません。デフォルト値は None です。

    • layout (文字列) – 指定可能な値は "b h s d" または "b s h d" です。attn に適した変換を適用できるように、渡された qkv テンソルのレイアウトを設定する必要があります。デフォルト値は "b h s d" です。

戻り値

アテンション計算の出力を表す単一の torch.Tensor (batch_size x num_heads x seq_len x head_size) を返します。

torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention

Llama モデルの FlashAttention をサポートする API。この API は、低レベルで torch.sagemaker.nn.attn.FlashGroupedQueryAttention API を使用します。使用方法については、「グループ化クエリアテンションに FlashAttention カーネルを使用する」を参照してください。

class torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention( config: LlamaConfig )

パラメータ

  • config – Llama モデルの FlashAttention 設定。

方法

  • forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)

    • hidden_states (torch.Tensor) – (batch_size x seq_len x num_heads x head_size) の形式で表される、テンソルの非表示状態。

    • attention_mask (torch.LongTensor) – パディングトークンインデックスに注意 (アテンション) が払われないようにするマスク。形状は (batch_size x seqlen) です。デフォルト値は None です。

    • position_ids (torch.LongTensor) – None 以外の場合は、(batch_size x seqlen) の形式で、位置埋め込みの各入力シーケンストークンの位置のインデックスを指定します。デフォルト値は None です。

    • past_key_value (Cache) – 計算済みの非表示状態 (セルフアテンションブロックとクロスアテンションブロックのキーと値)。デフォルト値は None です。

    • output_attentions (bool) – すべてのアテンション層のアテンションテンソルを返すかどうかを指定します。デフォルト値は False です。

    • use_cache (bool) – past_key_values のキー値の状態を返すかどうかを指定します。デフォルト値は False です。

戻り値

アテンション計算の出力を表す単一の torch.Tensor (batch_size x num_heads x seq_len x head_size) を返します。

torch.sagemaker.transform

SMP v2 は、Hugging Face Transformer モデルを SMP モデル実装に変換し、SMP テンソル並列処理を有効にするために、この torch.sagemaker.transform() API を提供しています。

torch.sagemaker.transform( model: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, config: Optional[Dict] = None, load_state_dict_from_rank0: bool = False, cp_comm_type: str = "p2p" )

SMP v2 は、Hugging Face Transformer モデルの設定を SMP トランスフォーマーの設定に変換することで、SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル の変換ポリシーを維持しています。

パラメータ

  • model (torch.nn.Module) – 変換し、SMP ライブラリのテンソル並列処理機能を適用する対象の SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル モデル。

  • device (torch.device) – このパラメータを渡した場合、そのデバイスに新しいモデルが作成されます。変換前のモジュールのいずれかのパラメータがメタデバイス上にある場合 (「パラメータの遅延初期化」を参照)、変換後のモジュールもメタデバイス上に作成され、この引数で渡した値は無視されます。デフォルト値は None です。

  • dtype (torch.dtype) – このパラメータを渡した場合は、その値がモデル作成時の dtype コンテキストマネージャーとして設定され、その dtype でモデルが作成されます。通常は、指定の必要はありません。fp32 を使用する場合、MixedPrecision でモデルを作成するのが普通であり、PyTorch では fp32 がデフォルトの dtype になっています。デフォルト値は None です。

  • config (dict) – これは SMP トランスフォーマーを設定するためのディクショナリです。デフォルト値は None です。

  • load_state_dict_from_rank0 (ブール値) – デフォルトでは、このモジュールは、モデルの新しいインスタンスを新しい重みで作成します。この引数を True に設定すると、SMP は変換前の PyTorch モデルの状態ディクショナリを 0 番ランクから取得し、0 番ランクが属するテンソル並列グループの変換後のモデルにロードしようとします。True に設定されている場合、ランク 0 がメタデバイス上にパラメータを持つことはできません。この変換呼び出しの後、最初のテンソル並列グループのみが 0 番ランクの重みを取り込みます。これらの重みを最初のテンソル並列グループから取得し、他のすべてのプロセスに適用するには、FSDP ラッパーで sync_module_statesTrue に設定する必要があります。これを有効にすると、SMP ライブラリは元のモデルから状態ディクショナリをロードします。SMP ライブラリは、変換前のモデルの state_dict を取得し、変換後のモデルの構造に適合するように変換して、テンソル並列ランクごとにシャーディングします。この状態を 0 番ランクから、0 番ランクが属するテンソル並列グループ内の他のランクに伝達し、ロードします。デフォルト値は False です。

  • cp_comm_type (str) – コンテキスト並列処理の実装を指定します。context_parallel_degree が 1 より大きい場合にのみ適用されます。このパラメータで指定可能な値は p2pall_gather です。p2p 実装では、アテンションの計算中にピアツーピアの送受信呼び出しを使用してキーと値 (KV) のテンソルを蓄積します。非同期的に実行され、通信と計算をオーバーラップさせることができます。一方の all_gather 実装は、KV テンソルの蓄積に AllGather 集合通信演算を使用します。デフォルト値は "p2p" です。

戻り値

PyTorch FSDP でラップできる変換後のモデルを返します。load_state_dict_from_rank0True に設定されている場合、ランク 0 が属するテンソル並列グループには、ランク 0 の元の状態ディクショナリから重みがロードされます。元のモデルで パラメータの遅延初期化 を使用している場合、これらのランクのみが、変換後のモデルのパラメータとバッファを表す実際のテンソルを CPU 上に保持します。それ以外のランクでは、パラメータとバッファは引き続きメタデバイス上に保持され、メモリを節約します。

torch.sagemaker ユーティリティ関数とプロパティ

torch.sagemaker ユーティリティ関数
  • torch.sagemaker.init(config: Optional[Union[str, Dict[str, Any]]] = None) -> None - PyTorch トレーニングジョブを SMP で初期化します。

  • torch.sagemaker.is_initialized() -> bool – トレーニングジョブが SMP で初期化されているかどうかを確認します。ジョブが SMP で初期化されている場合にネイティブの PyTorch にフォールバックすると、以下の「プロパティ」リストで説明しているとおり、一部のプロパティが適切でなくなり、None になります。

  • torch.sagemaker.utils.module_utils.empty_module_params(module: nn.Module, device: Optional[torch.device] = None, recurse: bool = False) -> nn.Module – 指定されたデバイス (device) で空のパラメータを作成します。ネストされたすべてのモジュールに再帰的 (recurse) に適用するかどうかも指定できます。

  • torch.sagemaker.utils.module_utils.move_buffers_to_device(module: nn.Module, device: torch.device, recurse: bool = False) -> nn.Module – 指定されたデバイス (device) にモジュールバッファを移動します。ネストされたすべてのモジュールに再帰的 (recurse) に適用するかどうかも指定できます。

プロパティ

torch.sagemaker.state は、torch.sagemaker.init で SMP を初期化した後、複数の有用なプロパティを保持します。

  • torch.sagemaker.state.hybrid_shard_degree (int) – シャーディングデータ並列処理の並列度。SMP の設定で torch.sagemaker.init() に渡されたユーザー入力のコピーです。詳細については、「SageMaker モデル並列処理ライブラリ v2 を使用する」を参照してください。

  • torch.sagemaker.state.rank (int) – デバイスのグローバルランク。[0, world_size) の範囲内の値です。

  • torch.sagemaker.state.rep_rank_process_group (torch.distributed.ProcessGroup) – レプリケーションランクが同じすべてのデバイスを含むプロセスグループ。torch.sagemaker.state.tp_process_group とは、微妙ではあるが基本的な違いがあります。ネイティブ PyTorch にフォールバックすると、None を返します。

  • torch.sagemaker.state.tensor_parallel_degree (int) – テンソル並列処理の並列度。SMP の設定で torch.sagemaker.init() に渡されたユーザー入力のコピーです。詳細については、「SageMaker モデル並列処理ライブラリ v2 を使用する」を参照してください。

  • torch.sagemaker.state.tp_size (int) – torch.sagemaker.state.tensor_parallel_degree へのエイリアス。

  • torch.sagemaker.state.tp_rank (int) – テンソル並列度とランキングメカニズムによって決定される、特定のデバイスのテンソル並列処理ランク。[0, tp_size) の範囲内の値です。

  • torch.sagemaker.state.tp_process_group (torch.distributed.ProcessGroup) – 他の次元 (シャーディングデータ並列処理やレプリケーションなど) で同じランクを持つが、テンソル並列ランクは一意であるデバイスをすべて含むテンソル並列プロセスグループ。ネイティブ PyTorch にフォールバックすると、None を返します。

  • torch.sagemaker.state.world_size (int) – トレーニングで使用されるデバイスの総数。

SMP v1 から SMP v2 へのアップグレード

SMP v1 から SMP v2 に移行するには、スクリプトを変更して SMP v1 の API を削除し、SMP v2 の API を適用する必要があります。SMP v1 のスクリプトから着手するのではなく、PyTorch FSDP スクリプトから着手し、「SageMaker モデル並列処理ライブラリ v2 を使用する」の指示に従うことを推奨します。

SMP v1 のモデルを SMP v2 に移行するには、SMP v1 で完全なモデル状態ディクショナリを収集し、そのモデル状態ディクショナリに変換関数を適用して、Hugging Face Transformers モデルのチェックポイント形式に変換する必要があります。その後、SMP v2 では、「SMP を使用したチェックポイント」で説明しているとおり、Hugging Face Transformers モデルのチェックポイントをロードし、PyTorch のチェックポイント API を SMP v2 でも引き続き使用できます。PyTorch FSDP モデルで SMP を使用するには、SMP v2 に移行し、PyTorch FSDP やその他の最新機能を使用するようにトレーニングスクリプトを変更してください。

import smdistributed.modelparallel.torch as smp # Create model model = ... model = smp.DistributedModel(model) # Run training ... # Save v1 full checkpoint if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model # Get the corresponding translation function in smp v1 and translate if model_type == "gpt_neox": from smdistributed.modelparallel.torch.nn.huggingface.gptneox import translate_state_dict_to_hf_gptneox translated_state_dict = translate_state_dict_to_hf_gptneox(state_dict, max_seq_len=None) # Save the checkpoint checkpoint_path = "checkpoint.pt" if smp.rank() == 0: smp.save( {"model_state_dict": translated_state_dict}, checkpoint_path, partial=False, )

SMP v1 で使用可能な変換関数については、「Hugging Face Transformer モデルのサポート」を参照してください。

SMP v2 でのモデルチェックポイントの保存とロードの手順については、「SMP を使用したチェックポイント」を参照してください。