翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。
SageMaker モデル並列処理ライブラリ v2 のリファレンス
以下は、SageMaker モデル並列処理ライブラリ v2 (SMP v2) のリファレンスです。
SMP v2 の主要機能の設定パラメータ
以下は、SageMaker モデル並列処理ライブラリ v2 の主要機能を有効にし、設定するためのパラメータを網羅したリストです。これらは JSON 形式で記述し、SageMaker Python SDK の PyTorch 推定器に渡すか、SageMaker HyperPod 用の JSON ファイルとして保存する必要があります。
{ "hybrid_shard_degree":
Integer
, "sm_activation_offloading":Boolean
, "activation_loading_horizon":Integer
, "fsdp_cache_flush_warnings":Boolean
, "allow_empty_shards":Boolean
, "tensor_parallel_degree":Integer
, "context_parallel_degree":Integer
, "expert_parallel_degree":Integer
, "random_seed":Integer
}
-
hybrid_shard_degree
(整数) – シャーディング並列処理の並列度を指定します。0
からworld_size
の間の整数を指定する必要があります。デフォルト値は0
です。-
0
に設定すると、tensor_parallel_degree
が 1 の場合は、スクリプトのネイティブの PyTorch 実装と API にフォールバックします。それ以外の場合は、tensor_parallel_degree
とworld_size
に基づいて可能な限り最大のhybrid_shard_degree
が計算されます。ネイティブの PyTorch FSDP ユースケースにフォールバックする場合、使用する戦略がFULL_SHARD
であれば、GPU のクラスター全体にシャーディング (分割) が行われます。戦略がHYBRID_SHARD
または_HYBRID_SHARD_ZERO2
の場合は、hybrid_shard_degree
が 8 である場合と同等です。テンソル並列処理が有効な場合は、見直されたhybrid_shard_degree
に基づいてシャーディングが行われます。 -
1
に設定すると、tensor_parallel_degree
が 1 の場合、スクリプトのNO_SHARD
のネイティブ PyTorch 実装と API にフォールバックします。それ以外の場合は、特定のテンソル並列グループ内でNO_SHARD
と同等の動作をします。 -
2 から
world_size
までの整数に設定すると、指定数の GPU 間でシャーディングが行われます。FSDP スクリプトでsharding_strategy
を設定していない場合は、HYBRID_SHARD
に設定されます。_HYBRID_SHARD_ZERO2
を設定した場合は、指定したsharding_strategy
が使用されます。
-
-
sm_activation_offloading
(ブール値) - SMP アクティベーションオフロードの実装を有効にするかどうかを指定します。False
の場合は、ネイティブの PyTorch 実装を使用します。True
の場合は、SMP のアクティベーションオフロードの実装を使用します。また、PyTorch のアクティベーションオフロードラッパー (torch.distributed.algorithms._checkpoint.checkpoint_wrapper.offload_wrapper
) もスクリプトで使用する必要があります。詳細については、「アクティベーションオフロード」を参照してください。デフォルト値はTrue
です。 -
activation_loading_horizon
(整数) – FSDP のアクティベーションオフロードの範囲 (horizon) タイプを指定する整数。これは、チェックポイントまたはオフロードされた層の入力が GPU メモリに同時に存在できる最大数です。詳細については、「アクティベーションオフロード」を参照してください。入力値は正の整数である必要があります。デフォルト値は2
です。 -
fsdp_cache_flush_warnings
(ブール値) – PyTorch メモリマネージャーで発生したキャッシュフラッシュを検知し、警告するかどうかを指定します。キャッシュフラッシュは計算性能を低下させる可能性があります。デフォルト値はTrue
です。 -
allow_empty_shards
(ブール値) – テンソルのシャーディング時にテンソルを分割できない場合に、空のシャードを許容するかどうかを指定します。これは、特定のシナリオにおける、チェックポイント中のクラッシュに対する実験的な修正です。これを無効にすると、元の PyTorch の動作に戻ります。デフォルト値はFalse
です。 -
tensor_parallel_degree
(整数) – テンソル並列処理の並列度を指定します。1
からworld_size
の間の値を指定する必要があります。デフォルト値は1
です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「テンソル並列性」を参照してください。 -
context_parallel_degree
(整数) – コンテキスト並列処理の並列度を指定します。1
からworld_size
の間の値を指定する必要があり、<= hybrid_shard_degree
であることが必要です。デフォルト値は1
です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「コンテキスト並列処理」を参照してください。 -
expert_parallel_degree
(整数) – エキスパート並列処理の並列度を指定します。1 からworld_size
の間の値を指定する必要があります。デフォルト値は1
です。1 より大きい値を渡した場合に、コンテキスト並列処理が自動的に有効になるわけではありません。torch.sagemaker.transform API を使用して、トレーニングスクリプトでモデルをラップすることも必要です。詳細については、「エキスパート並列処理」を参照してください。 -
random_seed
(整数) – SMP のテンソル並列処理またはエキスパート並列処理による分散型モジュールにおける、ランダムオペレーションに使用されるシード番号。このシードはテンソル並列ランクまたはエキスパート並列ランクに追加され、各ランクの実際のシードを設定します。この値は、テンソル並列ランクおよびエキスパート並列ランクごとに一意です。SMP v2 では、テンソル並列ランクやエキスパート並列ランク間で生成される乱数が、それぞれテンソル並列処理やエキスパート並列処理を使用しない場合と一致するようになっています。
SMP v2 torch.sagemaker
パッケージのリファレンス
このセクションは、SMP v2 が提供する torch.sagemaker
パッケージのリファレンスです。
トピック
torch.sagemaker.distributed.checkpoint.state_dict_saver.async_save
torch.sagemaker.distributed.checkpoint.state_dict_saver.maybe_finalize_async_calls
torch.sagemaker.distributed.checkpoint.state_dict_saver.save
torch.sagemaker.distributed.checkpoint.state_dict_loader.load
torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention
torch.sagemaker.delayed_param.DelayedParamIniter
PyTorch モデルに パラメータの遅延初期化 を適用するための API。
class torch.sagemaker.delayed_param.DelayedParamIniter( model: nn.Module, init_method_using_config : Callable = None, verbose: bool = False, )
パラメータ
-
model
(nn.Module
) – SMP v2 のパラメータの遅延初期化機能をラップおよび適用する PyTorch モデル。 -
init_method_using_config
(Callable) – SMP v2 またはサポート対象の SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル のテンソル並列実装を使用する場合は、このパラメータをデフォルト値 (None
) のままにしてください。デフォルトでは、DelayedParamIniter
API は、指定されたモデルを正しく初期化する方法を判定します。他のモデルでは、カスタムのパラメータ初期化関数を作成し、スクリプトに追加する必要があります。次のコードスニペットは、SMP v2 が SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル 向けに実装したデフォルトのinit_method_using_config
関数です。次のコードスニペットを参照して独自の初期化設定関数を作成し、スクリプトに追加して、SMPDelayedParamIniter
API のinit_method_using_config
パラメータに渡してください。from torch.sagemaker.utils.module_utils import empty_module_params, move_buffers_to_device # Define a custom init config function. def
custom_init_method_using_config
(module): d = torch.cuda.current_device() empty_module_params(module, device=d) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() elif isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) move_buffers_to_device(module, device=d) delayed_initer = DelayedParamIniter(model, init_method_using_config=custom_init_method_using_config
)前述のコードスニペットの
torch.sagemaker.module_util
関数の詳細については、「torch.sagemaker ユーティリティ関数とプロパティ」を参照してください。 -
verbose
(ブール値) – 初期化と検証の最中に詳細なログ記録を有効にするかどうかを指定します。デフォルト値はFalse
です。
方法
-
get_param_init_fn()
– PyTorch FSDP ラッパークラスのparam_init_fn
引数に渡すことができるパラメータ初期化関数を返します。 -
get_post_param_init_fn()
– PyTorch FSDP ラッパークラスのpost_param_init_fn
引数に渡すことができるパラメータ初期化関数を返します。これは、モデルで重みを共有 (tied weights) している場合に必要です。モデルはtie_weights
メソッドを実装する必要があります。詳細については、「パラメータの遅延初期化」の「重み共有に関する注意事項」を参照してください。 -
count_num_params
(module: nn.Module, *args: Tuple[nn.Parameter]
) – パラメータ初期化関数によって初期化されるパラメータの数をカウントします。これは、以下のvalidate_params_and_buffers_inited
メソッドを実装する際に役立ちます。通常、この関数を明示的に呼び出す必要はありません。validate_params_and_buffers_inited
メソッドがバックエンドでこのメソッドを暗黙的に呼び出すからです。 -
validate_params_and_buffers_inited
(enabled: bool=True
) – 初期化されたパラメータ数が、モデル内のパラメータの合計数と一致することを検証するためのコンテキストマネージャーです。すべてのパラメータとバッファが、メタデバイスではなく GPU デバイス上にあることも検証します。これらの条件が満たされない場合は、AssertionErrors
が発生します。このコンテキストマネージャーの使用は任意です。パラメータを初期化する際に必須ではありません。
torch.sagemaker.distributed.checkpoint.state_dict_saver.async_save
非同期保存用のエントリ API。このメソッドを使用して、指定された checkpoint_id
に非同期的に state_dict
を保存します。
def async_save( state_dict: STATE_DICT_TYPE, *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, queue : AsyncCallsQueue = None, sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None, wait_error_handling: bool = True, force_check_all_plans: bool = True, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:
パラメータ
-
state_dict
(dict) – 必須。保存対象の状態ディクショナリ。 -
checkpoint_id
(str) - 必須。チェックポイントの保存先のストレージパス。 -
storage_writer
(StorageWriter) - オプション。書き込みオペレーションを実行する、PyTorch のStorageWriter
のインスタンス。指定しない場合は、デフォルト設定の StorageWriter
が使用されます。 -
planner
(SavePlanner) - オプション。PyTorch のSavePlanner
のインスタンス。指定しない場合は、デフォルト設定の SavePlanner
が使用されます。 -
process_group
(ProcessGroup) - オプション。使用するプロセスグループ。None
の場合は、デフォルト (グローバル) のプロセスグループが使用されます。 -
coordinator_rank
(int) - オプション。集合通信演算 (AllReduce
など) を実行する場合のコーディネーターのランク。 -
queue
(AsyncRequestQueue) - オプション。使用する非同期スケジューラ。デフォルトでは、グローバルパラメータDEFAULT_ASYNC_REQUEST_QUEUE
を取ります。 -
sharded_strategy
(PyTorchDistSaveShardedStrategy) - オプション。チェックポイントの保存に使用するシャーディング戦略。指定しない場合は、デフォルトでtorch.sagemaker.distributed.checkpoint.state_dict_saver.PyTorchDistSaveShardedStrategy
が使用されます。 -
wait_error_handling
(bool) - オプション。すべてのランクがエラー処理を終了するまで待つかどうかを指定するフラグ。デフォルト値はTrue
です。 -
force_check_all_plans
(bool) - オプション。キャッシュヒットが発生した場合でも、ランク間でプランを強制的に同期するかどうかを決定するフラグ。デフォルト値はTrue
です。 -
s3_region
(str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id
を基にリージョンが推測されます。 -
s3client_config
(S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfigが使用されます。 part_size
パラメータは、デフォルトでは 64MB に設定されています。
torch.sagemaker.distributed.checkpoint.state_dict_saver.maybe_finalize_async_calls
この関数を使用して、トレーニングプロセスで複数の非同期リクエストが完了するのを監視できます。
def maybe_finalize_async_calls( blocking=True, process_group=None ) -> List[int]:
パラメータ
-
blocking
(bool) - オプション。True
の場合は、アクティブなリクエストがすべて完了するまで待機します。それ以外の場合は、既に完了した非同期リクエストのみが確定されます。デフォルト値はTrue
です。 -
process_group
(ProcessGroup) - オプション。操作対象のプロセスグループ。None
に設定すると、デフォルト (グローバル) のプロセスグループが使用されます。
戻り値
-
正常に終了した非同期呼び出しのインデックスを含むリスト。
torch.sagemaker.distributed.checkpoint.state_dict_saver.save
このメソッドを使用して、指定された checkpoint_id
に同期的に state_dict
を保存します。
def save( state_dict: STATE_DICT_TYPE, *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_writer: Optional[StorageWriter] = None, planner: Optional[SavePlanner] = None, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, wait_error_handling: bool = True, force_check_all_plans: bool = True, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:
パラメータ
-
state_dict
(dict) – 必須。保存対象の状態ディクショナリ。 -
checkpoint_id
(str) - 必須。チェックポイントの保存先のストレージパス。 -
storage_writer
(StorageWriter) - オプション。書き込みオペレーションを実行する、PyTorch のStorageWriter
のインスタンス。指定しない場合は、デフォルト設定の StorageWriter
が使用されます。 -
planner
(SavePlanner) - オプション。PyTorch のSavePlanner
のインスタンス。指定しない場合は、デフォルト設定の SavePlanner
が使用されます。 -
process_group
(ProcessGroup) - オプション。使用するプロセスグループ。None
の場合は、デフォルト (グローバル) のプロセスグループが使用されます。 -
coordinator_rank
(int) - オプション。集合通信演算 (AllReduce
など) を実行する場合のコーディネーターのランク。 -
wait_error_handling
(bool) - オプション。すべてのランクがエラー処理を終了するまで待つかどうかを指定するフラグ。デフォルト値はTrue
です。 -
force_check_all_plans
(bool) - オプション。キャッシュヒットが発生した場合でも、ランク間でプランを強制的に同期するかどうかを決定するフラグ。デフォルト値はTrue
です。 -
s3_region
(str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id
を基にリージョンが推測されます。 -
s3client_config
(S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfigが使用されます。 part_size
パラメータは、デフォルトでは 64MB に設定されています。
torch.sagemaker.distributed.checkpoint.state_dict_loader.load
分散モデルの状態ディクショナリ (state_dict
) をロードします。
def load( state_dict: Dict[str, Any], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_reader: Optional[StorageReader] = None, planner: Optional[LoadPlanner] = None, process_group: Optional[dist.ProcessGroup] = None, check_keys_matched: bool = True, coordinator_rank: int = 0, s3_region: Optional[str] = None, s3client_config: Optional[S3ClientConfig] = None ) -> None:
パラメータ
-
state_dict
(dict) – 必須。ロード対象のstate_dict
。 -
checkpoint_id
(str) - 必須。チェックポイントの ID。checkpoint_id
の指すものは、ストレージによって異なります。場合によっては、フォルダまたはファイルのパスになり、また、ストレージがキーと値のストアである場合は、キーになります。 -
storage_reader
(StorageReader) - オプション。読み取りオペレーションを実行する、PyTorch のStorageReader
のインスタンス。指定しない場合、分散チェックポイントでは checkpoint_id
に基づいてリーダーが自動的に推測されます。checkpoint_id
もNone
である場合は、例外エラーが発生します。 -
planner
(StorageReader) - オプション。PyTorch のLoadPlanner
のインスタンス。指定しない場合、デフォルト設定の LoadPlanner
が使用されます。 -
check_keys_matched
(bool) - オプション。有効にすると、すべてのランクのstate_dict
キーが一致しているかどうかをAllGather
を使用して確認します。 -
s3_region
(str) - オプション。S3 バケットが位置するリージョン。指定しない場合、checkpoint_id
を基にリージョンが推測されます。 -
s3client_config
(S3ClientConfig) - オプション。S3 クライアントの設定可能なパラメータを公開するデータクラス。指定しない場合、デフォルト設定の S3ClientConfigが使用されます。 part_size
パラメータは、デフォルトでは 64MB に設定されています。
torch.sagemaker.moe.moe_config.MoEConfig
Mixture-of-Experts (MoE) の SMP 実装を設定するための設定クラス。このクラスを通じて MoE 設定値を指定し、torch.sagemaker.transform
API コールに渡すことができます。このクラスを使用して MoE モデルをトレーニングする方法については、「エキスパート並列処理」を参照してください。
class torch.sagemaker.moe.moe_config.MoEConfig( smp_moe=True, random_seed=12345, moe_load_balancing="sinkhorn", global_token_shuffle=False, moe_all_to_all_dispatcher=True, moe_aux_loss_coeff=0.001, moe_z_loss_coeff=0.001 )
パラメータ
-
smp_moe
(ブール値) - MoE の SMP 実装を使用するかどうかを指定します。デフォルト値はTrue
です。 -
random_seed
(整数) - エキスパート並列分散モジュールにおけるランダムオペレーションに使用されるシード番号。このシードはエキスパート並列ランクに追加され、各ランクの実際のシードを設定します。この値は、エキスパート並列ランクごとに一意です。デフォルト値は12345
です。 -
moe_load_balancing
(文字列) - MoE ルーターの負荷分散タイプを指定します。有効なオプションは、aux_loss
、sinkhorn
、balanced
、none
です。デフォルト値はsinkhorn
です。 -
global_token_shuffle
(ブール値) - 同じ EP グループ内の EP ランク間でトークンをシャッフルするかどうかを指定します。デフォルト値はFalse
です。 -
moe_all_to_all_dispatcher
(ブール値) - MoE の通信に All-to-all ディスパッチャーを使用するかどうかを指定します。デフォルト値はTrue
です。 -
moe_aux_loss_coeff
(浮動小数点) - 負荷分散の補助損失の係数。デフォルト値は0.001
です。 -
moe_z_loss_coeff
(浮動小数点) - z-loss の係数。デフォルト値は0.001
です。
torch.sagemaker.nn.attn.FlashSelfAttention
SMP v2 で FlashAttention を使用するための API。
class torch.sagemaker.nn.attn.FlashSelfAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, triton_flash_attention: bool = False, use_alibi: bool = False, )
パラメータ
-
attention_dropout_prob
(float) – アテンションに適用するドロップアウト率。デフォルト値は0.0
です。 -
scale
(float) – このパラメータを渡した場合、指定したスケール係数がソフトマックスに適用されます。None
(デフォルト値) に設定すると、スケール係数は1 / sqrt(attention_head_size)
になります。デフォルト値はNone
です。 -
triton_flash_attention
(bool) – このパラメータを渡した場合、フラッシュアテンションの Triton 実装が使用されます。Attention with Linear Biases (ALiBi) をサポートする場合は必須です (次のuse_alibi
パラメータを参照)。このバージョンのカーネルはドロップアウトをサポートしていません。デフォルト値はFalse
です。 -
use_alibi
(bool) – このパラメータを渡した場合、指定したマスクを使用して Attention with Linear Biases (ALiBi) が有効になります。ALiBi を使用する場合は、次のようなアテンションマスクを用意する必要があります。デフォルト値はFalse
です。def generate_alibi_attn_mask(attention_mask, batch_size, seq_length, num_attention_heads, alibi_bias_max=8): device, dtype = attention_mask.device, attention_mask.dtype alibi_attention_mask = torch.zeros( 1, num_attention_heads, 1, seq_length, dtype=dtype, device=device ) alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view( 1, 1, 1, seq_length ) m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device) m.mul_(alibi_bias_max / num_attention_heads) alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1))) alibi_attention_mask.add_(alibi_bias) alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length] if attention_mask is not None and attention_mask.bool().any(): alibi_attention_mask.masked_fill( attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf") ) return alibi_attention_mask
方法
-
forward(self, qkv, attn_mask=None, causal=False, cast_dtype=None, layout="b h s d")
– 一般的な PyTorch モジュール関数。module(x)
が呼び出されると、SMP はこの関数を自動的に実行します。-
qkv
– 形式が(batch_size x seqlen x (3 x num_heads) x head_size)
または(batch_size, (3 x num_heads) x seqlen x head_size)
であるtorch.Tensor
。torch.Tensors
のタプルである場合は、それぞれの形状が(batch_size x seqlen x num_heads x head_size)
または(batch_size x num_heads x seqlen x head_size)
である場合があります。その形状に基づいて適切なレイアウト引数を渡す必要があります。 -
attn_mask
– 形式が(batch_size x 1 x 1 x seqlen)
であるtorch.Tensor
。このアテンションマスクパラメータを有効にするには、triton_flash_attention=True
とuse_alibi=True
が必須です。このメソッドを使用してアテンションマスクを生成する方法については、「FlashAttention」のコード例を参照してください。デフォルト値はNone
です。 -
causal
–False
(この引数のデフォルト値) に設定すると、マスクは適用されません。True
に設定すると、forward
メソッドは標準の下三角形状のマスクを使用します。デフォルト値はFalse
です。 -
cast_dtype
– 特定のdtype
に設定すると、attn
の前にqkv
テンソルがそのdtype
にキャストされます。これは、回転埋め込み後にq
とk
がfp32
である Hugging Face Transformer GPT-NeoX モデルなどの実装に役立ちます。None
に設定すると、キャストは適用されません。デフォルト値はNone
です。 -
layout
(文字列) – 指定可能な値はb h s d
またはb s h d
です。attn
に適した変換を適用できるように、渡されたqkv
テンソルのレイアウトを設定する必要があります。デフォルト値はb h s d
です。
-
戻り値
形状が (batch_size x num_heads x
seq_len x head_size)
である単一の torch.Tensor
。
torch.sagemaker.nn.attn.FlashGroupedQueryAttention
SMP v2 で FlashGroupedQueryAttention
を使用するための API。この API の使用方法の詳細については、「グループ化クエリアテンションに FlashAttention カーネルを使用する」を参照してください。
class torch.sagemaker.nn.attn.FlashGroupedQueryAttention( attention_dropout_prob: float = 0.0, scale: Optional[float] = None, )
パラメータ
-
attention_dropout_prob
(float) – アテンションに適用するドロップアウト率。デフォルト値は0.0
です。 -
scale
(float) – このパラメータを渡した場合、指定したスケール係数がソフトマックスに適用されます。None
に設定した場合は、1 / sqrt(attention_head_size)
がスケール係数として使用されます。デフォルト値はNone
です。
方法
-
forward(self, q, kv, causal=False, cast_dtype=None, layout="b s h d")
– 一般的な PyTorch モジュール関数。module(x)
が呼び出されると、SMP はこの関数を自動的に実行します。-
q
– 形式が(batch_size x seqlen x num_heads x head_size)
または(batch_size x num_heads x seqlen x head_size)
であるtorch.Tensor
。その形状に基づいて適切なレイアウト引数を渡す必要があります。 -
kv
– 形式が(batch_size x seqlen x (2 x num_heads) x head_size)
または(batch_size, (2 x num_heads) x seqlen x head_size)
であるtorch.Tensor
。または、2 つのtorch.Tensor
のタプルである場合は、それぞれの形状が(batch_size x seqlen x num_heads x head_size)
または(batch_size x num_heads x seqlen x head_size)
である場合があります。その形状に基づいて適切なlayout
引数を渡す必要もあります。 -
causal
–False
(この引数のデフォルト値) に設定すると、マスクは適用されません。True
に設定すると、forward
メソッドは標準の下三角形状のマスクを使用します。デフォルト値はFalse
です。 -
cast_dtype
– 特定の dtype に設定すると、attn
の前にqkv
テンソルがその dtype にキャストされます。これは、回転埋め込み後にq,k
がfp32
である Hugging Face Transformers GPT-NeoX などの実装に役立ちます。None
に設定すると、キャストは適用されません。デフォルト値はNone
です。 -
layout (文字列) – 指定可能な値は
"b h s d"
または"b s h d"
です。attn
に適した変換を適用できるように、渡されたqkv
テンソルのレイアウトを設定する必要があります。デフォルト値は"b h s d"
です。
-
戻り値
アテンション計算の出力を表す単一の torch.Tensor (batch_size x num_heads x seq_len x
head_size)
を返します。
torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention
Llama モデルの FlashAttention をサポートする API。この API は、低レベルで torch.sagemaker.nn.attn.FlashGroupedQueryAttention API を使用します。使用方法については、「グループ化クエリアテンションに FlashAttention カーネルを使用する」を参照してください。
class torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention( config: LlamaConfig )
パラメータ
-
config
– Llama モデルの FlashAttention 設定。
方法
-
forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
-
hidden_states
(torch.Tensor
) –(batch_size x seq_len x num_heads x head_size)
の形式で表される、テンソルの非表示状態。 -
attention_mask
(torch.LongTensor
) – パディングトークンインデックスに注意 (アテンション) が払われないようにするマスク。形状は(batch_size x seqlen)
です。デフォルト値はNone
です。 -
position_ids
(torch.LongTensor
) –None
以外の場合は、(batch_size x seqlen)
の形式で、位置埋め込みの各入力シーケンストークンの位置のインデックスを指定します。デフォルト値はNone
です。 -
past_key_value
(Cache) – 計算済みの非表示状態 (セルフアテンションブロックとクロスアテンションブロックのキーと値)。デフォルト値はNone
です。 -
output_attentions
(bool) – すべてのアテンション層のアテンションテンソルを返すかどうかを指定します。デフォルト値はFalse
です。 -
use_cache
(bool) –past_key_values
のキー値の状態を返すかどうかを指定します。デフォルト値はFalse
です。
-
戻り値
アテンション計算の出力を表す単一の torch.Tensor (batch_size x num_heads x seq_len x
head_size)
を返します。
torch.sagemaker.transform
SMP v2 は、Hugging Face Transformer モデルを SMP モデル実装に変換し、SMP テンソル並列処理を有効にするために、この torch.sagemaker.transform()
API を提供しています。
torch.sagemaker.transform( model: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, config: Optional[Dict] = None, load_state_dict_from_rank0: bool = False, cp_comm_type: str = "p2p" )
SMP v2 は、Hugging Face Transformer モデルの設定を SMP トランスフォーマーの設定に変換することで、SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル の変換ポリシーを維持しています。
パラメータ
-
model
(torch.nn.Module
) – 変換し、SMP ライブラリのテンソル並列処理機能を適用する対象の SMP テンソル並列処理と互換性のあるHugging Face Transformer モデル モデル。 -
device
(torch.device
) – このパラメータを渡した場合、そのデバイスに新しいモデルが作成されます。変換前のモジュールのいずれかのパラメータがメタデバイス上にある場合 (「パラメータの遅延初期化」を参照)、変換後のモジュールもメタデバイス上に作成され、この引数で渡した値は無視されます。デフォルト値はNone
です。 -
dtype
(torch.dtype
) – このパラメータを渡した場合は、その値がモデル作成時の dtype コンテキストマネージャーとして設定され、その dtype でモデルが作成されます。通常は、指定の必要はありません。fp32
を使用する場合、MixedPrecision
でモデルを作成するのが普通であり、PyTorch ではfp32
がデフォルトの dtype になっています。デフォルト値はNone
です。 -
config
(dict) – これは SMP トランスフォーマーを設定するためのディクショナリです。デフォルト値はNone
です。 -
load_state_dict_from_rank0
(ブール値) – デフォルトでは、このモジュールは、モデルの新しいインスタンスを新しい重みで作成します。この引数をTrue
に設定すると、SMP は変換前の PyTorch モデルの状態ディクショナリを 0 番ランクから取得し、0 番ランクが属するテンソル並列グループの変換後のモデルにロードしようとします。True
に設定されている場合、ランク 0 がメタデバイス上にパラメータを持つことはできません。この変換呼び出しの後、最初のテンソル並列グループのみが 0 番ランクの重みを取り込みます。これらの重みを最初のテンソル並列グループから取得し、他のすべてのプロセスに適用するには、FSDP ラッパーでsync_module_states
をTrue
に設定する必要があります。これを有効にすると、SMP ライブラリは元のモデルから状態ディクショナリをロードします。SMP ライブラリは、変換前のモデルのstate_dict
を取得し、変換後のモデルの構造に適合するように変換して、テンソル並列ランクごとにシャーディングします。この状態を 0 番ランクから、0 番ランクが属するテンソル並列グループ内の他のランクに伝達し、ロードします。デフォルト値はFalse
です。 cp_comm_type
(str) – コンテキスト並列処理の実装を指定します。context_parallel_degree
が 1 より大きい場合にのみ適用されます。このパラメータで指定可能な値はp2p
とall_gather
です。p2p
実装では、アテンションの計算中にピアツーピアの送受信呼び出しを使用してキーと値 (KV) のテンソルを蓄積します。非同期的に実行され、通信と計算をオーバーラップさせることができます。一方のall_gather
実装は、KV テンソルの蓄積にAllGather
集合通信演算を使用します。デフォルト値は"p2p"
です。
戻り値
PyTorch FSDP でラップできる変換後のモデルを返します。load_state_dict_from_rank0
が True
に設定されている場合、ランク 0 が属するテンソル並列グループには、ランク 0 の元の状態ディクショナリから重みがロードされます。元のモデルで パラメータの遅延初期化 を使用している場合、これらのランクのみが、変換後のモデルのパラメータとバッファを表す実際のテンソルを CPU 上に保持します。それ以外のランクでは、パラメータとバッファは引き続きメタデバイス上に保持され、メモリを節約します。
torch.sagemaker
ユーティリティ関数とプロパティ
torch.sagemaker ユーティリティ関数
-
torch.sagemaker.init(config: Optional[Union[str, Dict[str, Any]]] = None) -> None
- PyTorch トレーニングジョブを SMP で初期化します。 -
torch.sagemaker.is_initialized() -> bool
– トレーニングジョブが SMP で初期化されているかどうかを確認します。ジョブが SMP で初期化されている場合にネイティブの PyTorch にフォールバックすると、以下の「プロパティ」リストで説明しているとおり、一部のプロパティが適切でなくなり、None
になります。 -
torch.sagemaker.utils.module_utils.empty_module_params(module: nn.Module, device: Optional[torch.device] = None, recurse: bool = False) -> nn.Module
– 指定されたデバイス (device
) で空のパラメータを作成します。ネストされたすべてのモジュールに再帰的 (recurse) に適用するかどうかも指定できます。 -
torch.sagemaker.utils.module_utils.move_buffers_to_device(module: nn.Module, device: torch.device, recurse: bool = False) -> nn.Module
– 指定されたデバイス (device
) にモジュールバッファを移動します。ネストされたすべてのモジュールに再帰的 (recurse) に適用するかどうかも指定できます。
プロパティ
torch.sagemaker.state
は、torch.sagemaker.init
で SMP を初期化した後、複数の有用なプロパティを保持します。
-
torch.sagemaker.state.hybrid_shard_degree
(int) – シャーディングデータ並列処理の並列度。SMP の設定でtorch.sagemaker.init()
に渡されたユーザー入力のコピーです。詳細については、「SageMaker モデル並列処理ライブラリ v2 を使用する」を参照してください。 -
torch.sagemaker.state.rank
(int) – デバイスのグローバルランク。[0, world_size)
の範囲内の値です。 -
torch.sagemaker.state.rep_rank_process_group
(torch.distributed.ProcessGroup
) – レプリケーションランクが同じすべてのデバイスを含むプロセスグループ。torch.sagemaker.state.tp_process_group
とは、微妙ではあるが基本的な違いがあります。ネイティブ PyTorch にフォールバックすると、None
を返します。 -
torch.sagemaker.state.tensor_parallel_degree
(int) – テンソル並列処理の並列度。SMP の設定でtorch.sagemaker.init()
に渡されたユーザー入力のコピーです。詳細については、「SageMaker モデル並列処理ライブラリ v2 を使用する」を参照してください。 -
torch.sagemaker.state.tp_size
(int) –torch.sagemaker.state.tensor_parallel_degree
へのエイリアス。 -
torch.sagemaker.state.tp_rank
(int) – テンソル並列度とランキングメカニズムによって決定される、特定のデバイスのテンソル並列処理ランク。[0, tp_size)
の範囲内の値です。 -
torch.sagemaker.state.tp_process_group
(torch.distributed.ProcessGroup
) – 他の次元 (シャーディングデータ並列処理やレプリケーションなど) で同じランクを持つが、テンソル並列ランクは一意であるデバイスをすべて含むテンソル並列プロセスグループ。ネイティブ PyTorch にフォールバックすると、None
を返します。 -
torch.sagemaker.state.world_size
(int) – トレーニングで使用されるデバイスの総数。
SMP v1 から SMP v2 へのアップグレード
SMP v1 から SMP v2 に移行するには、スクリプトを変更して SMP v1 の API を削除し、SMP v2 の API を適用する必要があります。SMP v1 のスクリプトから着手するのではなく、PyTorch FSDP スクリプトから着手し、「SageMaker モデル並列処理ライブラリ v2 を使用する」の指示に従うことを推奨します。
SMP v1 のモデルを SMP v2 に移行するには、SMP v1 で完全なモデル状態ディクショナリを収集し、そのモデル状態ディクショナリに変換関数を適用して、Hugging Face Transformers モデルのチェックポイント形式に変換する必要があります。その後、SMP v2 では、「SMP を使用したチェックポイント」で説明しているとおり、Hugging Face Transformers モデルのチェックポイントをロードし、PyTorch のチェックポイント API を SMP v2 でも引き続き使用できます。PyTorch FSDP モデルで SMP を使用するには、SMP v2 に移行し、PyTorch FSDP やその他の最新機能を使用するようにトレーニングスクリプトを変更してください。
import smdistributed.modelparallel.torch as smp # Create model model = ... model = smp.DistributedModel(model) # Run training ... # Save v1 full checkpoint if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model # Get the corresponding translation function in smp v1 and translate if model_type == "gpt_neox": from smdistributed.modelparallel.torch.nn.huggingface.gptneox import translate_state_dict_to_hf_gptneox translated_state_dict = translate_state_dict_to_hf_gptneox(state_dict, max_seq_len=None) # Save the checkpoint checkpoint_path = "checkpoint.pt" if smp.rank() == 0: smp.save( {"model_state_dict": translated_state_dict}, checkpoint_path, partial=False, )
SMP v1 で使用可能な変換関数については、「Hugging Face Transformer モデルのサポート」を参照してください。
SMP v2 でのモデルチェックポイントの保存とロードの手順については、「SMP を使用したチェックポイント」を参照してください。