混合精度トレーニング
SageMaker モデル並列処理 (SMP) ライブラリ v2 は、PyTorch FSDP や Transformer Engine などのオープンソースフレームワークと統合することで、そのまますぐに混合精度トレーニングに対応します。詳細については、以下のトピックを参照してください。
P5 インスタンスでの Transformer Engine を使用した FP8 混合精度トレーニング
SageMaker モデル並列処理 (SMP) ライブラリ v2.2.0 以降では、SMP ライブラリが Transformer EngineMixedPrecision
注記
SMP v2 は、次の Hugging Face Transformer モデルに対して FP8 をサポートしています。
-
GPT-NeoX (SMP v2.2.0 以降で利用可能)
-
Llama 2 (SMP v2.2.0 以降で利用可能)
-
Mixtral 8x7b および Mixtral 8x22b (SMP v2.5.0 以降で利用可能)
注記
P5 におけるこの FP8 トレーニングの機能は、SageMaker ライブラリと PyTorch ライブラリの次の組み合わせで使用できます。
-
SageMaker Python SDK v2.212.0 以降
-
PyTorch v2.2.0 以降
FP8 (8 ビット浮動小数点精度) は、LLM モデルの深層学習トレーニングを加速するための新たなパラダイムとして浮上したデータ型です。FP8 データ型をサポートする NVIDIA H100 GPU がリリースされたことで、H100 GPU を搭載した P5 インスタンスでパフォーマンス向上の恩恵を受けつつ、FP8 混合精度トレーニングによって分散トレーニングを加速できます。
FP8 データ型は、さらに E4M3 形式と E5M2 形式に分かれます。E4M3 は精度が高く、動的範囲が制限されており、モデルトレーニングのフォワードパス (順伝播) に最適です。E5M2 は、動的範囲が広くなる代わりに精度が低下しますが、精度はそれほど重視されず、動的範囲が広い方が有利なバックワードパス (逆伝播) に適しています。したがって、ハイブリッド FP8 戦略レシピ
半精度データ型 (FP16 および BF16) の場合は、グローバルロススケーリング手法 (静的ロススケーリングや動的ロススケーリングなど) によって、半精度での勾配の丸めによる情報損失に起因する収束の問題に対処します。しかし、FP8 の動的範囲はさらに狭く、グローバルロススケーリング手法では不十分です。この場合は、よりきめ細かなテンソルごとのスケーリング手法が必要です。遅延スケーリングは、以前のイテレーションの多数のテンソルで観測された最大絶対値に基づいて、スケーリング係数を選択する戦略です。この戦略にはトレードオフがあり、FP8 計算のパフォーマンスのメリットをフルに活かせる一方で、テンソルの最大値の履歴を保持するためのメモリが必要です。遅延スケーリング戦略全般の詳細については、「FP8 Formats for Deep Learning
実際、すべてのトレーニングシナリオにおいて、P5 インスタンスで FP8 を使用することは有益です。トレーニングのパフォーマンスを向上させるために、可能な限り FP8 を有効にすることを強くお勧めします。
SMP v2 は、Transformer Engine を標準でサポートしています。そのため、SageMaker の P5 インスタンス (ml.p5.48xlarge
) で SMP v2 を使用して FP8 トレーニングを実行する場合、トレーニングスクリプトで torch.sagemaker
をインポートし、ネイティブの Transformer Engine Python パッケージをそのまま使用するだけで済みます。Transformer Engine を使用した FP8 トレーニング全般の詳細については、NVIDIA Transformer Engine ドキュメントの「Using FP8 with Transformer Engine
import torch.sagemaker as tsm import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format # Initialize the SMP torch.sagemaker API. tsm.init() # Define a transformer model and wrap it with the torch.sagemaker.transform API. from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(
ModelConfig
) model = tsm.transform(model) # Enable E4M3 during forward pass, E5M2 during backward pass. fp8_format = Format.HYBRID # Create an FP8 recipe. fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Enable FP8 autocasting. with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group): out = model(inp) loss = out.sum() loss.backward()
P5 インスタンスでの SMP v2 を使用した FP8 トレーニングの実践的な例については、「Accelerate SageMaker PyTorch FSDP Training of Llama-v2 (or GPT-NeoX) with FP8 on P5 instances
PyTorch FSDP を使用した半精度データ型による混合精度トレーニング
SMP v2 は、P4 インスタンスと P5 インスタンスでのトレーニングジョブで PyTorch FSDP MixedPrecision
注記
PyTorch FSDP の機能を使用したこの混合精度トレーニングは、SageMaker ライブラリと PyTorch ライブラリの次の組み合わせで使用できます。
-
SMP v2.0.0 以降
-
SageMaker Python SDK v2.200.0 以降
-
PyTorch v2.0.1 以降
モデルを混合精度で設定する標準的な方法としては、float32
でモデルを作成したうえで、MixedPrecision
ポリシーを渡し、FSDP がパラメータを float16
または bfloat16
にその場でキャストできるようにします。次のコードスニペットを参照してください。PyTorch における混合精度のパラメータ、集約 (reduction)、バッファの dtype
を変更するオプションの詳細については、PyTorch ドキュメントの「PyTorch FSDP MixedPrecision
API
# Native PyTorch API from torch.distributed.fsdp import MixedPrecision dtype = torch.bfloat16 mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype ) model = FSDP( model, ..., mixed_precision=mixed_precision_policy )
特定のモデル (Hugging Face Transformers Llama モデルなど) では、バッファは float32
として想定されています。float32
を使用するには、dtype
オブジェクトを定義する行で torch.bfloat16
を torch.float32
に置き換えてください。