Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.
FlashAttention
SMP v2 mendukung FlashAttention
Module (nn.Module
) adalah API tingkat rendah yang mendefinisikan lapisan perhatian model. Ini harus diterapkan tepat setelah pembuatan model, dari AutoModelForCausalLM.from_config()
API misalnya, dan sebelum model diubah atau dibungkus dengan FSDP.
Gunakan FlashAttention kernel untuk perhatian diri
Cuplikan kode berikut menunjukkan cara menggunakan torch.sagemaker.nn.attn.FlashSelfAttention API yang disediakan oleh SMP v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Gunakan FlashAttention kernel untuk perhatian kueri yang dikelompokkan
SMP v2 juga mendukung FlashAttention
Contoh penggunaan FlashGroupedQueryAttention
Cuplikan kode berikut menunjukkan cara menggunakan torch.sagemaker.nn.attn.FlashGroupedQueryAttention API yang disediakan oleh SMP v2.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
Pustaka SMP juga menyediakantorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, yang menggunakan torch.sagemaker.nn.attn.FlashGroupedQueryAttention API pada tingkat rendah. Hugging Face Transformers memiliki LlamaFlashAttention2
LlamaFlashAttention
LlamaFlashAttention2
API untuk mengganti lapisan perhatian model Llama yang ada.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))