混合精度トレーニング - Amazon SageMaker

混合精度トレーニング

SageMaker モデル並列処理 (SMP) ライブラリ v2 は、PyTorch FSDP や Transformer Engine などのオープンソースフレームワークと統合することで、そのまますぐに混合精度トレーニングに対応します。詳細については、以下のトピックを参照してください。

P5 インスタンスでの Transformer Engine を使用した FP8 混合精度トレーニング

SageMaker モデル並列処理 (SMP) ライブラリ v2.2.0 以降では、SMP ライブラリが Transformer Engine と統合されており、そのまますぐに FP8 混合精度トレーニングに対応します。PyTorch FSDP MixedPrecision との互換性は維持されます。つまり、混合精度トレーニング用の PyTorch FSDP と、FP8 トレーニング用の Transformer Engine を両方とも使用できます。Transformer Engine の FP8 トレーニング機能でサポートされていないモデル層には、PyTorch FSDP の混合精度が使用されます。

注記

SMP v2 は、次の Hugging Face Transformer モデルに対して FP8 をサポートしています。

  • GPT-NeoX (SMP v2.2.0 以降で利用可能)

  • Llama 2 (SMP v2.2.0 以降で利用可能)

  • Mixtral 8x7b および Mixtral 8x22b (SMP v2.5.0 以降で利用可能)

注記

P5 におけるこの FP8 トレーニングの機能は、SageMaker ライブラリと PyTorch ライブラリの次の組み合わせで使用できます。

  • SageMaker Python SDK v2.212.0 以降

  • PyTorch v2.2.0 以降

FP8 (8 ビット浮動小数点精度) は、LLM モデルの深層学習トレーニングを加速するための新たなパラダイムとして浮上したデータ型です。FP8 データ型をサポートする NVIDIA H100 GPU がリリースされたことで、H100 GPU を搭載した P5 インスタンスでパフォーマンス向上の恩恵を受けつつ、FP8 混合精度トレーニングによって分散トレーニングを加速できます。

FP8 データ型は、さらに E4M3 形式と E5M2 形式に分かれます。E4M3 は精度が高く、動的範囲が制限されており、モデルトレーニングのフォワードパス (順伝播) に最適です。E5M2 は、動的範囲が広くなる代わりに精度が低下しますが、精度はそれほど重視されず、動的範囲が広い方が有利なバックワードパス (逆伝播) に適しています。したがって、ハイブリッド FP8 戦略レシピを使用して、両者の特性を効果的に活用することをお勧めします。

半精度データ型 (FP16 および BF16) の場合は、グローバルロススケーリング手法 (静的ロススケーリングや動的ロススケーリングなど) によって、半精度での勾配の丸めによる情報損失に起因する収束の問題に対処します。しかし、FP8 の動的範囲はさらに狭く、グローバルロススケーリング手法では不十分です。この場合は、よりきめ細かなテンソルごとのスケーリング手法が必要です。遅延スケーリングは、以前のイテレーションの多数のテンソルで観測された最大絶対値に基づいて、スケーリング係数を選択する戦略です。この戦略にはトレードオフがあり、FP8 計算のパフォーマンスのメリットをフルに活かせる一方で、テンソルの最大値の履歴を保持するためのメモリが必要です。遅延スケーリング戦略全般の詳細については、「FP8 Formats for Deep Learning」を参照してください。

実際、すべてのトレーニングシナリオにおいて、P5 インスタンスで FP8 を使用することは有益です。トレーニングのパフォーマンスを向上させるために、可能な限り FP8 を有効にすることを強くお勧めします。

SMP v2 は、Transformer Engine を標準でサポートしています。そのため、SageMaker の P5 インスタンス (ml.p5.48xlarge) で SMP v2 を使用して FP8 トレーニングを実行する場合、トレーニングスクリプトで torch.sagemaker をインポートし、ネイティブの Transformer Engine Python パッケージをそのまま使用するだけで済みます。Transformer Engine を使用した FP8 トレーニング全般の詳細については、NVIDIA Transformer Engine ドキュメントの「Using FP8 with Transformer Engine」を参照してください。次のコードスニペットは、SMP ライブラリをインポートし、トレーニングスクリプトで FP8 を設定するためのコード行の例を示しています。

import torch.sagemaker as tsm import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format # Initialize the SMP torch.sagemaker API. tsm.init() # Define a transformer model and wrap it with the torch.sagemaker.transform API. from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(ModelConfig) model = tsm.transform(model) # Enable E4M3 during forward pass, E5M2 during backward pass. fp8_format = Format.HYBRID # Create an FP8 recipe. fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Enable FP8 autocasting. with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group): out = model(inp) loss = out.sum() loss.backward()

P5 インスタンスでの SMP v2 を使用した FP8 トレーニングの実践的な例については、「Accelerate SageMaker PyTorch FSDP Training of Llama-v2 (or GPT-NeoX) with FP8 on P5 instances」のサンプルノートブックを参照してください。

PyTorch FSDP を使用した半精度データ型による混合精度トレーニング

SMP v2 は、P4 インスタンスと P5 インスタンスでのトレーニングジョブで PyTorch FSDP MixedPrecision に対応します。PyTorch FSDP は、パフォーマンス向上とメモリ削減の両方を目的として、混合精度のさまざまな設定を提供しています。

注記

PyTorch FSDP の機能を使用したこの混合精度トレーニングは、SageMaker ライブラリと PyTorch ライブラリの次の組み合わせで使用できます。

  • SMP v2.0.0 以降

  • SageMaker Python SDK v2.200.0 以降

  • PyTorch v2.0.1 以降

モデルを混合精度で設定する標準的な方法としては、float32 でモデルを作成したうえで、MixedPrecision ポリシーを渡し、FSDP がパラメータを float16 または bfloat16 にその場でキャストできるようにします。次のコードスニペットを参照してください。PyTorch における混合精度のパラメータ、集約 (reduction)、バッファの dtype を変更するオプションの詳細については、PyTorch ドキュメントの「PyTorch FSDP MixedPrecision API」を参照してください。

# Native PyTorch API from torch.distributed.fsdp import MixedPrecision dtype = torch.bfloat16 mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype ) model = FSDP( model, ..., mixed_precision=mixed_precision_policy )

特定のモデル (Hugging Face Transformers Llama モデルなど) では、バッファは float32 として想定されています。float32 を使用するには、dtype オブジェクトを定義する行で torch.bfloat16torch.float32 に置き換えてください。