のサポート FlashAttention - Amazon SageMaker

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

のサポート FlashAttention

のサポート FlashAttention は、分散トランスフォーマーモデルにのみ適用可能なライブラリの機能です。分散トランスフォーマーモデルは、モデル並列トレーニングsmp.DistributedModel()のために によってラップされたトランスフォーマーモデルです。この機能は テンソル並列処理 とも互換性があります。

FlashAttention ライブラリは、 attention_head_sizeが 8 の倍数で 128 未満の値に設定されている場合にのみモデルをサポートします。したがって、分散トランスフォーマーをトレーニングし、 が正しく FlashAttention 動作することを確認するときは、注意ヘッドのサイズが要件に準拠するようにパラメータを調整する必要があります。詳細については、「 FlashAttention GitHubリポジトリ」の「インストールと機能」も参照してください。

例えば、hidden_width=864num_heads=48 を使用して Transformer モデルを設定すると仮定します。のヘッドサイズ FlashAttention は として計算されますattention_head_size = hidden_width / num_heads = 864 / 48 = 18。を有効にするには FlashAttention、 num_headsパラメータを に調整する必要があります。54これによりattention_head_size = hidden_width / num_heads = 864 / 54 = 16、 は 8 の倍数になります。