From 2f42f74e71234bb61c2b983d23a20d1fd27c57bf Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 12:39:24 +0000 Subject: [PATCH 01/12] Added flash attention kernel --- .../ops/flash_attention/backward/caller.py | 178 +++++++++++ .../flash_attention/backward/compute_delta.py | 73 +++++ .../flash_attention/backward/compute_dkdv.py | 296 ++++++++++++++++++ .../flash_attention/backward/compute_dq.py | 261 +++++++++++++++ .../ops/flash_attention/backward/kernel.py | 182 +++++++++++ .../ops/flash_attention/forward/caller.py | 121 +++++++ .../forward/compute_row_blocks.py | 103 ++++++ .../ops/flash_attention/forward/kernel.py | 291 +++++++++++++++++ .../reference_implementation.py | 129 ++++++++ src/liger_kernel/ops/flash_attention/utils.py | 109 +++++++ .../ops/flash_attention/wrapper.py | 100 ++++++ 11 files changed, 1843 insertions(+) create mode 100644 src/liger_kernel/ops/flash_attention/backward/caller.py create mode 100644 src/liger_kernel/ops/flash_attention/backward/compute_delta.py create mode 100644 src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py create mode 100644 src/liger_kernel/ops/flash_attention/backward/compute_dq.py create mode 100644 src/liger_kernel/ops/flash_attention/backward/kernel.py create mode 100644 src/liger_kernel/ops/flash_attention/forward/caller.py create mode 100644 src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py create mode 100644 src/liger_kernel/ops/flash_attention/forward/kernel.py create mode 100644 src/liger_kernel/ops/flash_attention/reference_implementation.py create mode 100644 src/liger_kernel/ops/flash_attention/utils.py create mode 100644 src/liger_kernel/ops/flash_attention/wrapper.py diff --git a/src/liger_kernel/ops/flash_attention/backward/caller.py b/src/liger_kernel/ops/flash_attention/backward/caller.py new file mode 100644 index 00000000..d80b293e --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/backward/caller.py @@ -0,0 +1,178 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +from torch import Tensor + +from src.liger_kernel.ops.flash_attention.backward.compute_delta import _compute_delta +from src.liger_kernel.ops.flash_attention.backward.kernel import _bwd_kernel +from src.liger_kernel.ops.flash_attention.utils import attention_pack, attention_unpack, torch_ignore_deterministic, infer_bias_strides, handle_dropout, encode_dtype + + +def _flash_attn_backward( + dO: Tensor, # [batch_size, seqlen_q, nheads_q, head_dim] + q: Tensor, # [batch_size, seqlen_q, nheads_q, head_dim] + k: Tensor, # [batch_size, seqlen_k, nheads_kv, head_dim] + v: Tensor, # [batch_size, seqlen_k, nheads_kv, head_dim] + bias: Optional[Tensor], # [1 | batch_size, 1 | nheads_q, seqlen_q, seqlen_k] + attention_mask: Optional[Tensor], # [batch_size, seqlen_qk] + o: Tensor, # [batch_size, seqlen_q, nheads_q, head_dim] + lse: Tensor, # [batch_size, nheads_q, max_seqlen_q_rounded] + dropout_p: float, + causal: bool, + softmax_scale: Optional[float], + dropout_seed: Optional[int], +) -> Tuple[Tensor, Tensor, Tensor]: + + if attention_mask is not None: + assert bias is None, "Attention mask is not supported along with attention bias. Just use bias instead." + assert q.size(1) == k.size(1), "Attention mask is not supported with seqlen_q != seqlen_k" + varlen_mode = (attention_mask.size(0) > 1) + useless_padding = attention_mask.size(1) - attention_mask.sum(-1).max().item() + if useless_padding > 0: + dO = dO[:, :-useless_padding] + q = q[:, :-useless_padding] + k = k[:, :-useless_padding] + v = v[:, :-useless_padding] + attention_mask = attention_mask[:, :-useless_padding] + o = o[:, :-useless_padding] + else: + varlen_mode = False + useless_padding = 0 + + # Retrieve and check shapes + dO = dO.contiguous() if dO.stride(-1) != 1 else dO + batch_size, seqlen_q, nheads_q, head_dim = q.shape + _, seqlen_k, nheads_kv, _ = k.shape + max_seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" + assert lse.shape == (batch_size, nheads_q, max_seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + + # Depending on attention_mask, switch to varlen + if varlen_mode: + # Compute padding-related statistics + cum_seqlens_q = torch.zeros(size=(attention_mask.size(0)+1,), device=attention_mask.device, dtype=torch.int32) + cum_seqlens_k = torch.zeros(size=(attention_mask.size(0)+1,), device=attention_mask.device, dtype=torch.int32) + with torch_ignore_deterministic(): + cum_seqlens_q[1:] = attention_mask.sum(dim=1).cumsum(0) + cum_seqlens_k[1:] = attention_mask.sum(dim=1).cumsum(0) + # cum_seqlens_q = [0, seqlen_q1, seqlen_q1+seqlen_q2, ..., seqlen_q1+...+seqlen_qB] of shape [B+1] + max_seqlen_q: int = attention_mask.size(1) + max_seqlen_k: int = attention_mask.size(1) + # Collate all matrices + q = attention_pack(q, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + k = attention_pack(k, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + v = attention_pack(v, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + o = attention_pack(o, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + dO = attention_pack(dO, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + # Update seqlens + seqlen_q = q.size(1) + seqlen_k = k.size(1) + else: + cum_seqlens_q = None + cum_seqlens_k = None + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # Handle bias and dropout + stride_bb, stride_bh, stride_bm = infer_bias_strides(bias, batch_size, nheads_q, seqlen_q, seqlen_k) + dropout_seed = handle_dropout(dropout_p, dropout_seed, is_forward=False) + + # Prepare gradient accumulators # TODO: maybe we can initialize this as empty -- check pre hook + dq = torch.zeros_like(q, dtype=torch.float32) # [batch_size|1, seqlen_q|sum_seqlens_qk, nheads_q, head_dim] + dk = torch.zeros(size=(k.size(0), k.size(1), q.size(2), k.size(3)), device=k.device, dtype=k.dtype) + dv = torch.zeros(size=(v.size(0), v.size(1), q.size(2), v.size(3)), device=v.device, dtype=v.dtype) + delta = torch.zeros_like(lse) # [batch_size, nheads_q, max_seqlen_q_rounded] + + # Infer problem size + BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) + # Launch the delta computation kernel + grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch_size * nheads_q) # noqa: E731 + _compute_delta[grid]( + o, + dO, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + dO.stride(0), + dO.stride(2), + dO.stride(1), + nheads_q, + seqlen_q, + max_seqlen_q_rounded, + cum_seqlens_q, + head_dim, + max_seqlen_q // 32, + encode_dtype(o), + VARLEN=varlen_mode, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # Launch backward kernel + head_ratio = nheads_q // nheads_kv + grid = lambda META: ( # noqa: E731 + triton.cdiv(seqlen_k, META["BLOCK_N1"]) + triton.cdiv(seqlen_q, META["BLOCK_M2"]), + batch_size * nheads_q, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + dO, + dq, + dk, + dv, + lse, + delta, + softmax_scale, + dropout_p, + dropout_seed, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + stride_bb, stride_bh, stride_bm, + dO.stride(0), dO.stride(2), dO.stride(1), + dq.stride(0), dq.stride(2), dq.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads_q, + head_ratio, + seqlen_q, + cum_seqlens_q, + seqlen_k, + cum_seqlens_k, + max_seqlen_q_rounded, + head_dim, + max_seqlen_q // 32, + max_seqlen_k // 32, # key for triton cache (limit number of compilations) + encode_dtype(q), + VARLEN=varlen_mode, + IS_CAUSAL=causal, + BIAS_ON=(bias is not None), + USE_DROPOUT=(dropout_p > 0), + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # GQA reduction + if head_ratio > 1: + dk = dk.unflatten(dim=2, sizes=(nheads_kv, head_ratio)).sum(-2) + dv = dv.unflatten(dim=2, sizes=(nheads_kv, head_ratio)).sum(-2) + + # In case of variable length mode, we need to unpack the gradients + if varlen_mode: + dq = attention_unpack(dq, cum_seqlens_q, batch_size, max_seqlen_q) + dk = attention_unpack(dk, cum_seqlens_k, batch_size, max_seqlen_k) + dv = attention_unpack(dv, cum_seqlens_k, batch_size, max_seqlen_k) + # And add back the useless padding if there was any + if useless_padding > 0: + dq = F.pad(dq, (0, 0, 0, 0, 0, useless_padding)) + dk = F.pad(dk, (0, 0, 0, 0, 0, useless_padding)) + dv = F.pad(dv, (0, 0, 0, 0, 0, useless_padding)) + + return dq, dk, dv diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_delta.py b/src/liger_kernel/ops/flash_attention/backward/compute_delta.py new file mode 100644 index 00000000..54ad3b5a --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/backward/compute_delta.py @@ -0,0 +1,73 @@ +import triton +import triton.language as tl +from triton import Config + +MIN_B = 16 + + +@triton.autotune( + configs=[ + Config({"BLOCK_M": MIN_B}, num_warps=4, num_stages=0), + Config({"BLOCK_M": 32}, num_warps=4, num_stages=0), + Config({"BLOCK_M": 64}, num_warps=4, num_stages=0), + Config({"BLOCK_M": 128}, num_warps=4, num_stages=0), + ], + key=["CACHE_KEY_SEQLEN_Q", "DTYPE"], # TODO: add dtype +) +@triton.jit +def _compute_delta( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + max_seqlen_q_rounded, + cum_seqlens_q, + headdim, + CACHE_KEY_SEQLEN_Q, + DTYPE, + VARLEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + # Locate kernel inside the grid + start_m = tl.program_id(0) # current block in the Q matrix + off_head_and_batch = tl.program_id(1) + off_batch = off_head_and_batch // nheads + off_head = off_head_and_batch % nheads + # Initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Infer actual sequence length of Q and the offset to the last sequence + if VARLEN: + actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - tl.load(cum_seqlens_q + off_batch) + cu_seq_start_q = tl.load(cum_seqlens_q + off_batch) + off_batch = 0 + else: + actual_seqlen_q = seqlen_q + cu_seq_start_q = 0 + + # Load the output tensor + Out_offseted = Out + off_batch * stride_ob + off_head * stride_oh + cu_seq_start_q * stride_om + o = tl.load( + Out_offseted + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + # And its gradient + DO_offseted = DO + off_batch * stride_dob + off_head * stride_doh + cu_seq_start_q * stride_dom + do = tl.load( + DO_offseted + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_head_and_batch * max_seqlen_q_rounded + offs_m, delta) diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py b/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py new file mode 100644 index 00000000..34c8427a --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py @@ -0,0 +1,296 @@ +import triton +import triton.language as tl + +from src.liger_kernel.ops.flash_attention.utils import load_fn + + +@triton.jit +def _compute_single_block_dkdv( + I_start_m, + k, + v, + dk, + dv, + LSE, + D, + offs_m, + offs_n, + offs_d, + q_ptrs, + bias_ptrs, + dropout_offs, + do_ptrs, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_bm, + stride_dom, + actual_seqlen_q, + actual_seqlen_k, + fully_masked_lines, + headdim, + MASKED: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + USE_DROPOUT: tl.constexpr, + PAD_ROWS: tl.constexpr, + PAD_COLS: tl.constexpr, + HEADS_PADDED: tl.constexpr, +): + # Relocate pointers + q_ptrs = q_ptrs + I_start_m * stride_qm + do_ptrs = do_ptrs + I_start_m * stride_dom + if BIAS_ON: + bias_ptrs = bias_ptrs + I_start_m * stride_bm + if USE_DROPOUT: + dropout_offs += I_start_m * actual_seqlen_k + + # Update row variables + offs_m_curr = I_start_m + offs_m + + # Load Q and LSE now to reduce pipeline stall + # BUG: if one is true and the ther not, q is filled with wrong values + q = load_fn(q_ptrs, + offs_m_curr, offs_d, + PAD_ROWS or HEADS_PADDED, PAD_ROWS or HEADS_PADDED, + actual_seqlen_q, headdim) + lse_i = tl.load(LSE + offs_m_curr) # since lsm is padded to max_seqlen_q, should be good + if BIAS_ON: + bias = load_fn( + bias_ptrs, + offs_m_curr, + offs_n, + PAD_ROWS or HEADS_PADDED, + PAD_ROWS or HEADS_PADDED, + actual_seqlen_q, + actual_seqlen_k + ) + + # Recompute P_ij = softmax(qk, dim=-1).T + qk = tl.dot(q, tl.trans(k)) + if BIAS_ON: + qk += bias / softmax_scale # TODO: check if this is optimal + + # Attention and causal mask + offs_n_causal = (offs_n - actual_seqlen_k + actual_seqlen_q) + if MASKED: + if PAD_COLS: + if IS_CAUSAL: + qk = tl.where( + tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] >= offs_n_causal[None, :], qk, float("-inf") + ) + else: + qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) + elif IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf")) + tl.debug_barrier() + + p = tl.exp2(qk * (softmax_scale * 1.44269504089) - lse_i[:, None]) + + # Account for fully masked lines + if MASKED: + if fully_masked_lines > 0: + p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p) + + # Load the gradient of O + do = load_fn(do_ptrs, offs_m_curr, offs_d, PAD_ROWS, HEADS_PADDED, actual_seqlen_q, headdim) + + # Compute the gradient of V + dv += tl.dot(tl.trans(p).to(do.dtype), do) + + # Compute auxiliary gradients + dp = tl.dot(do, tl.trans(v)) + + # Compute the gradient of the scores. Placing the substraction before the matmul apparently speeds up the process + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + + return dk, dv + + +@triton.jit +def _compute_column_blocks_dkdv( + I_start_n, + Q, + K, + V, + Bias, + Dropout, + DO, + DK, + DV, + LSE, + D, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dkn, + stride_dvn, + actual_seqlen_q, + actual_seqlen_k, + headdim, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + USE_DROPOUT: tl.constexpr, + PAD_COLS: tl.constexpr, + HEADS_PADDED: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + # This fuction goes through a column, so it always ends at m = actual_seqlen_q but can start early due to causality + I_begin_m = max(I_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0 + I_begin_m = (I_begin_m // BLOCK_M) * BLOCK_M + I_end_m = actual_seqlen_q + + fully_masked_lines = (actual_seqlen_q - actual_seqlen_k) if IS_CAUSAL else 0 + # Since we are in a grid dimensionned to fit max_seqlen_q, some blocks may exist early + if (I_begin_m >= actual_seqlen_q) or (I_start_n >= actual_seqlen_k): + return + + # Initialize offsets + offs_n = I_start_n + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize states pointer + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_d[None, :]) + # ...and maybe bias + if BIAS_ON: + bias_ptrs = Bias + (offs_m[:, None] * stride_bm + offs_n[None, :]) + else: + bias_ptrs = None + # ...and maybe dropout + if USE_DROPOUT: + dropout_offs = Dropout + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] + else: + dropout_offs = None + + # Initialize dv and dk + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + + # Load K and V, which will stay in SRAM for the row-wise loop + k = load_fn( + k_ptrs, offs_n, offs_d, + PAD_AXIS_0=PAD_COLS, PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + ) + v = load_fn( + v_ptrs, offs_n, offs_d, + PAD_AXIS_0=PAD_COLS, PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + ) + + # Loop over rows to compute dk and dv + first_full_row = max(0, I_start_n + BLOCK_N - 1 + actual_seqlen_q - actual_seqlen_k) + first_full_block = BLOCK_M * ( + (min(first_full_row, actual_seqlen_q) + BLOCK_M - 1) // BLOCK_M + ) + num_masked_blocks = (first_full_block - I_begin_m) // BLOCK_M if IS_CAUSAL else 0 + I_next_start_m = I_begin_m + + # Partially masked blocks + if num_masked_blocks > 0: + for _ in range(0, num_masked_blocks): + dk, dv = _compute_single_block_dkdv( + I_next_start_m, + k, + v, + dk, + dv, + LSE, + D, + offs_m, + offs_n, + offs_d, + q_ptrs, + bias_ptrs, + dropout_offs, + do_ptrs, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_bm, + stride_dom, + actual_seqlen_q, + actual_seqlen_k, + fully_masked_lines, + headdim, + MASKED=True, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + PAD_ROWS=True, # TODO: fix this + PAD_COLS=PAD_COLS, + HEADS_PADDED=HEADS_PADDED, + ) + I_next_start_m += BLOCK_M + + # Full blocks + if I_next_start_m < I_end_m: + for I_start_m in range(I_next_start_m, I_end_m, BLOCK_M): + dk, dv = _compute_single_block_dkdv( + I_start_m, + k, + v, + dk, + dv, + LSE, + D, + offs_m, + offs_n, + offs_d, + q_ptrs, + bias_ptrs, + dropout_offs, + do_ptrs, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_bm, + stride_dom, + actual_seqlen_q, + actual_seqlen_k, + fully_masked_lines, + headdim, + MASKED=False, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + PAD_ROWS=True, # TODO: fix this + PAD_COLS=PAD_COLS, + HEADS_PADDED=HEADS_PADDED, + ) + + # Store dk and dv + if HEADS_PADDED: + if PAD_COLS: + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim)) + else: + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + else: + if PAD_COLS: + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < actual_seqlen_k) + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < actual_seqlen_k) + else: + tl.store(dk_ptrs, dk) + tl.store(dv_ptrs, dv) diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_dq.py b/src/liger_kernel/ops/flash_attention/backward/compute_dq.py new file mode 100644 index 00000000..38c2a606 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/backward/compute_dq.py @@ -0,0 +1,261 @@ +import triton +import triton.language as tl + +from src.liger_kernel.ops.flash_attention.utils import load_fn + + +@triton.jit +def _compute_single_block_dq( + I_start_n, + q, + dq, + do, + lse_i, + delta_i, + offs_m, + offs_n, + offs_d, + k_ptrs, + v_ptrs, + bias_ptrs, + dropout_offs, + softmax_scale, + dropout_p, + dropout_seed, + stride_kn, + stride_vn, + actual_seqlen_q, + actual_seqlen_k, + headdim, + MASKED: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + USE_DROPOUT: tl.constexpr, + PAD_COLS: tl.constexpr, + HEADS_PADDED: tl.constexpr, +): + # Relocate pointers and offsets + k_ptrs = k_ptrs + I_start_n * stride_kn + v_ptrs = v_ptrs + I_start_n * stride_vn + offs_n_curr = I_start_n + offs_n + if BIAS_ON: + bias_ptrs += I_start_n + if USE_DROPOUT: + dropout_offs += I_start_n + + # Load K, V and LSE now to reduce pipeline stall + k = load_fn(k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim) + v = load_fn(v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim) + if BIAS_ON: + bias = load_fn(bias_ptrs, offs_m, offs_n_curr, True, PAD_COLS, actual_seqlen_q, actual_seqlen_k) # TODO: pad rows + + # Recompute P_ij = softmax(qk, dim=-1).T + qk = tl.dot(q, tl.trans(k)) + if BIAS_ON: + qk += bias / softmax_scale # TODO: check if this is optimal + + offs_n_causal = (offs_n_curr - actual_seqlen_k + actual_seqlen_q) + + # Attention and causal mask + if MASKED: + if PAD_COLS: + if IS_CAUSAL: + qk = tl.where(tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :], qk, float("-inf")) + else: + qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) + elif IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf")) + tl.debug_barrier() + + p = tl.exp2(qk * (softmax_scale * 1.44269504089) - lse_i[:, None]) + dp = tl.dot(do, tl.trans(v)) + + ds = (p * (dp - delta_i[:, None]) * softmax_scale).to(q.dtype) + + # compute dq + dq += tl.dot(ds, k) + + return dq + + +@triton.jit +def _compute_row_blocks_dq( + I_start_m, + Q, + K, + V, + Bias, + Dropout, + DO, + DQ, + LSE, + D, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + actual_seqlen_q, + actual_seqlen_k, + headdim, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + USE_DROPOUT: tl.constexpr, + PAD_ROWS: tl.constexpr, + HEADS_PADDED: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_N: tl.constexpr, +): + # This fuction goes through a row, so it always starts at i = 0 but the end can vary because of causality + if IS_CAUSAL: + I_end_n = min(actual_seqlen_k - actual_seqlen_q + I_start_m + BLOCK_M, actual_seqlen_k) + # For a seqlen_q >> seqlen_k, there migh be entire block skipped + if I_end_n < 0: + return + else: + I_end_n = actual_seqlen_k + # Compute the number of fully masked lines + fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 + # Exit if the block is fully masked or the current row is greater than the actual sequence length + if (I_start_m >= actual_seqlen_q) or (fully_masked_lines >= I_start_m + BLOCK_M): + return + + # Initialize offsets + offs_m = tl.arange(0, BLOCK_M) + I_start_m + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # Initialize value-related pointer (not stats-related) + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + dq_ptrs = DQ + (offs_m[:, None] * stride_dqm + offs_d[None, :]) + do_ptrs = DO + (offs_m[:, None] * stride_dom + offs_d[None, :]) + if BIAS_ON: + bias_ptrs = Bias + (offs_m[:, None] * stride_bm + offs_n[None, :]) + else: + bias_ptrs = None + if USE_DROPOUT: + dropout_offs = Dropout + (offs_m[:, None] * stride_bm + offs_n[None, :]) + else: + dropout_offs = None + + # Initialize the dq accumulator + dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # Load Q, DO, LSE and D, which will stay in SRAM for the row-wise loop + q = load_fn( + q_ptrs, offs_m, offs_d, + PAD_AXIS_0=PAD_ROWS, PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + ) + do = load_fn( + do_ptrs, offs_m, offs_d, + PAD_AXIS_0=PAD_ROWS, PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + ) + lse_i = tl.load(LSE + offs_m) # since lse is padded to max_seqlen_q, should be good + delta_i = tl.load(D + offs_m) # same as LSE for now + + # Infer the number of full and partially masked blocks + uneven_n = (actual_seqlen_k % BLOCK_N != 0) + attention_padding = VARLEN & uneven_n + if IS_CAUSAL: + first_masked_col = I_start_m + 1 + actual_seqlen_k - actual_seqlen_q + elif attention_padding: + first_masked_col = actual_seqlen_k + else: + first_masked_col = I_end_n + nb_full_blocks = first_masked_col // BLOCK_N + + # Loop over rows to compute dk and dv + I_next_start_n = 0 + if nb_full_blocks > 0: + for _ in range(0, nb_full_blocks): + I_next_start_n = tl.multiple_of(I_next_start_n, BLOCK_N) + dq = _compute_single_block_dq( + I_next_start_n, + q, + dq, + do, + lse_i, + delta_i, + offs_m, + offs_n, + offs_d, + k_ptrs, + v_ptrs, + bias_ptrs, + dropout_offs, + softmax_scale, + dropout_p, + dropout_seed, + stride_kn, + stride_vn, + actual_seqlen_q, + actual_seqlen_k, + headdim, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + MASKED=False, + PAD_COLS=False, + HEADS_PADDED=HEADS_PADDED, + ) + I_next_start_n += BLOCK_N + + if I_next_start_n < I_end_n: + for I_start_n in range(I_next_start_n, I_end_n, BLOCK_N): + pad_cols = (not EVEN_N) or (VARLEN and (I_start_n + BLOCK_N > actual_seqlen_k)) + dq = _compute_single_block_dq( + I_start_n, + q, + dq, + do, + lse_i, + delta_i, + offs_m, + offs_n, + offs_d, + k_ptrs, + v_ptrs, + bias_ptrs, + dropout_offs, + softmax_scale, + dropout_p, + dropout_seed, + stride_kn, + stride_vn, + actual_seqlen_q, + actual_seqlen_k, + headdim, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + MASKED=True, + PAD_COLS=pad_cols, + HEADS_PADDED=HEADS_PADDED, + ) + + # Account for fully masked lines + if fully_masked_lines > 0: + dq = tl.where(offs_m[:, None] < fully_masked_lines, 0, dq) + + # Store dq + if HEADS_PADDED: + if PAD_ROWS: + tl.store(dq_ptrs, dq, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)) + else: + tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim) + else: + if PAD_ROWS: + tl.store(dq_ptrs, dq, mask=offs_m[:, None] < actual_seqlen_q) + else: + tl.store(dq_ptrs, dq) diff --git a/src/liger_kernel/ops/flash_attention/backward/kernel.py b/src/liger_kernel/ops/flash_attention/backward/kernel.py new file mode 100644 index 00000000..560f2bfd --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/backward/kernel.py @@ -0,0 +1,182 @@ +import math +from typing import Any, Dict, List + +import triton +import triton.language as tl +from triton import Config + +from src.liger_kernel.ops.flash_attention.backward.compute_dkdv import _compute_column_blocks_dkdv +from src.liger_kernel.ops.flash_attention.backward.compute_dq import _compute_row_blocks_dq + +MIN_B = 16 + + +def early_config_prune_bwd_kernel( + configs: List[Config], + named_args: Dict[str, Any], + **kwargs, +) -> List[Config]: + # Remove the configs where BLOCK_ > seqlen_ + kept_configs = [] + for cfg in configs: + block_m_too_large = max(cfg.kwargs["BLOCK_M1"], cfg.kwargs["BLOCK_M2"]) > named_args["seqlen_q"] + block_n_too_large = max(cfg.kwargs["BLOCK_N1"], cfg.kwargs["BLOCK_N2"]) > named_args["seqlen_k"] + if (block_m_too_large or block_n_too_large): + pass + else: + kept_configs.append(cfg) + # If no config is left, go for the minimal config + if kept_configs: + return kept_configs + return [Config({"BLOCK_M1": MIN_B, "BLOCK_N1": MIN_B, "BLOCK_M2": MIN_B, "BLOCK_N2": MIN_B}, num_warps=4, num_stages=0)] + + +@triton.autotune( + configs=[ + Config({"BLOCK_M1": MIN_B, "BLOCK_N1": MIN_B, "BLOCK_M2": MIN_B, "BLOCK_N2": MIN_B}, num_warps=4, num_stages=0), + Config({"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32}, num_warps=4, num_stages=0), + Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32}, num_warps=4, num_stages=0), + Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64}, num_warps=4, num_stages=0), + Config({"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64}, num_warps=4, num_stages=0), + ], + key=[ + "CACHE_KEY_SEQLEN_Q", + "CACHE_KEY_SEQLEN_K", + "DTYPE", + "VARLEN", + "USE_DROPOUT", + "IS_CAUSAL", + "BIAS_ON", + "BLOCK_HEADDIM", + ], + prune_configs_by={"early_config_prune": early_config_prune_bwd_kernel}, +) +@triton.heuristics( + { + "EVEN_M1": lambda args: args["seqlen_q"] % args["BLOCK_M1"] == 0, + "EVEN_N1": lambda args: args["seqlen_k"] % args["BLOCK_N1"] == 0, + "EVEN_M2": lambda args: args["seqlen_q"] % args["BLOCK_M2"] == 0, + "EVEN_N2": lambda args: args["seqlen_k"] % args["BLOCK_N2"] == 0, + "HEADS_PADDED": lambda args: args["headdim"] != args["BLOCK_HEADDIM"], + "NUM_BLOCKS_KV": lambda args: math.ceil(args["seqlen_k"] / args["BLOCK_N1"]), + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + dropout_p, + dropout_seed, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads_q, + head_ratio, + seqlen_q, + cum_seqlens_q, + seqlen_k, + cum_seqlens_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + DTYPE, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + USE_DROPOUT: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + NUM_BLOCKS_KV: tl.constexpr, + EVEN_M1: tl.constexpr, + EVEN_N1: tl.constexpr, + EVEN_M2: tl.constexpr, + EVEN_N2: tl.constexpr, + HEADS_PADDED: tl.constexpr, +): + # Locate kernel inside the grid + pid = tl.program_id(0) + off_head_and_batch = tl.program_id(1) + off_batch = off_head_and_batch // nheads_q + off_head_q = off_head_and_batch % nheads_q + off_head_kv = off_head_q // head_ratio + + # If in variable length mode, retrieve the actual sequence lengths + if VARLEN: + cu_seq_start_q = tl.load(cum_seqlens_q + off_batch) + cu_seq_start_k = tl.load(cum_seqlens_k + off_batch) + actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - cu_seq_start_q + actual_seqlen_k = tl.load(cum_seqlens_k + off_batch + 1) - cu_seq_start_k + off_batch = 0 + else: + cu_seq_start_q = 0 + cu_seq_start_k = 0 + actual_seqlen_q = seqlen_q + actual_seqlen_k = seqlen_k + + # Offset matrix pointers for batch and head + Q += off_batch * stride_qb + off_head_q * stride_qh + cu_seq_start_q * stride_qm + K += off_batch * stride_kb + off_head_kv * stride_kh + cu_seq_start_k * stride_kn + V += off_batch * stride_vb + off_head_kv * stride_vh + cu_seq_start_k * stride_vn + DO += off_batch * stride_dob + off_head_q * stride_doh + cu_seq_start_q * stride_dom + DQ += off_batch * stride_dqb + off_head_q * stride_dqh + cu_seq_start_q * stride_dqm + DK += off_batch * stride_dkb + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn + DV += off_batch * stride_dvb + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn + if BIAS_ON: + Bias += off_batch * stride_bb + off_head_q * stride_bh + cu_seq_start_q * stride_bm + if USE_DROPOUT: + Dropout = actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch)) + else: + Dropout = None + + # Offset vector pointers for batch and head + D += off_head_and_batch * seqlen_q_rounded + LSE += off_head_and_batch * seqlen_q_rounded + + # Case: this block works on dk and dv + if pid < NUM_BLOCKS_KV: + i_start_n = pid + pad_cols = (not EVEN_N1) or (VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k)) + _compute_column_blocks_dkdv( + i_start_n * BLOCK_N1, + Q, K, V, Bias, Dropout, DO, DK, DV, LSE, D, + softmax_scale, dropout_p, dropout_seed, + stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dkn, stride_dvn, + actual_seqlen_q, actual_seqlen_k, headdim, + IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, USE_DROPOUT=USE_DROPOUT, + PAD_COLS=pad_cols, HEADS_PADDED=HEADS_PADDED, + BLOCK_M=BLOCK_M1, BLOCK_N=BLOCK_N1, BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # Case: this block works on dq + else: + i_start_m = pid - NUM_BLOCKS_KV + pad_rows = (not EVEN_M2) or (VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q)) + _compute_row_blocks_dq( + i_start_m * BLOCK_M2, + Q, K, V, Bias, Dropout, DO, DQ, LSE, D, + softmax_scale, dropout_p, dropout_seed, + stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, + actual_seqlen_q, actual_seqlen_k, headdim, + VARLEN=VARLEN, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, USE_DROPOUT=USE_DROPOUT, + PAD_ROWS=pad_rows, HEADS_PADDED=HEADS_PADDED, + BLOCK_M=BLOCK_M2, BLOCK_N=BLOCK_N2, BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_N=EVEN_N2, + ) diff --git a/src/liger_kernel/ops/flash_attention/forward/caller.py b/src/liger_kernel/ops/flash_attention/forward/caller.py new file mode 100644 index 00000000..89ef7d51 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/forward/caller.py @@ -0,0 +1,121 @@ +import math +from typing import Optional, Tuple + +import torch +import triton +from torch import Tensor + +from src.liger_kernel.ops.flash_attention.forward.kernel import _fwd_kernel +from src.liger_kernel.ops.flash_attention.utils import attention_pack, attention_unpack, torch_ignore_deterministic, infer_bias_strides, handle_dropout, encode_dtype + + +def _flash_attn_forward( + q: Tensor, # [batch_size, seqlen_q, num_heads_q, head_dim] + k: Tensor, # [batch_size, seqlen_k, num_heads_kv, head_dim] + v: Tensor, # [batch_size, seqlen_k, num_heads_kv, head_dim] + attention_mask: Optional[Tensor], # [batch_size, seqlen_qk] + bias: Optional[Tensor], # [1 | batch_size, 1 | num_heads_q, seqlen_q, seqlen_k] + dropout_p: float = 0.0, + causal: bool = False, + softmax_scale: Optional[float] = None, + dropout_seed: Optional[int] = None, +) -> Tuple[Tensor, Tensor, float, int]: + + # Currently, variable length (varlen) mode is mutually exclusive with attention masking (TODO) + if attention_mask is not None: + varlen_mode = True + assert bias is None, "Attention mask is not supported along with attention bias. Just use bias instead." + assert q.size(1) == k.size(1), "Attention mask is not supported with seqlen_q != seqlen_k" + else: + varlen_mode = False + + # Retrieve and check shapes (TODO: remove as much as possible of those) + batch, seqlen_q, nheads_q, head_dim = q.shape + _, seqlen_k, nheads_kv, _ = k.shape + expected_kv_shape = (batch, seqlen_k, nheads_kv, head_dim) + assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" + assert k.shape == expected_kv_shape, f"{k.shape = } <> {expected_kv_shape = }" + assert v.shape == expected_kv_shape, f"{v.shape = } <> {expected_kv_shape = }" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + + # Depending on attention_mask, switch to varlen + varlen_mode = varlen_mode and (batch > 1) + if varlen_mode: + # Compute padding-related statistics + cum_seqlens_q = torch.zeros(size=(attention_mask.size(0) + 1,), device=attention_mask.device, dtype=torch.int32) + with torch_ignore_deterministic(): + cum_seqlens_q[1:] = attention_mask.sum(dim=1).cumsum(0) + # cum_seqlens_q = [0, seqlen_q1, seqlen_q1+seqlen_q2, ..., seqlen_q1+...+seqlen_qB] of shape [B+1] + max_seqlen_q: int = attention_mask.size(1) + max_seqlen_k: int = attention_mask.size(1) + # Collate all matrices + q = attention_pack(q, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + k = attention_pack(k, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + v = attention_pack(v, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + # Update seqlens + seqlen_q = q.size(1) + else: + cum_seqlens_q = None + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # Account for bias and dropout + stride_bb, stride_bh, stride_bm = infer_bias_strides(bias, batch, nheads_q, seqlen_q, seqlen_k) + dropout_seed = handle_dropout(dropout_p, dropout_seed, is_forward=True) + + # Setup output accumulator + o = torch.zeros_like(q) + + # Setup LSE accumulators: in varlen mode, batch is still equal to the nb of queries + max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 # wastefull in varlen and not (just use mask) + lse = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), device=q.device, dtype=torch.float32) + + # Infer problem size and launch kernel + BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) + PADDED_HEADS = BLOCK_HEADDIM > head_dim + # BLOCK = 128 + # num_warps = 4 if head_dim <= 64 else 8 + head_ratio = nheads_q // nheads_kv + grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch * nheads_q) # noqa: E731 + _fwd_kernel[grid]( + q, + k, + v, + o, + lse, + bias, + softmax_scale, + dropout_p, + dropout_seed, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + o.stride(0), o.stride(2), o.stride(1), + stride_bb, stride_bh, stride_bm, + nheads_q, + head_ratio, + seqlen_q, + cum_seqlens_q, # array containing [seqlen_q_1, ..., seqlen_q_B] , if VARLEN, else None + seqlen_k, + max_seqlen_q_rounded, + head_dim, + max_seqlen_q // 128, + max_seqlen_k // 128, # key for triton cache (limit number of compilations) + encode_dtype(q), + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # VARLEN=varlen_mode, IS_CAUSAL=causal, BLOCK_HEADDIM=d, + VARLEN=varlen_mode, + USE_DROPOUT=(dropout_p > 0), + IS_CAUSAL=causal, + BIAS_ON=(bias is not None), + BLOCK_HEADDIM=BLOCK_HEADDIM, + PADDED_HEADS=PADDED_HEADS, + ) + + # When in variable length mode, we need to unpack the packed tensors + if varlen_mode: + o = attention_unpack(o, cum_seqlens_q, *attention_mask.shape) + + return o, lse, softmax_scale, dropout_seed # softmax_scale could have been updated diff --git a/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py b/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py new file mode 100644 index 00000000..f717cf14 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl + +from src.liger_kernel.ops.flash_attention.utils import load_fn + + +@triton.jit +def compute_row_block( + q, + m_i, + lse_i, + k_ptrs, + v_ptrs, + bias_ptrs, + acc_o, + offs_m, + offs_n, + offs_d, + softmax_scale, + dropout_p, + dropout_seed, + dropout_offs, + stride_kn, + stride_vn, + I_start_n, + actual_seqlen_q, + actual_seqlen_k, + headdim, + USE_DROPOUT: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + MASKED: tl.constexpr, + PADDED_COLS: tl.constexpr, + PADDED_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + I_start_n = tl.multiple_of(I_start_n, BLOCK_N) + + # Load K (same mechanism as for Q, only check cols instead of rows) + offset_k_ptrs = k_ptrs + I_start_n * stride_kn + k = load_fn( + offset_k_ptrs, + I_start_n + offs_n, offs_d, + PAD_AXIS_0=PADDED_COLS, PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + ) + if BIAS_ON: + bias = load_fn( + bias_ptrs + I_start_n, + offs_m, I_start_n + offs_n, + PAD_AXIS_0=True, PAD_AXIS_1=PADDED_COLS, # check + LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=actual_seqlen_k, + ) + + # Compute QK + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + # Apply attention masking and/or account for padding of the keys + if PADDED_COLS: # TODO: check impact on speed when conditionned by MASKED (always true?) + qk += tl.where((I_start_n + offs_n)[None, :] < actual_seqlen_k, 0, float("-inf")) + # Apply causal mask + if MASKED and IS_CAUSAL: + causal_mask = offs_m[:, None] >= (I_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :] + qk += tl.where(causal_mask, 0, float("-inf")) + + if BIAS_ON: + qk += bias * (1.44269504089 / softmax_scale) # TODO: check if this is optimal + + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + P_ij = tl.exp2(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(P_ij, 1) + + # Dropout + if USE_DROPOUT: + dropout_offs = dropout_offs + I_start_n + dropout_mask = (tl.rand(dropout_seed, dropout_offs) > dropout_p) # TODO: replace this w/ randint for better perfs + P_ij = tl.where(dropout_mask, P_ij, 0.0) + + # Scale the output accumulator + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + + # Load V (same mechanism as K) + offset_v_ptrs = v_ptrs + I_start_n * stride_vn + v = load_fn( + offset_v_ptrs, + I_start_n + offs_n, offs_d, + PAD_AXIS_0=PADDED_COLS, PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + ) + + # Update the output accumulator + P_ij = P_ij.to(v.dtype) + acc_o += tl.dot(P_ij, v) + + # Update the statistics + m_i = m_ij + l_i_new = tl.exp2(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log2(l_i_new) + + return m_i, lse_i, acc_o diff --git a/src/liger_kernel/ops/flash_attention/forward/kernel.py b/src/liger_kernel/ops/flash_attention/forward/kernel.py new file mode 100644 index 00000000..e7eeec03 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/forward/kernel.py @@ -0,0 +1,291 @@ +import triton +import triton.language as tl +from triton import Config + +from typing import List, Any, Dict +from src.liger_kernel.ops.flash_attention.forward.compute_row_blocks import compute_row_block +from src.liger_kernel.ops.flash_attention.utils import load_fn + +# TODO: exit causal blocks early +# TODO: can we initialize accO to empty instead of 0? + +MIN_B = 32 + + +def early_config_prune_fwd_kernel( + configs: List[Config], + named_args: Dict[str, Any], + **kwargs, +) -> List[Config]: + # Remove the configs where BLOCK_ > seqlen_ + kept_configs = [] + for cfg in configs: + block_m_too_large = cfg.kwargs["BLOCK_M"] > named_args["seqlen_q"] + block_n_too_large = cfg.kwargs["BLOCK_N"] > named_args["seqlen_k"] + if (block_m_too_large or block_n_too_large): + pass + else: + kept_configs.append(cfg) + # If no config is left, go for the minimal config + if kept_configs: + return kept_configs + return [Config({"BLOCK_M": MIN_B, "BLOCK_N": MIN_B}, num_warps=4, num_stages=1)] + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": MIN_B, "BLOCK_N": MIN_B}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=4, num_stages=1), + ], + key=[ + "CACHE_KEY_SEQLEN_Q", + "CACHE_KEY_SEQLEN_K", + "DTYPE", + "VARLEN", + "USE_DROPOUT", + "IS_CAUSAL", + "BIAS_ON", + "BLOCK_HEADDIM", + ], + prune_configs_by={"early_config_prune": early_config_prune_fwd_kernel}, +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Out, + Lse, + Bias, + softmax_scale, + dropout_p, + dropout_seed, + stride_qb, stride_qh, stride_qm, # Q stride for the batch, head and sequence axis (sequence subscript is m for rows) + stride_kb, stride_kh, stride_kn, # Same for K (sequence subscript is n for cols) + stride_vb, stride_vh, stride_vn, # Same for V (sequence subscript is n for cols) + stride_ob, stride_oh, stride_om, # Same for O (sequence subscript is m for rows) + stride_bb, stride_bh, stride_bm, + nheads_q, + head_ratio, + seqlen_q, + cum_seqlens_q, + seqlen_k, + max_seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + DTYPE, + VARLEN: tl.constexpr, + USE_DROPOUT: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BIAS_ON: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + PADDED_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Locate kernel inside the grid + i_start_m = tl.program_id(0) # current block in the Q matrix + off_head_and_batch = tl.program_id(1) + off_head_q = off_head_and_batch % nheads_q + off_head_kv = off_head_q // head_ratio + off_batch = off_head_and_batch // nheads_q + + # Infer actual sequence length of Q and the offset to the last sequence + if VARLEN: + cu_seq_start_q = tl.load(cum_seqlens_q + off_batch) + actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - cu_seq_start_q + if i_start_m * BLOCK_M >= actual_seqlen_q: + return + actual_seqlen_k = actual_seqlen_q # TODO: support packed + varlen? rn, check is done outside + cu_seq_start_k = cu_seq_start_q + off_batch = 0 + else: + actual_seqlen_q = seqlen_q + actual_seqlen_k = seqlen_k + cu_seq_start_q = 0 + cu_seq_start_k = 0 + + softmax_scale = softmax_scale * 1.44269504089 + # Initialize offsets + offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + # When in VARLEN mode, since we dimension the grid to be large enough for all sequences, the + # current sequence might have less rows than the current row (detemined through the grid). + + fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 + if fully_masked_lines >= (i_start_m+1) * BLOCK_M: + return + + # Initialize pointers to Q, K, V + offseted_Q = Q + off_batch * stride_qb + off_head_q * stride_qh + cu_seq_start_q * stride_qm + q_ptrs = (offseted_Q + (offs_m[:, None] * stride_qm + offs_d[None, :])) + offseted_K = K + off_batch * stride_kb + off_head_kv * stride_kh + cu_seq_start_k * stride_kn + k_ptrs = (offseted_K + (offs_n[:, None] * stride_kn + offs_d[None, :])) + offseted_V = V + off_batch * stride_vb + off_head_kv * stride_vh + cu_seq_start_k * stride_vn + v_ptrs = (offseted_V + (offs_n[:, None] * stride_vn + offs_d[None, :])) + # ...and maybe bias + if BIAS_ON: + offseted_Bias = Bias + off_batch * stride_bb + off_head_kv * stride_bh + cu_seq_start_q * stride_bm + bias_ptrs = (offseted_Bias + (offs_m[:, None] * stride_bm + offs_n[None, :])) + else: + bias_ptrs = None + # ...and maybe dropout + if USE_DROPOUT: + dropout_off = actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch)) + dropout_offs = dropout_off + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] + else: + dropout_offs = None + + # Initialize pointers to m and l + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # Load Q, which will stay in SRAM for the whole loop + pad_rows = (not EVEN_M) or (VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q)) # this works while other bools fail. Why? + q = load_fn( + q_ptrs, + offs_m, offs_d, + PAD_AXIS_0=pad_rows, PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + ) + + # Compute last visited column of KV which + if IS_CAUSAL: + end_n = min(actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M, actual_seqlen_k) + # For a seqlen_q >> seqlen_k, there migh be entire block skipped + if end_n < 0: + return + else: + end_n = actual_seqlen_k + + # first_masked_block = min(start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q, end_n) if IS_CAUSAL else end_n + uneven_n = (actual_seqlen_k % BLOCK_N != 0) + attention_padding = VARLEN & uneven_n + if IS_CAUSAL: + first_masked_col = i_start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q + elif attention_padding: + first_masked_col = actual_seqlen_k + else: + first_masked_col = end_n + nb_full_blocks = first_masked_col // BLOCK_N + + next_start_n = 0 + if nb_full_blocks > 0: + for _ in range(0, nb_full_blocks): + m_i, lse_i, acc_o = compute_row_block( + q, + m_i, + lse_i, + k_ptrs, + v_ptrs, + bias_ptrs, + acc_o, + offs_m, + offs_n, + offs_d, + softmax_scale, + dropout_p, + dropout_seed, + dropout_offs, + stride_kn, + stride_vn, + next_start_n, + actual_seqlen_q, + actual_seqlen_k, + headdim, + USE_DROPOUT=USE_DROPOUT, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + MASKED=False, + PADDED_COLS=False, + PADDED_HEADS=PADDED_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + next_start_n += BLOCK_N + + if next_start_n < end_n: + for I_start_n in range(next_start_n, end_n, BLOCK_N): + pad_cols = (not EVEN_N) or VARLEN # TODO: refine varlen side + m_i, lse_i, acc_o = compute_row_block( + q, + m_i, + lse_i, + k_ptrs, + v_ptrs, + bias_ptrs, + acc_o, + offs_m, + offs_n, + offs_d, + softmax_scale, + dropout_p, + dropout_seed, + dropout_offs, + stride_kn, + stride_vn, + I_start_n, + actual_seqlen_q, + actual_seqlen_k, + headdim, + USE_DROPOUT=USE_DROPOUT, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + MASKED=True, + PADDED_COLS=pad_cols, + PADDED_HEADS=PADDED_HEADS, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + # Final scaling of the output accumulator + if USE_DROPOUT: + o_scale = tl.exp2((m_i - lse_i) - tl.log2(1 - dropout_p)) + else: + o_scale = tl.exp2(m_i - lse_i) + acc_o = acc_o * o_scale[:, None] + + # For seqlen_q >> seqlen_k, there might be entire lines masked, so we account for that + if fully_masked_lines > i_start_m * BLOCK_M: + acc_o = tl.where(offs_m[:, None] < fully_masked_lines, 0, acc_o) + + # rematerialize offsets to save registers (?) + i_start_m = tl.program_id(0) + offs_m = i_start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # Write back l and m + # Q + off_batch * stride_qb + off_head * stride_qh + cu_seq_start_q * stride_qm + lse_ptrs = Lse + off_head_and_batch * max_seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # Initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_batch * stride_ob + + off_head_q * stride_oh + + cu_seq_start_q * stride_om + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + + # Store O (same mechanism as Q) BUG: here, the store instruction seems to fail when one of the two bools is false + if True: + tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)) + elif pad_rows: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < actual_seqlen_q) + elif PADDED_HEADS: # nothing is padded + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: # only heads are padded + tl.store(out_ptrs, acc_o) diff --git a/src/liger_kernel/ops/flash_attention/reference_implementation.py b/src/liger_kernel/ops/flash_attention/reference_implementation.py new file mode 100644 index 00000000..f113ada7 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/reference_implementation.py @@ -0,0 +1,129 @@ +import math + +import torch +from einops import rearrange, repeat + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, + key_leftpad=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og) diff --git a/src/liger_kernel/ops/flash_attention/utils.py b/src/liger_kernel/ops/flash_attention/utils.py new file mode 100644 index 00000000..6bdd9713 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/utils.py @@ -0,0 +1,109 @@ +import torch +import triton +import triton.language as tl +from torch import Tensor +from typing import Tuple, Optional + + +def attention_pack( + x: torch.Tensor, # [batch_size, seqlen, num_heads, head_dim] + attention_mask: torch.Tensor, # [batch_size, seqlen] +) -> torch.Tensor: + to_pack = [] + for i, attn_mask in enumerate(attention_mask): + seqlen = attn_mask.sum().int().item() + kept = x[i, :seqlen] # [seqlen, num_heads, head_dim] + to_pack.append(kept) + return torch.concatenate(to_pack, dim=0).unsqueeze(0) + + +def attention_unpack( + x: torch.Tensor, # [1, sum_seqlens, num_heads, head_dim] + cum_seqlens: torch.Tensor, # [0, seqlen_1, seqlen_1+seqlen_2, ...] + batch_size: int, + goal_seqlen: int, +) -> torch.Tensor: + unpacked = torch.zeros(size=(batch_size, goal_seqlen, *x.shape[2:]), dtype=x.dtype, device=x.device) + for i in range(cum_seqlens.size(0)-1): + seq_start = cum_seqlens[i] + seq_end = cum_seqlens[i+1] + unpacked[i, :seq_end-seq_start] = x[0, seq_start:seq_end] + return unpacked + + +@triton.jit +def load_fn( + ptrs, + offs_axis_0: tl.const_pointer_type, + offs_axis_1: tl.const_pointer_type, + PAD_AXIS_0: tl.constexpr, + PAD_AXIS_1: tl.constexpr, + LIM_AXIS_0: tl.constexpr, + LIM_AXIS_1: tl.constexpr, +): + if PAD_AXIS_0: + if PAD_AXIS_1: + x = tl.load(ptrs, mask=(offs_axis_0[:, None] < LIM_AXIS_0) & (offs_axis_1[None, :] < LIM_AXIS_1), other=0.0) + else: + x = tl.load(ptrs, mask=offs_axis_0[:, None] < LIM_AXIS_0, other=0.0) + else: + if PAD_AXIS_1: + x = tl.load(ptrs, mask=offs_axis_1[None, :] < LIM_AXIS_1, other=0.0) + else: + x = tl.load(ptrs) + return x + + +def infer_bias_strides( + bias: Optional[Tensor], batch: int, nheads_q: int, seqlen_q: int, seqlen_k: int, +) -> Tuple[int, ...]: + if bias is not None: + assert (bias.size(2) == seqlen_q and bias.size(3) == seqlen_k), f"{bias.shape = }" + if bias.size(0) == 1: + stride_bb = 0 + elif bias.size(0) == batch: + stride_bb = bias.stride(0) + else: + raise ValueError(f"Attention bias has {bias.size(0) = } while {batch = }") + if bias.size(1) == 1: + stride_bh = 0 + elif bias.stride(1) == nheads_q: + stride_bh = bias.stride(1) + else: + raise ValueError(f"Attention bias has {bias.size(1) = } while {nheads_q = }") + stride_bm = bias.stride(2) + else: + stride_bb, stride_bh, stride_bm = 0, 0, 0 + return stride_bb, stride_bh, stride_bm + + +def handle_dropout(dropout_p: float, dropout_seed: Optional[int], is_forward: bool) -> int: + assert dropout_p >= 0, f"Dropout probability {dropout_p = } must be above 0." + assert dropout_p < 1, f"Dropout probability {dropout_p = } must be strictly below 1." + if dropout_p == 0: + return 0 + elif is_forward: + return torch.randint(low=0, high=2**32, size=(1,)).item() if dropout_seed is None else dropout_seed + else: + raise NotImplementedError("Backward pass does not yet support dropout.") + + +class torch_ignore_deterministic: + def __enter__(self): + self.previous_mode = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(False) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type: + raise exc_val + torch.use_deterministic_algorithms(self.previous_mode) + + +def encode_dtype(x: Tensor) -> int: + if x.dtype == torch.float16: + return 16 + if x.dtype == torch.bfloat16: + return 17 + if x.dtype == torch.float32: + return 32 + raise ValueError(x.dtype) diff --git a/src/liger_kernel/ops/flash_attention/wrapper.py b/src/liger_kernel/ops/flash_attention/wrapper.py new file mode 100644 index 00000000..66fe83e3 --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/wrapper.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch +from torch import Tensor + +from src.liger_kernel.ops.flash_attention.backward.caller import _flash_attn_backward +from src.liger_kernel.ops.flash_attention.forward.caller import _flash_attn_forward + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q: Tensor, + k: Tensor, + v: Tensor, + attention_mask: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + dropout_p: float = 0.0, + causal: bool = False, + softmax_scale: Optional[Tensor] = None, + dropout_seed: Optional[int] = None, + ): + """ + Compute the forward pass of the FlashAttention function. + Args: + - ctx (): the autograd.Function context + - q (Tensor): the query projection tensor, of shape [batch_size, seqlen_q, num_heads, head_dim] + - k (Tensor): the key projection tensor, of shape [batch_size, seqlen_k, num_heads, head_dim] + - v (Tensor): the values projection tensor, of shape [batch_size, seqlen_k, num_heads, head_dim] + - attention_mask (Optional[Tensor]): an optional attention mask of shape [batch_size, seqlen_q]. + Forces seqlen_q == seqlen_k. + - causal (bool): a boolean to indicate whether or not to use causal attention + - softmax_scale (Optional[float]): an optional float to scale the pre-softmax attention scores. Defaults + to 1 / sqrt(head_dim) + Return: + the attention output tensor + """ + # Make sure that the last dimension is contiguous + q = q if q.stride(-1) == 1 else q.contiguous() + k = k if k.stride(-1) == 1 else k.contiguous() + v = v if v.stride(-1) == 1 else v.contiguous() + attention_bias = None if (attention_bias is None) else attention_bias.contiguous() + o, lse, ctx.softmax_scale, ctx.dropout_seed = _flash_attn_forward( + q=q, + k=k, + v=v, + attention_mask=attention_mask, + bias=attention_bias, + dropout_p=dropout_p, + causal=causal, + softmax_scale=softmax_scale, + dropout_seed=dropout_seed, + ) + ctx.save_for_backward(q, k, v, attention_bias, attention_mask, o, lse) + ctx.causal = causal + ctx.dropout_p = dropout_p + return o + + @staticmethod + def backward(ctx, do): + """ + Compute the backward pass of the FlashAttention function. + Args: + - ctx (): the autograd.Function context + - do (Tensor): the gradient of the output tensor, of shape [batch_size, seqlen_q, num_heads, head_dim] + Return: + three tensors, the gradients of q, k and v respectively (check forward for shape info) + """ + q, k, v, bias, attention_mask, o, lse = ctx.saved_tensors + dq, dk, dv = _flash_attn_backward( + dO=do, + q=q, + k=k, + v=v, + bias=bias, + attention_mask=attention_mask, + o=o, + lse=lse, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + dropout_seed=ctx.dropout_seed, + ) + return dq, dk, dv, None, None, None, None, None, None + + +def flash_attn_func( + q: Tensor, + k: Tensor, + v: Tensor, + attention_mask: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + dropout_p: float = 0.0, + causal: bool = False, + softmax_scale: Optional[Tensor] = None, + dropout_seed: Optional[int] = None, +) -> Tensor: + return FlashAttnFunc.apply(q, k, v, attention_mask, attention_bias, dropout_p, causal, softmax_scale, dropout_seed) From 8412147972207c8ec104f3f7d8c360a7c4050ccb Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 14:18:52 +0000 Subject: [PATCH 02/12] Added monkey patching for SDPA --- src/liger_kernel/transformers/attention.py | 88 ++++++++++++++++++ src/liger_kernel/transformers/model/gemma2.py | 89 ++++++++++++++++++ src/liger_kernel/transformers/model/phi3.py | 92 ++++++++++++++++++- src/liger_kernel/transformers/model/qwen2.py | 92 +++++++++++++++++++ src/liger_kernel/transformers/monkey_patch.py | 32 +++++++ 5 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 src/liger_kernel/transformers/attention.py create mode 100644 src/liger_kernel/transformers/model/gemma2.py diff --git a/src/liger_kernel/transformers/attention.py b/src/liger_kernel/transformers/attention.py new file mode 100644 index 00000000..76e5aa80 --- /dev/null +++ b/src/liger_kernel/transformers/attention.py @@ -0,0 +1,88 @@ +from typing import Optional, Tuple + +from transformers.cache_utils import Cache +import torch +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, logger, LlamaSdpaAttention + +from liger_kernel.ops.flash_attention.wrapper import flash_attn_func + + +# Adapted from LlamaSdpaAttention.forward +def liger_general_sdpa_forward( + self: LlamaSdpaAttention, # Might not always be this module in particular, but is a good general placholder + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + if output_attentions: + raise NotImplementedError("Output attentions") # TODO: support this? + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj.forward(hidden_states) + key_states = self.k_proj.forward(hidden_states) + value_states = self.v_proj.forward(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # key_states = repeat_kv(key_states, self.num_key_value_groups) not needed as we support GQA + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + attn_bias = attention_mask[:, :, :, : key_states.shape[-2]] + else: + attn_bias = None + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + attention_mask=None, + attention_bias=attn_bias, + dropout_p=(self.attention_dropout if self.training else 0.0), + causal=is_causal, + softmax_scale=None, + dropout_seed=None, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py new file mode 100644 index 00000000..655aea71 --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -0,0 +1,89 @@ +from typing import Optional, Tuple + +from transformers.cache_utils import Cache +import torch +from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb, logger, GemmaSdpaAttention + +from liger_kernel.ops.flash_attention.wrapper import flash_attn_func + + +# Copied from GemmaSdpaAttention.forward +def liger_gemma2_sdpa_forward( + self: GemmaSdpaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj.forward(hidden_states) + key_states = self.k_proj.forward(hidden_states) + value_states = self.v_proj.forward(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Commented out because we support GQA + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + attn_bias = attention_mask[:, :, :, : key_states.shape[-2]] + else: + attn_bias = None + + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + attention_mask=None, + attention_bias=attn_bias, + dropout_p=(self.attention_dropout if self.training else 0.0), + causal=is_causal, + softmax_scale=self.scaling, + dropout_seed=None, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index 4cb7ec0e..f322c3ba 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -6,12 +6,16 @@ from transformers.models.phi3.modeling_phi3 import ( _CONFIG_FOR_DOC, PHI3_INPUTS_DOCSTRING, + Phi3SdpaAttention, + Cache, + logger, + apply_rotary_pos_emb ) from transformers.utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) - +from liger_kernel.ops.flash_attention.wrapper import flash_attn_func from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) @@ -134,3 +138,89 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +def liger_phi3_sdpa_attention_forward( + self: Phi3SdpaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj.forward(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Commented out because we support GQA + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + attn_bias = attention_mask[:, :, :, : key_states.shape[-2]] + else: + attn_bias = None + + # We always need re-arranging and contiguousness + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + attention_mask=None, + attention_bias=attn_bias, + dropout_p=(self.attention_dropout if self.training else 0.0), + causal=is_causal, + softmax_scale=None, + dropout_seed=None, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index b8e9957e..896f42bc 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -6,6 +6,10 @@ from transformers.models.qwen2.modeling_qwen2 import ( _CONFIG_FOR_DOC, QWEN2_INPUTS_DOCSTRING, + Qwen2SdpaAttention, + Cache, + logger, + apply_rotary_pos_emb, ) from transformers.utils import ( add_start_docstrings_to_model_forward, @@ -15,6 +19,7 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyLoss, ) +from liger_kernel.ops.flash_attention.wrapper import flash_attn_func @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) @@ -133,3 +138,90 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +# Adaptation of Qwen2SdpaAttention.forward +def liger_qwen2_sdpa_forward( + self: Qwen2SdpaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # [Liger-Kernel modification] As we support GQa, we don't need to do this + # key_states = repeat_kv(key_states, self.num_key_value_groups) + # value_states = repeat_kv(value_states, self.num_key_value_groups) + + # [Liger-Kernel modification] Renamed the causal_mask as attn_bias, but behavior is the same + if attention_mask is not None: + attn_bias = attention_mask[:, :, :, : key_states.shape[-2]] + else: + attn_bias = None + + # [Liger-Kernel modification] We ensure contiguous-ness and transpose the axis the have [B, seqlen, heads, ...] + query_states = query_states.transpose(1, 2).contiguous() + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if attention_mask is None and q_len > 1 else False + + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + attention_mask=None, + attention_bias=attn_bias, + dropout_p=(self.attention_dropout if self.training else 0.0), + causal=is_causal, + softmax_scale=None, + dropout_seed=None, + ) + + # [Liger-Kernel modification] Already done in our FA kernel + # attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index b60e328f..87951c13 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -10,10 +10,13 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward +from liger_kernel.transformers.model.gemma2 import liger_gemma2_sdpa_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward +from liger_kernel.transformers.model.phi3 import liger_phi3_sdpa_attention_forward from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward +from liger_kernel.transformers.model.qwen2 import liger_qwen2_sdpa_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import ( @@ -21,6 +24,7 @@ LigerPhi3SwiGLUMLP, LigerSwiGLUMLP, ) +from liger_kernel.transformers.attention import liger_general_sdpa_forward logger = logging.getLogger(__name__) @@ -31,6 +35,7 @@ def apply_liger_kernel_to_llama( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -45,6 +50,7 @@ def apply_liger_kernel_to_llama( If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -65,6 +71,8 @@ def apply_liger_kernel_to_llama( modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: modeling_llama.LlamaForCausalLM.forward = llama_lce_forward + if sdpa_attention: + modeling_llama.LlamaSdpaAttention.forward = liger_general_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -105,6 +113,7 @@ def apply_liger_kernel_to_mistral( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -120,6 +129,7 @@ def apply_liger_kernel_to_mistral( rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -139,6 +149,8 @@ def apply_liger_kernel_to_mistral( modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward if swiglu: modeling_mistral.MistralMLP = LigerSwiGLUMLP + if sdpa_attention: + modeling_mistral.MistralSdpaAttention.forward = liger_general_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -176,6 +188,7 @@ def apply_liger_kernel_to_mixtral( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -190,6 +203,7 @@ def apply_liger_kernel_to_mixtral( If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -210,6 +224,8 @@ def apply_liger_kernel_to_mixtral( modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward if swiglu: modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP + if sdpa_attention: + modeling_mixtral.MixtralSdpaAttention.forward = liger_general_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -254,6 +270,7 @@ def apply_liger_kernel_to_gemma( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, geglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -269,6 +286,7 @@ def apply_liger_kernel_to_gemma( If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -293,6 +311,8 @@ def apply_liger_kernel_to_gemma( modeling_gemma.GemmaMLP = LigerGEGLUMLP if fused_linear_cross_entropy: modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward + if sdpa_attention: + modeling_gemma.GemmaSdpaAttention.forward = liger_general_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -329,6 +349,7 @@ def apply_liger_kernel_to_gemma2( cross_entropy: bool = True, rms_norm: bool = True, geglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -340,6 +361,7 @@ def apply_liger_kernel_to_gemma2( cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -355,6 +377,8 @@ def apply_liger_kernel_to_gemma2( modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss if geglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + if sdpa_attention: + modeling_gemma2.Gemma2SdpaAttention.forward = liger_gemma2_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -398,6 +422,7 @@ def apply_liger_kernel_to_qwen2( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -412,6 +437,7 @@ def apply_liger_kernel_to_qwen2( If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -431,6 +457,8 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + if sdpa_attention: + modeling_qwen2.Qwen2SdpaAttention.forward = liger_qwen2_sdpa_forward if model is not None: # The model instance already exists, so we need to additionally patch the @@ -558,6 +586,7 @@ def apply_liger_kernel_to_phi3( fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, + sdpa_attention: bool = True, model: PreTrainedModel = None, ) -> None: """ @@ -572,6 +601,7 @@ def apply_liger_kernel_to_phi3( If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. + sdpa_attention (bool): Whether to apply Liger's FlashAttention instead of SDPA. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ @@ -591,6 +621,8 @@ def apply_liger_kernel_to_phi3( modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward + if sdpa_attention: + modeling_phi3.Phi3SdpaAttention.forward = liger_phi3_sdpa_attention_forward if model is not None: # The model instance already exists, so we need to additionally patch the From 57317079e8f7de10138894b3a5cce440827138d2 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:00:42 +0000 Subject: [PATCH 03/12] Added and passed tests in test_mini_models --- test/convergence/test_mini_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index f648a88c..51bed64a 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -326,6 +326,7 @@ def run_mini_model( "rope": True, "rms_norm": True, "cross_entropy": True, + "sdpa_attention": (dtype in [torch.float16, torch.bfloat16]), } if "gemma" in model_name: kwargs["geglu"] = True @@ -364,6 +365,7 @@ def run_mini_model( [ # Gemma 1 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma1", 32, 1e-4, torch.float16, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1", 32, @@ -380,6 +382,7 @@ def run_mini_model( ), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma1.1", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1.1", 32, @@ -396,6 +399,7 @@ def run_mini_model( ), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma2", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma2", 32, @@ -412,6 +416,7 @@ def run_mini_model( ), ), ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_llama3", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_llama3", 32, @@ -432,6 +437,7 @@ def run_mini_model( # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_mistral", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_mistral", 32, @@ -448,6 +454,7 @@ def run_mini_model( ), ), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_qwen2", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_qwen2", 32, @@ -464,6 +471,7 @@ def run_mini_model( ), ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_phi3", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_phi3", 32, From 638882f8423273361311098d8ba8af461b53e6d4 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:02:39 +0000 Subject: [PATCH 04/12] Small changes to the fa ops --- src/liger_kernel/ops/flash_attention/__init__.py | 4 ++++ src/liger_kernel/ops/flash_attention/forward/caller.py | 1 + .../ops/flash_attention/reference_implementation.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 src/liger_kernel/ops/flash_attention/__init__.py diff --git a/src/liger_kernel/ops/flash_attention/__init__.py b/src/liger_kernel/ops/flash_attention/__init__.py new file mode 100644 index 00000000..2a2a7aef --- /dev/null +++ b/src/liger_kernel/ops/flash_attention/__init__.py @@ -0,0 +1,4 @@ +from .wrapper import flash_attn_func +from .reference_implementation import flash_attn_reference + +__all__ = ["flash_attn_func", "flash_attn_reference"] diff --git a/src/liger_kernel/ops/flash_attention/forward/caller.py b/src/liger_kernel/ops/flash_attention/forward/caller.py index 89ef7d51..e514e745 100644 --- a/src/liger_kernel/ops/flash_attention/forward/caller.py +++ b/src/liger_kernel/ops/flash_attention/forward/caller.py @@ -37,6 +37,7 @@ def _flash_attn_forward( assert k.shape == expected_kv_shape, f"{k.shape = } <> {expected_kv_shape = }" assert v.shape == expected_kv_shape, f"{v.shape = } <> {expected_kv_shape = }" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale diff --git a/src/liger_kernel/ops/flash_attention/reference_implementation.py b/src/liger_kernel/ops/flash_attention/reference_implementation.py index f113ada7..93304a12 100644 --- a/src/liger_kernel/ops/flash_attention/reference_implementation.py +++ b/src/liger_kernel/ops/flash_attention/reference_implementation.py @@ -39,7 +39,7 @@ def construct_local_mask( ) -def attention_ref( +def flash_attn_reference( q, k, v, From e04dc8718cc4ad26405735107a2a2f80c6bad685 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:31:29 +0000 Subject: [PATCH 05/12] Added and passed tests in no_logits (skipped in multi) --- .../test_mini_models_multimodal.py | 22 ++++++++++++++ .../convergence/test_mini_models_no_logits.py | 29 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index 4c164ba5..6c2a79ef 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -182,6 +182,7 @@ def run_mini_model_multimodal( kwargs = { "rms_norm": True, "cross_entropy": True, + "sdpa_attention": (dtype in [torch.float16, torch.bfloat16]), } model_supports_rope = "qwen2_vl" not in model_name if model_supports_rope: @@ -244,6 +245,27 @@ def run_mini_model_multimodal( reason="Qwen2-VL not available in this version of transformers", ), ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), pytest.param( "mini_qwen2_vl", 32, diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 7dfaa00f..bb54ffb8 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -350,6 +350,7 @@ def run_mini_model( if with_liger is True: kwargs = { "rms_norm": True, + "sdpa_attention": (dtype in [torch.float16, torch.bfloat16]), } model_supports_rope = "qwen2_vl" not in model_name if model_supports_rope: @@ -402,6 +403,7 @@ def run_mini_model( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), + ("mini_llama3", 32, 1e-4, torch.float16, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5), pytest.param( "mini_llama3", 32, @@ -418,6 +420,7 @@ def run_mini_model( ), ), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_qwen2", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_qwen2", 32, @@ -449,6 +452,27 @@ def run_mini_model( reason="Qwen2-VL not available in this version of transformers", ), ), + pytest.param( + "mini_qwen2_vl", + 32, + 1e-4, + torch.float16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + pytest.mark.skipif( + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", + ), + ], + ), pytest.param( "mini_qwen2_vl", 32, @@ -471,6 +495,7 @@ def run_mini_model( ], ), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_phi3", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_phi3", 32, @@ -487,6 +512,7 @@ def run_mini_model( ), ), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_mistral", 32, 1e-4, torch.float16, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_mistral", 32, @@ -521,6 +547,7 @@ def run_mini_model( # ), # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma1", 32, 1e-4, torch.float16, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1", 32, @@ -537,6 +564,7 @@ def run_mini_model( ), ), ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma1.1", 32, 1e-4, torch.float16, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma1.1", 32, @@ -553,6 +581,7 @@ def run_mini_model( ), ), ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_gemma2", 32, 1e-4, torch.float16, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), pytest.param( "mini_gemma2", 32, From 23d0df96927fe4594c72337a39753188e001f8d7 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:44:33 +0000 Subject: [PATCH 06/12] Unit tests, nearly all passed --- test/transformers/test_attention.py | 108 ++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 test/transformers/test_attention.py diff --git a/test/transformers/test_attention.py b/test/transformers/test_attention.py new file mode 100644 index 00000000..a22afd21 --- /dev/null +++ b/test/transformers/test_attention.py @@ -0,0 +1,108 @@ +import pytest +import torch +from torch import Tensor +from typing import Optional, Tuple + +from liger_kernel.ops.flash_attention import flash_attn_func, flash_attn_reference +from test.utils import set_seed + +set_seed() +DEVICE = "cuda" + + +def compare_numerical_errors( + x_ref: Tensor, + x_pt: Tensor, + x_triton: Tensor, + error_mul: float, + error_atol: float, + tensor_name: str, +) -> None: + max_pt_error = (x_pt - x_ref).abs().max().item() + max_triton_error = (x_triton - x_ref).abs().max().item() + assert max_triton_error <= max(error_mul * max_pt_error, error_atol), tensor_name + + +def _test_attention( + batch_size: int, + nheads_q: int, + nheads_kv: int, + seqlen_q: int, + seqlen_k: int, + head_dim: int, + causal: bool, + dropout_p: float, + use_attention: bool, + use_bias: bool, + dtype: torch.dtype, +) -> Optional[Tuple[Tensor, ...]]: + + # Prepare data + q = torch.normal(0, 0.5, (batch_size, seqlen_q, nheads_q, head_dim), dtype=dtype, device=DEVICE).requires_grad_() + k = torch.normal(0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE).requires_grad_() + v = torch.normal(0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE).requires_grad_() + do = torch.randn_like(q) + attn_bias = torch.rand(size=(1, 1, seqlen_q, seqlen_k), dtype=dtype, device=q.device) if use_bias else None + + # Compute the outputs of the forward pass + ref_output = flash_attn_reference(q, k, v, attn_bias=attn_bias, causal=causal, upcast=True, reorder_ops=False) + pt_output = flash_attn_reference(q, k, v, attn_bias=attn_bias, causal=causal, upcast=False, reorder_ops=True) + liger_output = flash_attn_func(q, k, v, attention_bias=attn_bias, causal=causal) + compare_numerical_errors(ref_output, pt_output, liger_output, 1, 1e-4, "output") + + # Compare the gradients after the backward pass + ref_dq, ref_dk, ref_dv = torch.autograd.grad(ref_output, (q, k, v), do, retain_graph=True) + pt_dq, pt_dk, pt_dv = torch.autograd.grad(pt_output, (q, k, v), do, retain_graph=True) + liger_dq, liger_dk, liger_dv = torch.autograd.grad(liger_output, (q, k, v), do, retain_graph=True) + compare_numerical_errors(ref_dq, pt_dq, liger_dq, 2, 1e-4, "dq") + compare_numerical_errors(ref_dk, pt_dk, liger_dk, 2, 1e-4, "dk") + compare_numerical_errors(ref_dv, pt_dv, liger_dv, 2, 1e-4, "dv") + + +@pytest.mark.parametrize("dtype", [(torch.float16), (torch.bfloat16)]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize( + "head_dim, nheads_q, nheads_kv", + [(32, 9, 9), (40, 9, 3), (64, 8, 8), (128, 8, 2), (256, 4, 2)], +) +@pytest.mark.parametrize("swap_seqlens", [False, True]) +@pytest.mark.parametrize("use_bias", [False, True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (113, 203), + (127, 512), + (128, 217), + (256, 512), + ], +) +@pytest.mark.parametrize("batch_size", [4]) +def test_fwd_bwd( + batch_size: int, + nheads_q: int, + nheads_kv: int, + seqlen_q: int, + seqlen_k: int, + swap_seqlens: bool, + head_dim: int, + causal: bool, + use_bias: bool, + dtype: torch.dtype, +) -> None: + if swap_seqlens: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + _test_attention( + batch_size=batch_size, + nheads_q=nheads_q, + nheads_kv=nheads_kv, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + causal=causal, + dropout_p=0, + use_attention=False, + use_bias=use_bias, + dtype=dtype, + ) From 6c03dcf8605584218d5056b7998dbb72e601f2be Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:54:22 +0000 Subject: [PATCH 07/12] Benchmarking --- benchmark/benchmarks_visualizer.py | 2 +- benchmark/scripts/benchmark_attention.py | 162 +++++++++++++++++++++++ 2 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 benchmark/scripts/benchmark_attention.py diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py index 2cb9b133..360057a4 100644 --- a/benchmark/benchmarks_visualizer.py +++ b/benchmark/benchmarks_visualizer.py @@ -7,7 +7,7 @@ import pandas as pd import seaborn as sns -DATA_PATH = "data/all_benchmark_data.csv" +DATA_PATH = os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv") VISUALIZATIONS_PATH = "visualizations/" diff --git a/benchmark/scripts/benchmark_attention.py b/benchmark/scripts/benchmark_attention.py new file mode 100644 index 00000000..bb4bf0a0 --- /dev/null +++ b/benchmark/scripts/benchmark_attention.py @@ -0,0 +1,162 @@ +import torch +import triton +from transformers.models.llama.modeling_llama import repeat_kv + +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) +from liger_kernel.ops.flash_attention import flash_attn_func + + +############################################################################# +# Test the memory consumption of the attention layer +############################################################################# + + +def bench_memory_attention( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + seqlen = input.x + batch_size = input.extra_benchmark_config["batch_size"] + nheads_q = input.extra_benchmark_config["nheads_q"] + nheads_kv = input.extra_benchmark_config["nheads_kv"] + hidden_size = input.extra_benchmark_config["hidden_size"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + device = "cuda" + + head_dim = hidden_size // nheads_q + q = torch.normal(0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device).requires_grad_() + k = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + v = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + do = torch.randn_like(q) + + if provider == "torch": + q, k, v, do = [x.transpose(1, 2).contiguous() for x in [q, k, v, do]] + + def fwd(): + if provider == "liger": + return flash_attn_func(q, k, v) + if provider == "torch": + if nheads_q == nheads_kv: + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + else: + ngroups = nheads_q // nheads_kv + return torch.nn.functional.scaled_dot_product_attention(q, repeat_kv(k, ngroups), repeat_kv(v, ngroups)) + + def full(): + y = fwd() + y.backward(do) + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_attention( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + seqlen = input.x + batch_size = input.extra_benchmark_config["batch_size"] + nheads_q = input.extra_benchmark_config["nheads_q"] + nheads_kv = input.extra_benchmark_config["nheads_kv"] + hidden_size = input.extra_benchmark_config["hidden_size"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + device = "cuda" + + head_dim = hidden_size // nheads_q + q = torch.normal(0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device).requires_grad_() + k = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + v = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + do = torch.randn_like(q) + + if provider == "torch": + q, k, v, do = [x.transpose(1, 2).contiguous() for x in [q, k, v, do]] + + def fwd(): + if provider == "liger": + return flash_attn_func(q, k, v) + if provider == "torch": + if nheads_q == nheads_kv: + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + else: + ngroups = nheads_q // nheads_kv + return torch.nn.functional.scaled_dot_product_attention(q, repeat_kv(k, ngroups), repeat_kv(v, ngroups)) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(do, retain_graph=True), + grad_to_none=[q, k, v], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward(do) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "attention", + "x_name": "seqlen", + "x_label": "Sequence length", + "x_values": [2**i for i in range(5, 15)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + {"batch_size": 4, "nheads_q": 32, "nheads_kv": 8, "hidden_size": 4096, "dtype": torch.float16} + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_attention, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_attention, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) From b32cf9fa7e7f87ae02c4d2da9a542feb154962aa Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 16:57:31 +0000 Subject: [PATCH 08/12] Reduced test numbers --- test/transformers/test_attention.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/transformers/test_attention.py b/test/transformers/test_attention.py index a22afd21..645f7ed9 100644 --- a/test/transformers/test_attention.py +++ b/test/transformers/test_attention.py @@ -59,19 +59,21 @@ def _test_attention( compare_numerical_errors(ref_dv, pt_dv, liger_dv, 2, 1e-4, "dv") -@pytest.mark.parametrize("dtype", [(torch.float16), (torch.bfloat16)]) -@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize( - "head_dim, nheads_q, nheads_kv", - [(32, 9, 9), (40, 9, 3), (64, 8, 8), (128, 8, 2), (256, 4, 2)], + "dtype, swap_seqlens", + [(torch.float16, True), (torch.bfloat16, False)], ) -@pytest.mark.parametrize("swap_seqlens", [False, True]) -@pytest.mark.parametrize("use_bias", [False, True]) +@pytest.mark.parametrize("head_dim, nheads_q, nheads_kv, use_bias, causal", [ + (32, 9, 9, True, False), + (40, 9, 3, True, True), + (64, 8, 8, False, False), + (128, 8, 2, True, True), + (256, 4, 2, False, True), +]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), - (3, 799), (113, 203), (127, 512), (128, 217), @@ -79,7 +81,7 @@ def _test_attention( ], ) @pytest.mark.parametrize("batch_size", [4]) -def test_fwd_bwd( +def test_attention( batch_size: int, nheads_q: int, nheads_kv: int, From b213b0e5ca0c45ed12073e11920eb49ef0f58ac7 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 17:10:48 +0000 Subject: [PATCH 09/12] Formatted the style --- .../ops/flash_attention/backward/caller.py | 99 +++++++-- .../flash_attention/backward/compute_delta.py | 15 +- .../flash_attention/backward/compute_dkdv.py | 74 +++++-- .../flash_attention/backward/compute_dq.py | 67 +++++-- .../ops/flash_attention/backward/kernel.py | 189 ++++++++++++++---- .../ops/flash_attention/forward/caller.py | 64 ++++-- .../forward/compute_row_blocks.py | 44 ++-- .../ops/flash_attention/forward/kernel.py | 89 ++++++--- .../reference_implementation.py | 16 +- src/liger_kernel/ops/flash_attention/utils.py | 43 +++- .../ops/flash_attention/wrapper.py | 16 +- src/liger_kernel/transformers/attention.py | 26 ++- src/liger_kernel/transformers/model/gemma2.py | 22 +- src/liger_kernel/transformers/model/phi3.py | 32 ++- src/liger_kernel/transformers/model/qwen2.py | 26 ++- test/transformers/test_attention.py | 55 +++-- 16 files changed, 669 insertions(+), 208 deletions(-) diff --git a/src/liger_kernel/ops/flash_attention/backward/caller.py b/src/liger_kernel/ops/flash_attention/backward/caller.py index d80b293e..e655331c 100644 --- a/src/liger_kernel/ops/flash_attention/backward/caller.py +++ b/src/liger_kernel/ops/flash_attention/backward/caller.py @@ -8,7 +8,14 @@ from src.liger_kernel.ops.flash_attention.backward.compute_delta import _compute_delta from src.liger_kernel.ops.flash_attention.backward.kernel import _bwd_kernel -from src.liger_kernel.ops.flash_attention.utils import attention_pack, attention_unpack, torch_ignore_deterministic, infer_bias_strides, handle_dropout, encode_dtype +from src.liger_kernel.ops.flash_attention.utils import ( + attention_pack, + attention_unpack, + torch_ignore_deterministic, + infer_bias_strides, + handle_dropout, + encode_dtype, +) def _flash_attn_backward( @@ -27,9 +34,13 @@ def _flash_attn_backward( ) -> Tuple[Tensor, Tensor, Tensor]: if attention_mask is not None: - assert bias is None, "Attention mask is not supported along with attention bias. Just use bias instead." - assert q.size(1) == k.size(1), "Attention mask is not supported with seqlen_q != seqlen_k" - varlen_mode = (attention_mask.size(0) > 1) + assert ( + bias is None + ), "Attention mask is not supported along with attention bias. Just use bias instead." + assert q.size(1) == k.size( + 1 + ), "Attention mask is not supported with seqlen_q != seqlen_k" + varlen_mode = attention_mask.size(0) > 1 useless_padding = attention_mask.size(1) - attention_mask.sum(-1).max().item() if useless_padding > 0: dO = dO[:, :-useless_padding] @@ -47,7 +58,9 @@ def _flash_attn_backward( batch_size, seqlen_q, nheads_q, head_dim = q.shape _, seqlen_k, nheads_kv, _ = k.shape max_seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + softmax_scale = ( + 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + ) assert nheads_q % nheads_kv == 0, f"{nheads_q = } is not divisible by {nheads_kv =}" assert lse.shape == (batch_size, nheads_q, max_seqlen_q_rounded) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 @@ -55,8 +68,16 @@ def _flash_attn_backward( # Depending on attention_mask, switch to varlen if varlen_mode: # Compute padding-related statistics - cum_seqlens_q = torch.zeros(size=(attention_mask.size(0)+1,), device=attention_mask.device, dtype=torch.int32) - cum_seqlens_k = torch.zeros(size=(attention_mask.size(0)+1,), device=attention_mask.device, dtype=torch.int32) + cum_seqlens_q = torch.zeros( + size=(attention_mask.size(0) + 1,), + device=attention_mask.device, + dtype=torch.int32, + ) + cum_seqlens_k = torch.zeros( + size=(attention_mask.size(0) + 1,), + device=attention_mask.device, + dtype=torch.int32, + ) with torch_ignore_deterministic(): cum_seqlens_q[1:] = attention_mask.sum(dim=1).cumsum(0) cum_seqlens_k[1:] = attention_mask.sum(dim=1).cumsum(0) @@ -68,7 +89,9 @@ def _flash_attn_backward( k = attention_pack(k, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] v = attention_pack(v, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] o = attention_pack(o, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] - dO = attention_pack(dO, attention_mask) # [1, sum_seqlens_qk, num_head, head_dim] + dO = attention_pack( + dO, attention_mask + ) # [1, sum_seqlens_qk, num_head, head_dim] # Update seqlens seqlen_q = q.size(1) seqlen_k = k.size(1) @@ -79,19 +102,34 @@ def _flash_attn_backward( max_seqlen_k = seqlen_k # Handle bias and dropout - stride_bb, stride_bh, stride_bm = infer_bias_strides(bias, batch_size, nheads_q, seqlen_q, seqlen_k) + stride_bb, stride_bh, stride_bm = infer_bias_strides( + bias, batch_size, nheads_q, seqlen_q, seqlen_k + ) dropout_seed = handle_dropout(dropout_p, dropout_seed, is_forward=False) # Prepare gradient accumulators # TODO: maybe we can initialize this as empty -- check pre hook - dq = torch.zeros_like(q, dtype=torch.float32) # [batch_size|1, seqlen_q|sum_seqlens_qk, nheads_q, head_dim] - dk = torch.zeros(size=(k.size(0), k.size(1), q.size(2), k.size(3)), device=k.device, dtype=k.dtype) - dv = torch.zeros(size=(v.size(0), v.size(1), q.size(2), v.size(3)), device=v.device, dtype=v.dtype) + dq = torch.zeros_like( + q, dtype=torch.float32 + ) # [batch_size|1, seqlen_q|sum_seqlens_qk, nheads_q, head_dim] + dk = torch.zeros( + size=(k.size(0), k.size(1), q.size(2), k.size(3)), + device=k.device, + dtype=k.dtype, + ) + dv = torch.zeros( + size=(v.size(0), v.size(1), q.size(2), v.size(3)), + device=v.device, + dtype=v.dtype, + ) delta = torch.zeros_like(lse) # [batch_size, nheads_q, max_seqlen_q_rounded] # Infer problem size BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) # Launch the delta computation kernel - grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch_size * nheads_q) # noqa: E731 + grid = lambda META: ( + triton.cdiv(max_seqlen_q, META["BLOCK_M"]), + batch_size * nheads_q, + ) # noqa: E731 _compute_delta[grid]( o, dO, @@ -116,7 +154,8 @@ def _flash_attn_backward( # Launch backward kernel head_ratio = nheads_q // nheads_kv grid = lambda META: ( # noqa: E731 - triton.cdiv(seqlen_k, META["BLOCK_N1"]) + triton.cdiv(seqlen_q, META["BLOCK_M2"]), + triton.cdiv(seqlen_k, META["BLOCK_N1"]) + + triton.cdiv(seqlen_q, META["BLOCK_M2"]), batch_size * nheads_q, ) _bwd_kernel[grid]( @@ -133,14 +172,30 @@ def _flash_attn_backward( softmax_scale, dropout_p, dropout_seed, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - stride_bb, stride_bh, stride_bm, - dO.stride(0), dO.stride(2), dO.stride(1), - dq.stride(0), dq.stride(2), dq.stride(1), - dk.stride(0), dk.stride(2), dk.stride(1), - dv.stride(0), dv.stride(2), dv.stride(1), + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + stride_bb, + stride_bh, + stride_bm, + dO.stride(0), + dO.stride(2), + dO.stride(1), + dq.stride(0), + dq.stride(2), + dq.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), nheads_q, head_ratio, seqlen_q, diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_delta.py b/src/liger_kernel/ops/flash_attention/backward/compute_delta.py index 54ad3b5a..9f7eac67 100644 --- a/src/liger_kernel/ops/flash_attention/backward/compute_delta.py +++ b/src/liger_kernel/ops/flash_attention/backward/compute_delta.py @@ -47,7 +47,9 @@ def _compute_delta( # Infer actual sequence length of Q and the offset to the last sequence if VARLEN: - actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - tl.load(cum_seqlens_q + off_batch) + actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - tl.load( + cum_seqlens_q + off_batch + ) cu_seq_start_q = tl.load(cum_seqlens_q + off_batch) off_batch = 0 else: @@ -55,14 +57,21 @@ def _compute_delta( cu_seq_start_q = 0 # Load the output tensor - Out_offseted = Out + off_batch * stride_ob + off_head * stride_oh + cu_seq_start_q * stride_om + Out_offseted = ( + Out + off_batch * stride_ob + off_head * stride_oh + cu_seq_start_q * stride_om + ) o = tl.load( Out_offseted + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ).to(tl.float32) # And its gradient - DO_offseted = DO + off_batch * stride_dob + off_head * stride_doh + cu_seq_start_q * stride_dom + DO_offseted = ( + DO + + off_batch * stride_dob + + off_head * stride_doh + + cu_seq_start_q * stride_dom + ) do = tl.load( DO_offseted + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py b/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py index 34c8427a..26390f28 100644 --- a/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py +++ b/src/liger_kernel/ops/flash_attention/backward/compute_dkdv.py @@ -51,11 +51,18 @@ def _compute_single_block_dkdv( # Load Q and LSE now to reduce pipeline stall # BUG: if one is true and the ther not, q is filled with wrong values - q = load_fn(q_ptrs, - offs_m_curr, offs_d, - PAD_ROWS or HEADS_PADDED, PAD_ROWS or HEADS_PADDED, - actual_seqlen_q, headdim) - lse_i = tl.load(LSE + offs_m_curr) # since lsm is padded to max_seqlen_q, should be good + q = load_fn( + q_ptrs, + offs_m_curr, + offs_d, + PAD_ROWS or HEADS_PADDED, + PAD_ROWS or HEADS_PADDED, + actual_seqlen_q, + headdim, + ) + lse_i = tl.load( + LSE + offs_m_curr + ) # since lsm is padded to max_seqlen_q, should be good if BIAS_ON: bias = load_fn( bias_ptrs, @@ -64,7 +71,7 @@ def _compute_single_block_dkdv( PAD_ROWS or HEADS_PADDED, PAD_ROWS or HEADS_PADDED, actual_seqlen_q, - actual_seqlen_k + actual_seqlen_k, ) # Recompute P_ij = softmax(qk, dim=-1).T @@ -73,17 +80,24 @@ def _compute_single_block_dkdv( qk += bias / softmax_scale # TODO: check if this is optimal # Attention and causal mask - offs_n_causal = (offs_n - actual_seqlen_k + actual_seqlen_q) + offs_n_causal = offs_n - actual_seqlen_k + actual_seqlen_q if MASKED: if PAD_COLS: if IS_CAUSAL: qk = tl.where( - tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] >= offs_n_causal[None, :], qk, float("-inf") + tl.minimum(actual_seqlen_q - 1, offs_m_curr)[:, None] + >= offs_n_causal[None, :], + qk, + float("-inf"), ) else: - qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) + qk = tl.where( + actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf") + ) elif IS_CAUSAL: - qk = tl.where(offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf")) + qk = tl.where( + offs_m_curr[:, None] >= offs_n_causal[None, :], qk, float("-inf") + ) tl.debug_barrier() p = tl.exp2(qk * (softmax_scale * 1.44269504089) - lse_i[:, None]) @@ -94,7 +108,9 @@ def _compute_single_block_dkdv( p = tl.where(offs_m_curr[:, None] < fully_masked_lines, 0, p) # Load the gradient of O - do = load_fn(do_ptrs, offs_m_curr, offs_d, PAD_ROWS, HEADS_PADDED, actual_seqlen_q, headdim) + do = load_fn( + do_ptrs, offs_m_curr, offs_d, PAD_ROWS, HEADS_PADDED, actual_seqlen_q, headdim + ) # Compute the gradient of V dv += tl.dot(tl.trans(p).to(do.dtype), do) @@ -148,7 +164,9 @@ def _compute_column_blocks_dkdv( BLOCK_HEADDIM: tl.constexpr, ): # This fuction goes through a column, so it always ends at m = actual_seqlen_q but can start early due to causality - I_begin_m = max(I_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0 + I_begin_m = ( + max(I_start_n + actual_seqlen_q - actual_seqlen_k, 0) if IS_CAUSAL else 0 + ) I_begin_m = (I_begin_m // BLOCK_M) * BLOCK_M I_end_m = actual_seqlen_q @@ -186,14 +204,22 @@ def _compute_column_blocks_dkdv( # Load K and V, which will stay in SRAM for the row-wise loop k = load_fn( - k_ptrs, offs_n, offs_d, - PAD_AXIS_0=PAD_COLS, PAD_AXIS_1=HEADS_PADDED, - LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + k_ptrs, + offs_n, + offs_d, + PAD_AXIS_0=PAD_COLS, + PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_k, + LIM_AXIS_1=headdim, ) v = load_fn( - v_ptrs, offs_n, offs_d, - PAD_AXIS_0=PAD_COLS, PAD_AXIS_1=HEADS_PADDED, - LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + v_ptrs, + offs_n, + offs_d, + PAD_AXIS_0=PAD_COLS, + PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_k, + LIM_AXIS_1=headdim, ) # Loop over rows to compute dk and dv @@ -282,8 +308,16 @@ def _compute_column_blocks_dkdv( # Store dk and dv if HEADS_PADDED: if PAD_COLS: - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim)) + tl.store( + dk_ptrs, + dk, + mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), + ) + tl.store( + dv_ptrs, + dv, + mask=(offs_n[:, None] < actual_seqlen_k) & (offs_d[None, :] < headdim), + ) else: tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) diff --git a/src/liger_kernel/ops/flash_attention/backward/compute_dq.py b/src/liger_kernel/ops/flash_attention/backward/compute_dq.py index 38c2a606..657f79de 100644 --- a/src/liger_kernel/ops/flash_attention/backward/compute_dq.py +++ b/src/liger_kernel/ops/flash_attention/backward/compute_dq.py @@ -44,25 +44,44 @@ def _compute_single_block_dq( dropout_offs += I_start_n # Load K, V and LSE now to reduce pipeline stall - k = load_fn(k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim) - v = load_fn(v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim) + k = load_fn( + k_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim + ) + v = load_fn( + v_ptrs, offs_n_curr, offs_d, PAD_COLS, HEADS_PADDED, actual_seqlen_k, headdim + ) if BIAS_ON: - bias = load_fn(bias_ptrs, offs_m, offs_n_curr, True, PAD_COLS, actual_seqlen_q, actual_seqlen_k) # TODO: pad rows + bias = load_fn( + bias_ptrs, + offs_m, + offs_n_curr, + True, + PAD_COLS, + actual_seqlen_q, + actual_seqlen_k, + ) # TODO: pad rows # Recompute P_ij = softmax(qk, dim=-1).T qk = tl.dot(q, tl.trans(k)) if BIAS_ON: qk += bias / softmax_scale # TODO: check if this is optimal - offs_n_causal = (offs_n_curr - actual_seqlen_k + actual_seqlen_q) + offs_n_causal = offs_n_curr - actual_seqlen_k + actual_seqlen_q # Attention and causal mask if MASKED: if PAD_COLS: if IS_CAUSAL: - qk = tl.where(tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] >= offs_n_causal[None, :], qk, float("-inf")) + qk = tl.where( + tl.minimum(actual_seqlen_q - 1, offs_m)[:, None] + >= offs_n_causal[None, :], + qk, + float("-inf"), + ) else: - qk = tl.where(actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf")) + qk = tl.where( + actual_seqlen_q - 1 >= offs_n_causal[None, :], qk, float("-inf") + ) elif IS_CAUSAL: qk = tl.where(offs_m[:, None] >= offs_n_causal[None, :], qk, float("-inf")) tl.debug_barrier() @@ -115,7 +134,9 @@ def _compute_row_blocks_dq( ): # This fuction goes through a row, so it always starts at i = 0 but the end can vary because of causality if IS_CAUSAL: - I_end_n = min(actual_seqlen_k - actual_seqlen_q + I_start_m + BLOCK_M, actual_seqlen_k) + I_end_n = min( + actual_seqlen_k - actual_seqlen_q + I_start_m + BLOCK_M, actual_seqlen_k + ) # For a seqlen_q >> seqlen_k, there migh be entire block skipped if I_end_n < 0: return @@ -152,20 +173,28 @@ def _compute_row_blocks_dq( # Load Q, DO, LSE and D, which will stay in SRAM for the row-wise loop q = load_fn( - q_ptrs, offs_m, offs_d, - PAD_AXIS_0=PAD_ROWS, PAD_AXIS_1=HEADS_PADDED, - LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + q_ptrs, + offs_m, + offs_d, + PAD_AXIS_0=PAD_ROWS, + PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_q, + LIM_AXIS_1=headdim, ) do = load_fn( - do_ptrs, offs_m, offs_d, - PAD_AXIS_0=PAD_ROWS, PAD_AXIS_1=HEADS_PADDED, - LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + do_ptrs, + offs_m, + offs_d, + PAD_AXIS_0=PAD_ROWS, + PAD_AXIS_1=HEADS_PADDED, + LIM_AXIS_0=actual_seqlen_q, + LIM_AXIS_1=headdim, ) lse_i = tl.load(LSE + offs_m) # since lse is padded to max_seqlen_q, should be good delta_i = tl.load(D + offs_m) # same as LSE for now # Infer the number of full and partially masked blocks - uneven_n = (actual_seqlen_k % BLOCK_N != 0) + uneven_n = actual_seqlen_k % BLOCK_N != 0 attention_padding = VARLEN & uneven_n if IS_CAUSAL: first_masked_col = I_start_m + 1 + actual_seqlen_k - actual_seqlen_q @@ -213,7 +242,9 @@ def _compute_row_blocks_dq( if I_next_start_n < I_end_n: for I_start_n in range(I_next_start_n, I_end_n, BLOCK_N): - pad_cols = (not EVEN_N) or (VARLEN and (I_start_n + BLOCK_N > actual_seqlen_k)) + pad_cols = (not EVEN_N) or ( + VARLEN and (I_start_n + BLOCK_N > actual_seqlen_k) + ) dq = _compute_single_block_dq( I_start_n, q, @@ -251,7 +282,11 @@ def _compute_row_blocks_dq( # Store dq if HEADS_PADDED: if PAD_ROWS: - tl.store(dq_ptrs, dq, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)) + tl.store( + dq_ptrs, + dq, + mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), + ) else: tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim) else: diff --git a/src/liger_kernel/ops/flash_attention/backward/kernel.py b/src/liger_kernel/ops/flash_attention/backward/kernel.py index 560f2bfd..17a1452a 100644 --- a/src/liger_kernel/ops/flash_attention/backward/kernel.py +++ b/src/liger_kernel/ops/flash_attention/backward/kernel.py @@ -5,8 +5,12 @@ import triton.language as tl from triton import Config -from src.liger_kernel.ops.flash_attention.backward.compute_dkdv import _compute_column_blocks_dkdv -from src.liger_kernel.ops.flash_attention.backward.compute_dq import _compute_row_blocks_dq +from src.liger_kernel.ops.flash_attention.backward.compute_dkdv import ( + _compute_column_blocks_dkdv, +) +from src.liger_kernel.ops.flash_attention.backward.compute_dq import ( + _compute_row_blocks_dq, +) MIN_B = 16 @@ -19,25 +23,65 @@ def early_config_prune_bwd_kernel( # Remove the configs where BLOCK_ > seqlen_ kept_configs = [] for cfg in configs: - block_m_too_large = max(cfg.kwargs["BLOCK_M1"], cfg.kwargs["BLOCK_M2"]) > named_args["seqlen_q"] - block_n_too_large = max(cfg.kwargs["BLOCK_N1"], cfg.kwargs["BLOCK_N2"]) > named_args["seqlen_k"] - if (block_m_too_large or block_n_too_large): + block_m_too_large = ( + max(cfg.kwargs["BLOCK_M1"], cfg.kwargs["BLOCK_M2"]) > named_args["seqlen_q"] + ) + block_n_too_large = ( + max(cfg.kwargs["BLOCK_N1"], cfg.kwargs["BLOCK_N2"]) > named_args["seqlen_k"] + ) + if block_m_too_large or block_n_too_large: pass else: kept_configs.append(cfg) # If no config is left, go for the minimal config if kept_configs: return kept_configs - return [Config({"BLOCK_M1": MIN_B, "BLOCK_N1": MIN_B, "BLOCK_M2": MIN_B, "BLOCK_N2": MIN_B}, num_warps=4, num_stages=0)] + return [ + Config( + { + "BLOCK_M1": MIN_B, + "BLOCK_N1": MIN_B, + "BLOCK_M2": MIN_B, + "BLOCK_N2": MIN_B, + }, + num_warps=4, + num_stages=0, + ) + ] @triton.autotune( configs=[ - Config({"BLOCK_M1": MIN_B, "BLOCK_N1": MIN_B, "BLOCK_M2": MIN_B, "BLOCK_N2": MIN_B}, num_warps=4, num_stages=0), - Config({"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32}, num_warps=4, num_stages=0), - Config({"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32}, num_warps=4, num_stages=0), - Config({"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64}, num_warps=4, num_stages=0), - Config({"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64}, num_warps=4, num_stages=0), + Config( + { + "BLOCK_M1": MIN_B, + "BLOCK_N1": MIN_B, + "BLOCK_M2": MIN_B, + "BLOCK_N2": MIN_B, + }, + num_warps=4, + num_stages=0, + ), + Config( + {"BLOCK_M1": 32, "BLOCK_N1": 16, "BLOCK_M2": 16, "BLOCK_N2": 32}, + num_warps=4, + num_stages=0, + ), + Config( + {"BLOCK_M1": 32, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 32}, + num_warps=4, + num_stages=0, + ), + Config( + {"BLOCK_M1": 64, "BLOCK_N1": 64, "BLOCK_M2": 64, "BLOCK_N2": 64}, + num_warps=4, + num_stages=0, + ), + Config( + {"BLOCK_M1": 64, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 64}, + num_warps=4, + num_stages=0, + ), ], key=[ "CACHE_KEY_SEQLEN_Q", @@ -76,14 +120,30 @@ def _bwd_kernel( softmax_scale, dropout_p, dropout_seed, - stride_qb, stride_qh, stride_qm, - stride_kb, stride_kh, stride_kn, - stride_vb, stride_vh, stride_vn, - stride_bb, stride_bh, stride_bm, - stride_dob, stride_doh, stride_dom, - stride_dqb, stride_dqh, stride_dqm, - stride_dkb, stride_dkh, stride_dkn, - stride_dvb, stride_dvh, stride_dvn, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, nheads_q, head_ratio, seqlen_q, @@ -140,9 +200,13 @@ def _bwd_kernel( DK += off_batch * stride_dkb + off_head_q * stride_dkh + cu_seq_start_k * stride_dkn DV += off_batch * stride_dvb + off_head_q * stride_dvh + cu_seq_start_k * stride_dvn if BIAS_ON: - Bias += off_batch * stride_bb + off_head_q * stride_bh + cu_seq_start_q * stride_bm + Bias += ( + off_batch * stride_bb + off_head_q * stride_bh + cu_seq_start_q * stride_bm + ) if USE_DROPOUT: - Dropout = actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch)) + Dropout = actual_seqlen_k * ( + cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch) + ) else: Dropout = None @@ -153,30 +217,81 @@ def _bwd_kernel( # Case: this block works on dk and dv if pid < NUM_BLOCKS_KV: i_start_n = pid - pad_cols = (not EVEN_N1) or (VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k)) + pad_cols = (not EVEN_N1) or ( + VARLEN and ((i_start_n + 1) * BLOCK_N1 > actual_seqlen_k) + ) _compute_column_blocks_dkdv( i_start_n * BLOCK_N1, - Q, K, V, Bias, Dropout, DO, DK, DV, LSE, D, - softmax_scale, dropout_p, dropout_seed, - stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dkn, stride_dvn, - actual_seqlen_q, actual_seqlen_k, headdim, - IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, USE_DROPOUT=USE_DROPOUT, - PAD_COLS=pad_cols, HEADS_PADDED=HEADS_PADDED, - BLOCK_M=BLOCK_M1, BLOCK_N=BLOCK_N1, BLOCK_HEADDIM=BLOCK_HEADDIM, + Q, + K, + V, + Bias, + Dropout, + DO, + DK, + DV, + LSE, + D, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dkn, + stride_dvn, + actual_seqlen_q, + actual_seqlen_k, + headdim, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + PAD_COLS=pad_cols, + HEADS_PADDED=HEADS_PADDED, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLOCK_HEADDIM=BLOCK_HEADDIM, ) # Case: this block works on dq else: i_start_m = pid - NUM_BLOCKS_KV - pad_rows = (not EVEN_M2) or (VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q)) + pad_rows = (not EVEN_M2) or ( + VARLEN and ((i_start_m + 1) * BLOCK_M2 > actual_seqlen_q) + ) _compute_row_blocks_dq( i_start_m * BLOCK_M2, - Q, K, V, Bias, Dropout, DO, DQ, LSE, D, - softmax_scale, dropout_p, dropout_seed, - stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, - actual_seqlen_q, actual_seqlen_k, headdim, - VARLEN=VARLEN, IS_CAUSAL=IS_CAUSAL, BIAS_ON=BIAS_ON, USE_DROPOUT=USE_DROPOUT, - PAD_ROWS=pad_rows, HEADS_PADDED=HEADS_PADDED, - BLOCK_M=BLOCK_M2, BLOCK_N=BLOCK_N2, BLOCK_HEADDIM=BLOCK_HEADDIM, + Q, + K, + V, + Bias, + Dropout, + DO, + DQ, + LSE, + D, + softmax_scale, + dropout_p, + dropout_seed, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + actual_seqlen_q, + actual_seqlen_k, + headdim, + VARLEN=VARLEN, + IS_CAUSAL=IS_CAUSAL, + BIAS_ON=BIAS_ON, + USE_DROPOUT=USE_DROPOUT, + PAD_ROWS=pad_rows, + HEADS_PADDED=HEADS_PADDED, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_N=EVEN_N2, ) diff --git a/src/liger_kernel/ops/flash_attention/forward/caller.py b/src/liger_kernel/ops/flash_attention/forward/caller.py index e514e745..cccd474c 100644 --- a/src/liger_kernel/ops/flash_attention/forward/caller.py +++ b/src/liger_kernel/ops/flash_attention/forward/caller.py @@ -6,7 +6,14 @@ from torch import Tensor from src.liger_kernel.ops.flash_attention.forward.kernel import _fwd_kernel -from src.liger_kernel.ops.flash_attention.utils import attention_pack, attention_unpack, torch_ignore_deterministic, infer_bias_strides, handle_dropout, encode_dtype +from src.liger_kernel.ops.flash_attention.utils import ( + attention_pack, + attention_unpack, + torch_ignore_deterministic, + infer_bias_strides, + handle_dropout, + encode_dtype, +) def _flash_attn_forward( @@ -24,8 +31,12 @@ def _flash_attn_forward( # Currently, variable length (varlen) mode is mutually exclusive with attention masking (TODO) if attention_mask is not None: varlen_mode = True - assert bias is None, "Attention mask is not supported along with attention bias. Just use bias instead." - assert q.size(1) == k.size(1), "Attention mask is not supported with seqlen_q != seqlen_k" + assert ( + bias is None + ), "Attention mask is not supported along with attention bias. Just use bias instead." + assert q.size(1) == k.size( + 1 + ), "Attention mask is not supported with seqlen_q != seqlen_k" else: varlen_mode = False @@ -39,13 +50,19 @@ def _flash_attn_forward( assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - softmax_scale = 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + softmax_scale = ( + 1.0 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale + ) # Depending on attention_mask, switch to varlen varlen_mode = varlen_mode and (batch > 1) if varlen_mode: # Compute padding-related statistics - cum_seqlens_q = torch.zeros(size=(attention_mask.size(0) + 1,), device=attention_mask.device, dtype=torch.int32) + cum_seqlens_q = torch.zeros( + size=(attention_mask.size(0) + 1,), + device=attention_mask.device, + dtype=torch.int32, + ) with torch_ignore_deterministic(): cum_seqlens_q[1:] = attention_mask.sum(dim=1).cumsum(0) # cum_seqlens_q = [0, seqlen_q1, seqlen_q1+seqlen_q2, ..., seqlen_q1+...+seqlen_qB] of shape [B+1] @@ -63,15 +80,21 @@ def _flash_attn_forward( max_seqlen_k = seqlen_k # Account for bias and dropout - stride_bb, stride_bh, stride_bm = infer_bias_strides(bias, batch, nheads_q, seqlen_q, seqlen_k) + stride_bb, stride_bh, stride_bm = infer_bias_strides( + bias, batch, nheads_q, seqlen_q, seqlen_k + ) dropout_seed = handle_dropout(dropout_p, dropout_seed, is_forward=True) # Setup output accumulator o = torch.zeros_like(q) # Setup LSE accumulators: in varlen mode, batch is still equal to the nb of queries - max_seqlen_q_rounded = math.ceil(max_seqlen_q / 128) * 128 # wastefull in varlen and not (just use mask) - lse = torch.zeros((batch, nheads_q, max_seqlen_q_rounded), device=q.device, dtype=torch.float32) + max_seqlen_q_rounded = ( + math.ceil(max_seqlen_q / 128) * 128 + ) # wastefull in varlen and not (just use mask) + lse = torch.zeros( + (batch, nheads_q, max_seqlen_q_rounded), device=q.device, dtype=torch.float32 + ) # Infer problem size and launch kernel BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) @@ -79,7 +102,10 @@ def _flash_attn_forward( # BLOCK = 128 # num_warps = 4 if head_dim <= 64 else 8 head_ratio = nheads_q // nheads_kv - grid = lambda META: (triton.cdiv(max_seqlen_q, META["BLOCK_M"]), batch * nheads_q) # noqa: E731 + grid = lambda META: ( + triton.cdiv(max_seqlen_q, META["BLOCK_M"]), + batch * nheads_q, + ) # noqa: E731 _fwd_kernel[grid]( q, k, @@ -90,11 +116,21 @@ def _flash_attn_forward( softmax_scale, dropout_p, dropout_seed, - q.stride(0), q.stride(2), q.stride(1), - k.stride(0), k.stride(2), k.stride(1), - v.stride(0), v.stride(2), v.stride(1), - o.stride(0), o.stride(2), o.stride(1), - stride_bb, stride_bh, stride_bm, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + o.stride(0), + o.stride(2), + o.stride(1), + stride_bb, + stride_bh, + stride_bm, nheads_q, head_ratio, seqlen_q, diff --git a/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py b/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py index f717cf14..ce94e681 100644 --- a/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py +++ b/src/liger_kernel/ops/flash_attention/forward/compute_row_blocks.py @@ -41,16 +41,22 @@ def compute_row_block( offset_k_ptrs = k_ptrs + I_start_n * stride_kn k = load_fn( offset_k_ptrs, - I_start_n + offs_n, offs_d, - PAD_AXIS_0=PADDED_COLS, PAD_AXIS_1=PADDED_HEADS, - LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + I_start_n + offs_n, + offs_d, + PAD_AXIS_0=PADDED_COLS, + PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_k, + LIM_AXIS_1=headdim, ) if BIAS_ON: bias = load_fn( bias_ptrs + I_start_n, - offs_m, I_start_n + offs_n, - PAD_AXIS_0=True, PAD_AXIS_1=PADDED_COLS, # check - LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=actual_seqlen_k, + offs_m, + I_start_n + offs_n, + PAD_AXIS_0=True, + PAD_AXIS_1=PADDED_COLS, # check + LIM_AXIS_0=actual_seqlen_q, + LIM_AXIS_1=actual_seqlen_k, ) # Compute QK @@ -58,11 +64,18 @@ def compute_row_block( qk += tl.dot(q, tl.trans(k)) # Apply attention masking and/or account for padding of the keys - if PADDED_COLS: # TODO: check impact on speed when conditionned by MASKED (always true?) - qk += tl.where((I_start_n + offs_n)[None, :] < actual_seqlen_k, 0, float("-inf")) + if ( + PADDED_COLS + ): # TODO: check impact on speed when conditionned by MASKED (always true?) + qk += tl.where( + (I_start_n + offs_n)[None, :] < actual_seqlen_k, 0, float("-inf") + ) # Apply causal mask if MASKED and IS_CAUSAL: - causal_mask = offs_m[:, None] >= (I_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :] + causal_mask = ( + offs_m[:, None] + >= (I_start_n + offs_n - actual_seqlen_k + actual_seqlen_q)[None, :] + ) qk += tl.where(causal_mask, 0, float("-inf")) if BIAS_ON: @@ -75,7 +88,9 @@ def compute_row_block( # Dropout if USE_DROPOUT: dropout_offs = dropout_offs + I_start_n - dropout_mask = (tl.rand(dropout_seed, dropout_offs) > dropout_p) # TODO: replace this w/ randint for better perfs + dropout_mask = ( + tl.rand(dropout_seed, dropout_offs) > dropout_p + ) # TODO: replace this w/ randint for better perfs P_ij = tl.where(dropout_mask, P_ij, 0.0) # Scale the output accumulator @@ -86,9 +101,12 @@ def compute_row_block( offset_v_ptrs = v_ptrs + I_start_n * stride_vn v = load_fn( offset_v_ptrs, - I_start_n + offs_n, offs_d, - PAD_AXIS_0=PADDED_COLS, PAD_AXIS_1=PADDED_HEADS, - LIM_AXIS_0=actual_seqlen_k, LIM_AXIS_1=headdim, + I_start_n + offs_n, + offs_d, + PAD_AXIS_0=PADDED_COLS, + PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_k, + LIM_AXIS_1=headdim, ) # Update the output accumulator diff --git a/src/liger_kernel/ops/flash_attention/forward/kernel.py b/src/liger_kernel/ops/flash_attention/forward/kernel.py index e7eeec03..6316947d 100644 --- a/src/liger_kernel/ops/flash_attention/forward/kernel.py +++ b/src/liger_kernel/ops/flash_attention/forward/kernel.py @@ -3,7 +3,9 @@ from triton import Config from typing import List, Any, Dict -from src.liger_kernel.ops.flash_attention.forward.compute_row_blocks import compute_row_block +from src.liger_kernel.ops.flash_attention.forward.compute_row_blocks import ( + compute_row_block, +) from src.liger_kernel.ops.flash_attention.utils import load_fn # TODO: exit causal blocks early @@ -22,7 +24,7 @@ def early_config_prune_fwd_kernel( for cfg in configs: block_m_too_large = cfg.kwargs["BLOCK_M"] > named_args["seqlen_q"] block_n_too_large = cfg.kwargs["BLOCK_N"] > named_args["seqlen_k"] - if (block_m_too_large or block_n_too_large): + if block_m_too_large or block_n_too_large: pass else: kept_configs.append(cfg) @@ -68,11 +70,21 @@ def _fwd_kernel( softmax_scale, dropout_p, dropout_seed, - stride_qb, stride_qh, stride_qm, # Q stride for the batch, head and sequence axis (sequence subscript is m for rows) - stride_kb, stride_kh, stride_kn, # Same for K (sequence subscript is n for cols) - stride_vb, stride_vh, stride_vn, # Same for V (sequence subscript is n for cols) - stride_ob, stride_oh, stride_om, # Same for O (sequence subscript is m for rows) - stride_bb, stride_bh, stride_bm, + stride_qb, + stride_qh, + stride_qm, # Q stride for the batch, head and sequence axis (sequence subscript is m for rows) + stride_kb, + stride_kh, + stride_kn, # Same for K (sequence subscript is n for cols) + stride_vb, + stride_vh, + stride_vn, # Same for V (sequence subscript is n for cols) + stride_ob, + stride_oh, + stride_om, # Same for O (sequence subscript is m for rows) + stride_bb, + stride_bh, + stride_bm, nheads_q, head_ratio, seqlen_q, @@ -107,7 +119,9 @@ def _fwd_kernel( actual_seqlen_q = tl.load(cum_seqlens_q + off_batch + 1) - cu_seq_start_q if i_start_m * BLOCK_M >= actual_seqlen_q: return - actual_seqlen_k = actual_seqlen_q # TODO: support packed + varlen? rn, check is done outside + actual_seqlen_k = ( + actual_seqlen_q # TODO: support packed + varlen? rn, check is done outside + ) cu_seq_start_k = cu_seq_start_q off_batch = 0 else: @@ -126,25 +140,38 @@ def _fwd_kernel( # current sequence might have less rows than the current row (detemined through the grid). fully_masked_lines = actual_seqlen_q - actual_seqlen_k if IS_CAUSAL else 0 - if fully_masked_lines >= (i_start_m+1) * BLOCK_M: + if fully_masked_lines >= (i_start_m + 1) * BLOCK_M: return # Initialize pointers to Q, K, V - offseted_Q = Q + off_batch * stride_qb + off_head_q * stride_qh + cu_seq_start_q * stride_qm - q_ptrs = (offseted_Q + (offs_m[:, None] * stride_qm + offs_d[None, :])) - offseted_K = K + off_batch * stride_kb + off_head_kv * stride_kh + cu_seq_start_k * stride_kn - k_ptrs = (offseted_K + (offs_n[:, None] * stride_kn + offs_d[None, :])) - offseted_V = V + off_batch * stride_vb + off_head_kv * stride_vh + cu_seq_start_k * stride_vn - v_ptrs = (offseted_V + (offs_n[:, None] * stride_vn + offs_d[None, :])) + offseted_Q = ( + Q + off_batch * stride_qb + off_head_q * stride_qh + cu_seq_start_q * stride_qm + ) + q_ptrs = offseted_Q + (offs_m[:, None] * stride_qm + offs_d[None, :]) + offseted_K = ( + K + off_batch * stride_kb + off_head_kv * stride_kh + cu_seq_start_k * stride_kn + ) + k_ptrs = offseted_K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + offseted_V = ( + V + off_batch * stride_vb + off_head_kv * stride_vh + cu_seq_start_k * stride_vn + ) + v_ptrs = offseted_V + (offs_n[:, None] * stride_vn + offs_d[None, :]) # ...and maybe bias if BIAS_ON: - offseted_Bias = Bias + off_batch * stride_bb + off_head_kv * stride_bh + cu_seq_start_q * stride_bm - bias_ptrs = (offseted_Bias + (offs_m[:, None] * stride_bm + offs_n[None, :])) + offseted_Bias = ( + Bias + + off_batch * stride_bb + + off_head_kv * stride_bh + + cu_seq_start_q * stride_bm + ) + bias_ptrs = offseted_Bias + (offs_m[:, None] * stride_bm + offs_n[None, :]) else: bias_ptrs = None # ...and maybe dropout if USE_DROPOUT: - dropout_off = actual_seqlen_k * (cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch)) + dropout_off = actual_seqlen_k * ( + cu_seq_start_q + actual_seqlen_q * (off_head_q + nheads_q * off_batch) + ) dropout_offs = dropout_off + offs_m[:, None] * actual_seqlen_k + offs_n[None, :] else: dropout_offs = None @@ -155,17 +182,25 @@ def _fwd_kernel( acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # Load Q, which will stay in SRAM for the whole loop - pad_rows = (not EVEN_M) or (VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q)) # this works while other bools fail. Why? + pad_rows = (not EVEN_M) or ( + VARLEN and (i_start_m * BLOCK_M > actual_seqlen_q) + ) # this works while other bools fail. Why? q = load_fn( q_ptrs, - offs_m, offs_d, - PAD_AXIS_0=pad_rows, PAD_AXIS_1=PADDED_HEADS, - LIM_AXIS_0=actual_seqlen_q, LIM_AXIS_1=headdim, + offs_m, + offs_d, + PAD_AXIS_0=pad_rows, + PAD_AXIS_1=PADDED_HEADS, + LIM_AXIS_0=actual_seqlen_q, + LIM_AXIS_1=headdim, ) # Compute last visited column of KV which if IS_CAUSAL: - end_n = min(actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M, actual_seqlen_k) + end_n = min( + actual_seqlen_k - actual_seqlen_q + (i_start_m + 1) * BLOCK_M, + actual_seqlen_k, + ) # For a seqlen_q >> seqlen_k, there migh be entire block skipped if end_n < 0: return @@ -173,7 +208,7 @@ def _fwd_kernel( end_n = actual_seqlen_k # first_masked_block = min(start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q, end_n) if IS_CAUSAL else end_n - uneven_n = (actual_seqlen_k % BLOCK_N != 0) + uneven_n = actual_seqlen_k % BLOCK_N != 0 attention_padding = VARLEN & uneven_n if IS_CAUSAL: first_masked_col = i_start_m * BLOCK_M + 1 + actual_seqlen_k - actual_seqlen_q @@ -282,7 +317,11 @@ def _fwd_kernel( # Store O (same mechanism as Q) BUG: here, the store instruction seems to fail when one of the two bools is false if True: - tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim)) + tl.store( + out_ptrs, + acc_o, + mask=(offs_m[:, None] < actual_seqlen_q) & (offs_d[None, :] < headdim), + ) elif pad_rows: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < actual_seqlen_q) elif PADDED_HEADS: # nothing is padded diff --git a/src/liger_kernel/ops/flash_attention/reference_implementation.py b/src/liger_kernel/ops/flash_attention/reference_implementation.py index 93304a12..311b5d24 100644 --- a/src/liger_kernel/ops/flash_attention/reference_implementation.py +++ b/src/liger_kernel/ops/flash_attention/reference_implementation.py @@ -13,7 +13,9 @@ def construct_local_mask( device=None, key_leftpad=None, ): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + row_idx = rearrange( + torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" + ) col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") @@ -94,7 +96,9 @@ def flash_attn_reference( scores = scores.tanh() scores = scores * softcap if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -111,11 +115,15 @@ def flash_attn_reference( attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) diff --git a/src/liger_kernel/ops/flash_attention/utils.py b/src/liger_kernel/ops/flash_attention/utils.py index 6bdd9713..74fe5720 100644 --- a/src/liger_kernel/ops/flash_attention/utils.py +++ b/src/liger_kernel/ops/flash_attention/utils.py @@ -23,11 +23,13 @@ def attention_unpack( batch_size: int, goal_seqlen: int, ) -> torch.Tensor: - unpacked = torch.zeros(size=(batch_size, goal_seqlen, *x.shape[2:]), dtype=x.dtype, device=x.device) - for i in range(cum_seqlens.size(0)-1): + unpacked = torch.zeros( + size=(batch_size, goal_seqlen, *x.shape[2:]), dtype=x.dtype, device=x.device + ) + for i in range(cum_seqlens.size(0) - 1): seq_start = cum_seqlens[i] - seq_end = cum_seqlens[i+1] - unpacked[i, :seq_end-seq_start] = x[0, seq_start:seq_end] + seq_end = cum_seqlens[i + 1] + unpacked[i, : seq_end - seq_start] = x[0, seq_start:seq_end] return unpacked @@ -43,7 +45,12 @@ def load_fn( ): if PAD_AXIS_0: if PAD_AXIS_1: - x = tl.load(ptrs, mask=(offs_axis_0[:, None] < LIM_AXIS_0) & (offs_axis_1[None, :] < LIM_AXIS_1), other=0.0) + x = tl.load( + ptrs, + mask=(offs_axis_0[:, None] < LIM_AXIS_0) + & (offs_axis_1[None, :] < LIM_AXIS_1), + other=0.0, + ) else: x = tl.load(ptrs, mask=offs_axis_0[:, None] < LIM_AXIS_0, other=0.0) else: @@ -55,10 +62,14 @@ def load_fn( def infer_bias_strides( - bias: Optional[Tensor], batch: int, nheads_q: int, seqlen_q: int, seqlen_k: int, + bias: Optional[Tensor], + batch: int, + nheads_q: int, + seqlen_q: int, + seqlen_k: int, ) -> Tuple[int, ...]: if bias is not None: - assert (bias.size(2) == seqlen_q and bias.size(3) == seqlen_k), f"{bias.shape = }" + assert bias.size(2) == seqlen_q and bias.size(3) == seqlen_k, f"{bias.shape = }" if bias.size(0) == 1: stride_bb = 0 elif bias.size(0) == batch: @@ -70,20 +81,30 @@ def infer_bias_strides( elif bias.stride(1) == nheads_q: stride_bh = bias.stride(1) else: - raise ValueError(f"Attention bias has {bias.size(1) = } while {nheads_q = }") + raise ValueError( + f"Attention bias has {bias.size(1) = } while {nheads_q = }" + ) stride_bm = bias.stride(2) else: stride_bb, stride_bh, stride_bm = 0, 0, 0 return stride_bb, stride_bh, stride_bm -def handle_dropout(dropout_p: float, dropout_seed: Optional[int], is_forward: bool) -> int: +def handle_dropout( + dropout_p: float, dropout_seed: Optional[int], is_forward: bool +) -> int: assert dropout_p >= 0, f"Dropout probability {dropout_p = } must be above 0." - assert dropout_p < 1, f"Dropout probability {dropout_p = } must be strictly below 1." + assert ( + dropout_p < 1 + ), f"Dropout probability {dropout_p = } must be strictly below 1." if dropout_p == 0: return 0 elif is_forward: - return torch.randint(low=0, high=2**32, size=(1,)).item() if dropout_seed is None else dropout_seed + return ( + torch.randint(low=0, high=2**32, size=(1,)).item() + if dropout_seed is None + else dropout_seed + ) else: raise NotImplementedError("Backward pass does not yet support dropout.") diff --git a/src/liger_kernel/ops/flash_attention/wrapper.py b/src/liger_kernel/ops/flash_attention/wrapper.py index 66fe83e3..31e6abe7 100644 --- a/src/liger_kernel/ops/flash_attention/wrapper.py +++ b/src/liger_kernel/ops/flash_attention/wrapper.py @@ -41,7 +41,9 @@ def forward( q = q if q.stride(-1) == 1 else q.contiguous() k = k if k.stride(-1) == 1 else k.contiguous() v = v if v.stride(-1) == 1 else v.contiguous() - attention_bias = None if (attention_bias is None) else attention_bias.contiguous() + attention_bias = ( + None if (attention_bias is None) else attention_bias.contiguous() + ) o, lse, ctx.softmax_scale, ctx.dropout_seed = _flash_attn_forward( q=q, k=k, @@ -97,4 +99,14 @@ def flash_attn_func( softmax_scale: Optional[Tensor] = None, dropout_seed: Optional[int] = None, ) -> Tensor: - return FlashAttnFunc.apply(q, k, v, attention_mask, attention_bias, dropout_p, causal, softmax_scale, dropout_seed) + return FlashAttnFunc.apply( + q, + k, + v, + attention_mask, + attention_bias, + dropout_p, + causal, + softmax_scale, + dropout_seed, + ) diff --git a/src/liger_kernel/transformers/attention.py b/src/liger_kernel/transformers/attention.py index 76e5aa80..3dea7910 100644 --- a/src/liger_kernel/transformers/attention.py +++ b/src/liger_kernel/transformers/attention.py @@ -2,7 +2,11 @@ from transformers.cache_utils import Cache import torch -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, logger, LlamaSdpaAttention +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, + logger, + LlamaSdpaAttention, +) from liger_kernel.ops.flash_attention.wrapper import flash_attn_func @@ -17,7 +21,9 @@ def liger_general_sdpa_forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -30,9 +36,15 @@ def liger_general_sdpa_forward( key_states = self.k_proj.forward(hidden_states) value_states = self.v_proj.forward(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) if position_embeddings is None: logger.warning_once( @@ -49,7 +61,9 @@ def liger_general_sdpa_forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # key_states = repeat_kv(key_states, self.num_key_value_groups) not needed as we support GQA # value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py index 655aea71..7063ecf5 100644 --- a/src/liger_kernel/transformers/model/gemma2.py +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -2,7 +2,11 @@ from transformers.cache_utils import Cache import torch -from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb, logger, GemmaSdpaAttention +from transformers.models.gemma.modeling_gemma import ( + apply_rotary_pos_emb, + logger, + GemmaSdpaAttention, +) from liger_kernel.ops.flash_attention.wrapper import flash_attn_func @@ -41,9 +45,15 @@ def liger_gemma2_sdpa_forward( key_states = self.k_proj.forward(hidden_states) value_states = self.v_proj.forward(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -51,7 +61,9 @@ def liger_gemma2_sdpa_forward( if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # Commented out because we support GQA # key_states = repeat_kv(key_states, self.num_key_value_groups) diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index f322c3ba..be5eb01c 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -9,7 +9,7 @@ Phi3SdpaAttention, Cache, logger, - apply_rotary_pos_emb + apply_rotary_pos_emb, ) from transformers.utils import ( add_start_docstrings_to_model_forward, @@ -170,23 +170,39 @@ def liger_phi3_sdpa_attention_forward( qkv = self.qkv_proj.forward(hidden_states) query_pos = self.num_heads * self.head_dim query_states = qkv[..., :query_pos] - key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + key_states = qkv[ + ..., query_pos : query_pos + self.num_key_value_heads * self.head_dim + ] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # Commented out because we support GQA # key_states = repeat_kv(key_states, self.num_key_value_groups) diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index 896f42bc..1c071b4b 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -172,20 +172,34 @@ def liger_qwen2_sdpa_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # [Liger-Kernel modification] As we support GQa, we don't need to do this # key_states = repeat_kv(key_states, self.num_key_value_groups) diff --git a/test/transformers/test_attention.py b/test/transformers/test_attention.py index 645f7ed9..623cb9d9 100644 --- a/test/transformers/test_attention.py +++ b/test/transformers/test_attention.py @@ -38,22 +38,42 @@ def _test_attention( ) -> Optional[Tuple[Tensor, ...]]: # Prepare data - q = torch.normal(0, 0.5, (batch_size, seqlen_q, nheads_q, head_dim), dtype=dtype, device=DEVICE).requires_grad_() - k = torch.normal(0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE).requires_grad_() - v = torch.normal(0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE).requires_grad_() + q = torch.normal( + 0, 0.5, (batch_size, seqlen_q, nheads_q, head_dim), dtype=dtype, device=DEVICE + ).requires_grad_() + k = torch.normal( + 0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE + ).requires_grad_() + v = torch.normal( + 0, 0.5, (batch_size, seqlen_k, nheads_kv, head_dim), dtype=dtype, device=DEVICE + ).requires_grad_() do = torch.randn_like(q) - attn_bias = torch.rand(size=(1, 1, seqlen_q, seqlen_k), dtype=dtype, device=q.device) if use_bias else None + attn_bias = ( + torch.rand(size=(1, 1, seqlen_q, seqlen_k), dtype=dtype, device=q.device) + if use_bias + else None + ) # Compute the outputs of the forward pass - ref_output = flash_attn_reference(q, k, v, attn_bias=attn_bias, causal=causal, upcast=True, reorder_ops=False) - pt_output = flash_attn_reference(q, k, v, attn_bias=attn_bias, causal=causal, upcast=False, reorder_ops=True) + ref_output = flash_attn_reference( + q, k, v, attn_bias=attn_bias, causal=causal, upcast=True, reorder_ops=False + ) + pt_output = flash_attn_reference( + q, k, v, attn_bias=attn_bias, causal=causal, upcast=False, reorder_ops=True + ) liger_output = flash_attn_func(q, k, v, attention_bias=attn_bias, causal=causal) compare_numerical_errors(ref_output, pt_output, liger_output, 1, 1e-4, "output") # Compare the gradients after the backward pass - ref_dq, ref_dk, ref_dv = torch.autograd.grad(ref_output, (q, k, v), do, retain_graph=True) - pt_dq, pt_dk, pt_dv = torch.autograd.grad(pt_output, (q, k, v), do, retain_graph=True) - liger_dq, liger_dk, liger_dv = torch.autograd.grad(liger_output, (q, k, v), do, retain_graph=True) + ref_dq, ref_dk, ref_dv = torch.autograd.grad( + ref_output, (q, k, v), do, retain_graph=True + ) + pt_dq, pt_dk, pt_dv = torch.autograd.grad( + pt_output, (q, k, v), do, retain_graph=True + ) + liger_dq, liger_dk, liger_dv = torch.autograd.grad( + liger_output, (q, k, v), do, retain_graph=True + ) compare_numerical_errors(ref_dq, pt_dq, liger_dq, 2, 1e-4, "dq") compare_numerical_errors(ref_dk, pt_dk, liger_dk, 2, 1e-4, "dk") compare_numerical_errors(ref_dv, pt_dv, liger_dv, 2, 1e-4, "dv") @@ -63,13 +83,16 @@ def _test_attention( "dtype, swap_seqlens", [(torch.float16, True), (torch.bfloat16, False)], ) -@pytest.mark.parametrize("head_dim, nheads_q, nheads_kv, use_bias, causal", [ - (32, 9, 9, True, False), - (40, 9, 3, True, True), - (64, 8, 8, False, False), - (128, 8, 2, True, True), - (256, 4, 2, False, True), -]) +@pytest.mark.parametrize( + "head_dim, nheads_q, nheads_kv, use_bias, causal", + [ + (32, 9, 9, True, False), + (40, 9, 3, True, True), + (64, 8, 8, False, False), + (128, 8, 2, True, True), + (256, 4, 2, False, True), + ], +) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ From c316b8793ad1c48770113973d0cc8bc9e60e44bb Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 17:11:46 +0000 Subject: [PATCH 10/12] Further style --- benchmark/benchmarks_visualizer.py | 2 +- benchmark/scripts/benchmark_attention.py | 40 ++++++++++++++++++------ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py index 360057a4..4179b6f0 100644 --- a/benchmark/benchmarks_visualizer.py +++ b/benchmark/benchmarks_visualizer.py @@ -8,7 +8,7 @@ import seaborn as sns DATA_PATH = os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv") -VISUALIZATIONS_PATH = "visualizations/" +VISUALIZATIONS_PATH = os.path.join(os.path.dirname(__file__), "visualizations/") @dataclass diff --git a/benchmark/scripts/benchmark_attention.py b/benchmark/scripts/benchmark_attention.py index bb4bf0a0..085fce90 100644 --- a/benchmark/scripts/benchmark_attention.py +++ b/benchmark/scripts/benchmark_attention.py @@ -31,9 +31,15 @@ def bench_memory_attention( device = "cuda" head_dim = hidden_size // nheads_q - q = torch.normal(0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device).requires_grad_() - k = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() - v = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + q = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device + ).requires_grad_() + k = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device + ).requires_grad_() + v = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device + ).requires_grad_() do = torch.randn_like(q) if provider == "torch": @@ -47,7 +53,9 @@ def fwd(): return torch.nn.functional.scaled_dot_product_attention(q, k, v) else: ngroups = nheads_q // nheads_kv - return torch.nn.functional.scaled_dot_product_attention(q, repeat_kv(k, ngroups), repeat_kv(v, ngroups)) + return torch.nn.functional.scaled_dot_product_attention( + q, repeat_kv(k, ngroups), repeat_kv(v, ngroups) + ) def full(): y = fwd() @@ -80,9 +88,15 @@ def bench_speed_attention( device = "cuda" head_dim = hidden_size // nheads_q - q = torch.normal(0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device).requires_grad_() - k = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() - v = torch.normal(0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device).requires_grad_() + q = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_q, head_dim), dtype=dtype, device=device + ).requires_grad_() + k = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device + ).requires_grad_() + v = torch.normal( + 0, 0.5, (batch_size, seqlen, nheads_kv, head_dim), dtype=dtype, device=device + ).requires_grad_() do = torch.randn_like(q) if provider == "torch": @@ -96,7 +110,9 @@ def fwd(): return torch.nn.functional.scaled_dot_product_attention(q, k, v) else: ngroups = nheads_q // nheads_kv - return torch.nn.functional.scaled_dot_product_attention(q, repeat_kv(k, ngroups), repeat_kv(v, ngroups)) + return torch.nn.functional.scaled_dot_product_attention( + q, repeat_kv(k, ngroups), repeat_kv(v, ngroups) + ) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -141,7 +157,13 @@ def full(): "x_values": [2**i for i in range(5, 15)], "kernel_providers": ["liger", "torch"], "extra_benchmark_configs": [ - {"batch_size": 4, "nheads_q": 32, "nheads_kv": 8, "hidden_size": 4096, "dtype": torch.float16} + { + "batch_size": 4, + "nheads_q": 32, + "nheads_kv": 8, + "hidden_size": 4096, + "dtype": torch.float16, + } ], "overwrite": args.overwrite, } From 15cfc639440f76a16980c4156db741e6be275ea9 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 21:31:15 +0000 Subject: [PATCH 11/12] Changed the test threshold in accordance with FA repo --- test/transformers/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_attention.py b/test/transformers/test_attention.py index 623cb9d9..291f7fe2 100644 --- a/test/transformers/test_attention.py +++ b/test/transformers/test_attention.py @@ -62,7 +62,7 @@ def _test_attention( q, k, v, attn_bias=attn_bias, causal=causal, upcast=False, reorder_ops=True ) liger_output = flash_attn_func(q, k, v, attention_bias=attn_bias, causal=causal) - compare_numerical_errors(ref_output, pt_output, liger_output, 1, 1e-4, "output") + compare_numerical_errors(ref_output, pt_output, liger_output, 2, 1e-4, "output") # Compare the gradients after the backward pass ref_dq, ref_dk, ref_dv = torch.autograd.grad( @@ -94,7 +94,7 @@ def _test_attention( ], ) @pytest.mark.parametrize( - "seqlen_q,seqlen_k", + "seqlen_q, seqlen_k", [ (1, 239), (113, 203), From 6a31c03378c9b3254f29a8b0a382058685a623c0 Mon Sep 17 00:00:00 2001 From: remi-or Date: Thu, 26 Sep 2024 21:31:15 +0000 Subject: [PATCH 12/12] Changed the test threshold in accordance with FA repo --- .../reference_implementation.py | 66 ++++++++----------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/src/liger_kernel/ops/flash_attention/reference_implementation.py b/src/liger_kernel/ops/flash_attention/reference_implementation.py index 311b5d24..546a2b96 100644 --- a/src/liger_kernel/ops/flash_attention/reference_implementation.py +++ b/src/liger_kernel/ops/flash_attention/reference_implementation.py @@ -1,35 +1,29 @@ import math +from typing import Optional, Tuple import torch -from einops import rearrange, repeat +from torch import Tensor def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, - key_leftpad=None, + seqlen_q: int, + seqlen_k: int, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite window size + query_padding_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + device: Optional[str] = None, ): - row_idx = rearrange( - torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1" - ) + row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).unsqueeze(-1) col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - if key_leftpad is not None: - key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") - col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) - col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None - else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + else key_padding_mask.sum(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) ) sq = ( seqlen_q if query_padding_mask is None - else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + else query_padding_mask.sum(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] @@ -42,20 +36,19 @@ def construct_local_mask( def flash_attn_reference( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - attn_bias=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - softcap=0.0, - upcast=True, - reorder_ops=False, - key_leftpad=None, + q: Tensor, + k: Tensor, + v: Tensor, + query_padding_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + dropout_p: float = 0.0, + dropout_mask: Optional[Tensor] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), # -1 means infinite window size + softcap: float = 0.0, + upcast: bool = True, + reorder_ops: bool = False, ): """ Arguments: @@ -84,8 +77,8 @@ def flash_attn_reference( if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + k = k.repeat_interleave(repeats=q.shape[2] // k.shape[2], dim=2) + v = v.repeat_interleave(repeats=q.shape[2] // v.shape[2], dim=2) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) @@ -97,7 +90,7 @@ def flash_attn_reference( scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_( - rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + (~key_padding_mask).unsqueeze(1).unsqueeze(1), float("-inf") ) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( @@ -107,7 +100,6 @@ def flash_attn_reference( query_padding_mask, key_padding_mask, q.device, - key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: @@ -122,7 +114,7 @@ def flash_attn_reference( # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill( - rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + (~query_padding_mask).unsqueeze(1).unsqueeze(-1), 0.0 ) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling @@ -133,5 +125,5 @@ def flash_attn_reference( attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + output.masked_fill_((~query_padding_mask).unsqueeze(-1).unsqueeze(-1), 0.0) return output.to(dtype=dtype_og)