SageMaker 分散モデル並列処理ライブラリ設定のヒントと落とし穴 - Amazon SageMaker

SageMaker 分散モデル並列処理ライブラリ設定のヒントと落とし穴

Amazon SageMaker のモデル並列処理ライブラリを使用する前に、次のヒントと落とし穴を確認してください。このリストには、すべてのフレームワークに通じるヒントが含まれています。TensorFlow と PyTorch 固有のヒントについては、それぞれ「TensorFlow トレーニングスクリプトを変更する」と「PyTorch トレーニングスクリプトを変更する」を参照してください。

バッチサイズとマイクロバッチ数

  • このライブラリは、バッチサイズを増やすと非常に効率的になります。モデルが 1 つのデバイス内に収まるが、小さなバッチサイズでしかトレーニングできないユースケースの場合、ライブラリの統合後にバッチサイズを増やすことができ、また、そのようにしてください。モデル並列処理は大規模モデルに対するメモリを節約するので、以前はメモリに収まらなかったバッチサイズを使ってトレーニングを実行できるようになります。

  • 過少または過大なマイクロバッチ数を選択すると、パフォーマンスが低下する可能性があります。ライブラリは各デバイスで各マイクロバッチを順次実行するため、マイクロバッチサイズ (バッチサイズをマイクロバッチ数で割った値) は、各 GPU を十分に利用できる大きさである必要があります。同時に、パイプラインの効率はマイクロバッチの数とともに向上するため、適切なバランスを取ることが重要です。通常は、まず、2 つまたは 4 つのマイクロバッチを試し、バッチサイズをメモリ制限まで増やしてから、より大きなバッチサイズとマイクロバッチ数を試してみることをお勧めします。マイクロバッチ数を増やすと、インターリーブパイプラインを使う場合に、より大きなバッチサイズが可能になる可能性があります。

  • バッチサイズは、常にマイクロバッチ数で割り切れなければなりません。データセットのサイズによって、すべてのエポックの最後のバッチが残りのバッチよりも小さなサイズになることがあり、このより小さなバッチもマイクロバッチ数で割り切れる必要があることに注意してください。そうでない場合は、tf.Dataset.batch() の呼び出しで drop_remainder=True を設定するか (TensorFlow の場合)、DataLoaderdrop_last=True を設定して (PyTorch の場合)、この最後の小さなバッチが使われないようにします。データパイプラインに異なる API を使っている場合、最後のバッチがマイクロバッチ数で割り切れないときは常に、手動でそのバッチをスキップする必要があります。

手動パーティショニング

  • 手動パーティショニングを行う場合、トランスフォーマーアーキテクチャの埋め込みテーブルなど、モデル内の複数のオペレーションやモジュールによって使用されるパラメータに注意してください。同じパラメータを共有するモジュールは、正確性確保のために同じデバイスに配置する必要があります。自動パーティショニングが使われる場合、ライブラリはこの制約を自動的に適用します。

データ準備

  • モデルが複数の入力を取り込む場合は、smp.dp_rank() を使ってデータパイプラインにランダムなオペレーション (シャッフルなど) をシードするようにしてください。データセットがデータ並列デバイス間で決定的にシャードされている場合、シャードが smp.dp_rank() でインデックス付けされるようにしてください。これは、モデルパーティションを形成するすべてのランクで見られるデータの順序が必ず一致するようにするためです。

smp.DistributedModel から返されるテンソル

  • smp.DistributedModel.call (TensorFlow の場合) または smp.DistributedModel.forward (PyTorch の場合) の関数から返されるテンソルはすべて、その特定のテンソルを計算したランクから、他のすべてのランクにブロードキャストされます。その結果、call メソッドと forward メソッドの範囲外では必要とされないテンソル (中間アクティベーションなど) はどれも、不要な通信やメモリのオーバーヘッドを発生させ、パフォーマンスを低下させるため、返さないようにしてください。

@smp.step デコレータ

  • smp.step で修飾された関数に、バッチディメンションを持たないテンソル引数がある場合、smp.step を呼び出すときに、引数名を non_split_inputs リストで与える必要があります。これにより、ライブラリがテンソルをマイクロバッチに分割しようとするのを防げます。詳細については、 API ドキュメントの「smp.step」を参照してください。

パラメータの初期化の遅延

パラメータが 1,000 億個を超える非常に大規模なモデルでは、CPU メモリを使用した重みの初期化によってメモリ不足エラーが発生する可能性があります。この問題を回避するために、ライブラリには smp.delay_param_initialization コンテキストマネージャーが用意されています。これにより、smp.step で修飾された関数の初回実行時に GPU に移動するまで、パラメータの物理的な割り当てを遅らせることができます。これにより、トレーニング初期化中に CPU の不必要なメモリを使用することを避けられます。次のコードに示すように、モデルオブジェクトを作成するときは、コンテキストマネージャーを使用してください。

with smp.delay_param_initialization(enabled=True): model = MyModel()

PyTorch のテンソル並列処理

  • 決定論的な結果を得るためにシードを使用する場合は、smp.dp_rank() に基づいてシードを設定します (例えば、torch.manual_seed(42 + smp.dp_rank()))。これを行わないと、nn.Parameter の異なるパーティションが同じ方法で初期化され、収束に影響します。

  • SageMaker のモデル並列処理ライブラリは、NCCL を使用してモジュールの分配に必要な集合体を実装しています。特に小規模なモデルでは、GPU で同時にスケジュールされる NCCL 呼び出しが多すぎると、NCCL が使用するスペースが増えるため、メモリ使用量が増加する可能性があります。これを打ち消すために、smp は NCCL 呼び出しを抑制して、実行中の NCCL 操作数が常に特定の制限値以下になるようにします。デフォルトの上限は 8 ですが、これは環境変数 SMP_NCCL_THROTTLE_LIMIT を使用して調整できます。テンソル並列処理の使用中にメモリ使用量が予想以上に多い場合は、この制限を減らすことができます。ただし、制限値が小さすぎると、スループットが低下する可能性があります。スロットリングを完全に無効にするには、SMP_NCCL_THROTTLE_LIMIT=-1 を設定します。

  • 次の等式は、テンソルの並列処理度が 1 の場合に成立しますが、テンソルの並列処理度が 1 より大きい場合には成立しません: smp.mp_size() * smp.dp_size() == smp.size()。これは、テンソル並列グループが、モデル並列処理グループとデータ並列処理グループの両方に属しているためです。コードに mp_rankmp_sizeMP_GROUP などの既存の参照があり、パイプライン並列グループのみで作業する場合は、参照を smp.pp_size() に置き換える必要がある場合があります。以下の等式は常に当てはまります:

    • smp.mp_size() * smp.rdp_size() == smp.size()

    • smp.pp_size() * smp.dp_size() == smp.size()

    • smp.pp_size() * smp.tp_size() * smp.rdp_size() == smp.size()

  • テンソル並列処理が有効になっている場合、smp.DistributedModel ラッパーはモデルパラメータを変更するため、オプティマイザは smp.DistributedModel を呼び出した後に分散パラメータを使用して作成する必要があります。例えば、以下は機能しません。

    ## WRONG model = MyModel() optimizer = SomeOptimizer(model.parameters()) model = smp.DistributedModel(model)  # optimizer now has outdated parameters! 

    代わりに、smp.DistributedModel パラメータを使用して、以下のようにオプティマイザを作成する必要があります。

    ## CORRECT model = smp.DistributedModel(MyModel()) optimizer = SomeOptimizer(model.optimizers())
  • あるモジュールをテンソル並列処理によって分散モジュールに置き換えると、分散モジュールは元のモジュールから重みを継承せず、新しい重みを初期化します。例えば、特定の呼び出しで重みを初期化する必要がある場合 (例: load_state_dict 呼び出しを通じて)、smp.DistributedModel 呼び出しの後、モジュールの分配が行われた後に行う必要があります。

  • 分散モジュールのパラメータに直接アクセスする場合、重みの形状は元のモジュールと同じではないことに注意してください。例えば、 

    with smp.tensor_parallelism():     linear = nn.Linear(60, 60) # will pass assert tuple(linear.weight.shape) == (60, 60) distributed_linear = smp.DistributedModel(linear) # will fail. the number of input channels will have been divided by smp.tp_size() assert tuple(distributed_linear.module.weight.shape) == (60, 60)
  • テンソルの並列処理には torch.utils.data.distributed.DistributedSampler を使用することを強くお勧めします。これにより、すべてのデータ並列ランクが同じ数のデータサンプルを受け取ることができるようになり、異なる dp_rank が異なるステップ数を取ることで生じるハングを防ぐことができます。

  • PyTorch の DistributedDataParallel クラスの join API を使用して、異なるデータ並列ランクのバッチ数が異なるケースを処理する場合でも、同じ TP_GROUP 内のランクのバッチ数が同じであることを確認する必要があります。それ以外の場合、モジュールの分散実行に使用される通信集合がハングする可能性があります。join API を使用する限り、異なる TP_GROUP ランクのバッチ数が異なっていても構いません。

  • モデルをチェックポイントしてテンソル並列処理を使用する場合は、以下の点を考慮してください。

    • テンソル並列処理を使用する際、モデルの保存およびロード中に遅延が発生したり競合状態にならないようにするには、必ず適切な関数を以下のモデルとオプティマイザの状態から縮小データ並列処理ランク内で呼び出すようにしてください。

    • 既存のパイプライン並列スクリプトを移行し、スクリプトのテンソル並列を有効にするには、if smp.rdp_rank() == 0 ブロックの保存とロードに使用される if smp.dp_rank() == 0 ブロックを必ず変更してください。それ以外の場合、トレーニングジョブが停止する可能性があります。

    テンソル並列処理を使用したモデルのチェックポイントの詳細については、分散モデルのチェックポイント機能 を参照してください。