Support para FlashAttention
Support for FlashAttention é um atributo da biblioteca aplicável apenas ao modelo de transformador distribuído, que é um modelo de transformador incluído smp.DistributedModel()
A biblioteca FlashAttentionattention_head_size
é definida com um valor múltiplo de 8 e menor que 128. Portanto, ao treinar um transformador distribuído e garantir que o FlashAttention funcione corretamente, você deve ajustar os parâmetros para que o tamanho da cabeça de atenção atenda aos requisitos. Para obter mais informações, consulte também Instalação e atributos
Por exemplo, suponha que você configure um modelo Transformador com hidden_width=864
e num_heads=48
. O tamanho da cabeça do FlashAttention é calculado como attention_head_size = hidden_width / num_heads = 864 / 48 = 18
. Para ativar o FlashAttention, você precisa ajustar o parâmetro num_heads
para 54
, de forma que attention_head_size = hidden_width / num_heads = 864
/ 54 = 16
, o qual seja um múltiplo de 8.