FlashAttention - Amazon SageMaker

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

FlashAttention

SMPla v2 supporta i FlashAttentionkernel e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. Nota che se usi il FlashAttention pacchetto v2.0 o successivo, SMP utilizza la FlashAttention v2; tuttavia, Triton flash attention utilizza per impostazione predefinita il kernel flash attention nella v1.x, rendendolo supportato esclusivamente nella v1. FlashAttention FlashAttention

Il module (nn.Module) è un livello basso che definisce i livelli API di attenzione di un modello. Dovrebbe essere applicato subito dopo la creazione del modello, ad AutoModelForCausalLM.from_config() API esempio, e prima che il modello venga trasformato o avvolto. FSDP

Utilizzate i FlashAttention kernel per l'attenzione personale

Il seguente frammento di codice mostra come utilizzare quanto torch.sagemaker.nn.attn.FlashSelfAttention API fornito dalla v2. SMP

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Usa i FlashAttention kernel per attirare l'attenzione sulle query raggruppate

SMPv2 supporta anche i FlashAttentionkernel per grouped-query attention (GQA) e ne semplifica l'applicazione a vari scenari per i modelli Hugging Face Transformer. A differenza dell'architettura di attenzione originale, suddivide GQA equamente le teste di interrogazione in gruppi e le testine di query dello stesso gruppo condividono le stesse teste di chiave e di valore. Pertanto, le testine q e kv vengono passate separatamente alla chiamata in avanti. Nota: il numero di teste q deve essere divisibile per il numero di teste kv.

Esempio di utilizzo FlashGroupedQueryAttention

Il seguente frammento di codice mostra come utilizzare il file torch.sagemaker.nn.attn.FlashGroupedQueryAttention API fornito da SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

La SMP libreria fornisce anchetorch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, che utilizza il torch.sagemaker.nn.attn.FlashGroupedQueryAttention API a basso livello. Hugging Face Transformers ha un'implementazione simile chiamata dalla v4.36.0. LlamaFlashAttention2 Il seguente frammento di codice mostra come utilizzare la SMP v2 LlamaFlashAttention API o i LlamaFlashAttention2 API Transformers per sostituire i livelli di attenzione di un modello Llama esistente.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))