FlashAttention - Amazon SageMaker

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

FlashAttention

SMP v2는 FlashAttention 커널을 지원하며 Hugging Face Transformer 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. FlashAttention 패키지 v2.0 이상을 사용하는 경우 는 FlashAttention v2SMP를 사용하지만 Triton 플래시 주의는 v1.x의 플래시 주의 커널로 기본 설정되어 FlashAttention v FlashAttention 1에서만 지원됩니다.

모듈(nn.Module)은 모델의 주의 계층을 API 정의하는 낮은 수준입니다. AutoModelForCausalLM.from_config() API 예를 들어 에서 모델 생성 직후, 모델을 로 변환하거나 래핑하기 전에 적용해야 합니다FSDP.

자기 관심을 위해 FlashAttention 커널 사용

다음 코드 조각은 SMP v2에서 torch.sagemaker.nn.attn.FlashSelfAttention API 제공하는 를 사용하는 방법을 보여줍니다.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

그룹화된 쿼리 주의에 FlashAttention 커널 사용

SMP 또한 v2는 그룹화된 쿼리 주의(GQA)에 대한 FlashAttention 커널을 지원하며 Hugging Face Transformer 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. 원래 주의 아키텍처와 달리 는 쿼리 헤드를 그룹으로 GQA균등하게 분할하고 동일한 그룹의 쿼리 헤드는 동일한 키와 값 헤드를 공유합니다. 따라서 q 및 kv 헤드는 전달 호출로 별도로 전달됩니다. 참고: q 헤드 수는 kv 헤드 수로 나눌 수 있어야 합니다.

사용 예제 FlashGroupedQueryAttention

다음 코드 조각은 SMP v2에서 torch.sagemaker.nn.attn.FlashGroupedQueryAttention API 제공하는 를 사용하는 방법을 보여줍니다.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

SMP 라이브러리는 torch.sagemaker.nn.attn.FlashGroupedQueryAttentionAPI를 낮은 수준에서 torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention사용하는 도 제공합니다. Hugging Face Transformer는 v4.36.0LlamaFlashAttention2에서 라는 유사한 구현을 수행합니다. 다음 코드 조각은 SMP v2 LlamaFlashAttention API 또는 변환기를 사용하여 기존 Llama 모델의 주의 계층LlamaFlashAttention2API을 대체하는 방법을 보여줍니다.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))