Skip to content

JuliaGenAI/FlashAttentionWrapper.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashAttentionWrapper.jl

Just a simple wrapper for the Flash Attention operation.

Installation

using FlashAttentionWrapper

FlashAttentionWrapper.install()

Note that by default it will install the latest version of FlashAttention.

Example

using FlashAttentionWrapper

# q, k, v are assumed to be 4d CuArray of size (head_dim, n_heads, seq_len, batch_size)
o = mha(q, k, v; kw...) 

Check the original doc on the explanation of supported keyword arguments.

Backward is also supported:

using CUDA
using Zygote

o, back = Zygote.pullback(q, k, v) do q, k, v
    mha(q, k, v)
end

Δo = CUDA.randn(eltype(o), size(o))

Δq, Δk, Δv = back(Δo)

If you'd like to use it with Lux.jl, here's a handy example:

using Lux

head_dim, n_head, seq_len, batch_size = 256, 8, 1024, 4
hidden_dim = head_dim * n_head

x = CUDA.randn(Float16, (hidden_dim, seq_len, batch_size))

m = Chain(
    BranchLayer(
        Chain(
            Dense(hidden_dim => hidden_dim, use_bias=false),
            ReshapeLayer((head_dim, n_head, seq_len))
        ),
        Chain(
            Dense(hidden_dim => hidden_dim, use_bias=false),
            ReshapeLayer((head_dim, n_head, seq_len))
        ),
        Chain(
            Dense(hidden_dim => hidden_dim, use_bias=false),
            ReshapeLayer((head_dim, n_head, seq_len))
        ),
    ),
    Attention()
)

using Random
rng = Random.default_rng()
ps, st = LuxCore.setup(rng, m)
cu_ps = recursive_map(CuArray{Float16}, ps)

o, _ = m(x, cu_ps, st)

TODO List

  • Add benchmark
  • Support FlexAttention?

About

Just a simple wrapper for the Flash Attention operation.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages