Skip to content

Commit

Permalink
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing
Browse files Browse the repository at this point in the history
Summary:

Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.

How.
Earlier PRs decoupled kv cache update from sdpa. So now
1. Decouple SDPA nn.Module from KV cache.
2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA
   both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
   tensors.
3. 2 will introduce multiple tranposes when KVCache and SDPA are
   replaced by custom modules, but we will write graph pass to undo
   those.

Test Plan:
Existing tests.
Make sure perf doesnt regress

ghstack-source-id: 6289ce22a2c190da7e38e098ba8a5d0254d6bf9d
Pull Request resolved: #7413
  • Loading branch information
kimishpatel committed Dec 21, 2024
1 parent d8e1d04 commit 38c2d0d
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 171 deletions.
2 changes: 2 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# export_to_edge
builder_exported = _prepare_for_llama_export(args).export()

builder_exported.run_canonical_optimizations()

if args.export_only:
exit()

Expand Down
47 changes: 17 additions & 30 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,16 @@ def __init__(
max_seq_length: int,
n_heads: int,
head_dim: int,
transpose_cache: bool,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
self.is_transposed = transpose_cache
if transpose_cache:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
self.head_dim = head_dim
self.transpose_cache = transpose_cache
self.enable_dynamic_shape = enable_dynamic_shape
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
Expand All @@ -259,12 +253,12 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
# input_pos: [S], k_val: [B, H, S, D]
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
dim_to_slice = 2 if self.transpose_cache else 1
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
# The following lines are equivalent to:
Expand All @@ -283,28 +277,22 @@ def update(
else:
k_out = self.k_cache
v_out = self.v_cache
if self.transpose_cache:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
k_out[:, input_pos] = k_val
v_out[:, input_pos] = v_val
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class SDPA(nn.Module):
def __init__(
self,
kv_cache: KVCache,
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.kv_cache = kv_cache
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
Expand All @@ -314,18 +302,16 @@ def __init__(
def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim)
q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim)
bsz,
seqlen,
mask: torch.Tensor,
) -> torch.Tensor:
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

k, v = self.kv_cache.update(input_pos, k, v)
# TODO(kimishpatel): Move this slicing logic to Attention block so that
# SDPA does not have to take input_pos as arg
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
Expand All @@ -336,6 +322,8 @@ def forward(
else:
attn_mask = mask[None, None, input_pos]

# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
# can natively support GQA now. But needs enable_gqa=True
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
Expand Down Expand Up @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
args.max_seq_len,
self.n_kv_heads,
self.head_dim,
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v
args.enable_dynamic_shape,
)
self.SDPA = SDPA(
kv_cache=self.kv_cache,
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
Expand All @@ -414,15 +400,16 @@ def forward(
# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
Expand Down
145 changes: 42 additions & 103 deletions examples/models/llama/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def __init__(
n_heads,
head_dim,
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
tranposed=False,
enable_dynamic_shape=False,
):
super().__init__()
if cache_type not in (
Expand All @@ -52,14 +50,8 @@ def __init__(
# For now supporting int8 only
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
self.is_transposed = tranposed
self.enable_dynamic_shape = enable_dynamic_shape
if self.is_transposed:
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
else:
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
Expand Down Expand Up @@ -98,71 +90,37 @@ def _quantize(self, value):
return quantized_value, scales, zero_points

def update(self, input_pos, k_val, v_val):
"""
k_val, v_val: [B, H, S, D]
return: [B, H, S, D]
However the storage is [B, S, H, D] so we incur transpose in, transpose out
This shall be removed by subsequent post-export graph pass
"""
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
# quantize current k_val and store it in the cache
quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)

quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

if self.is_transposed:
# We cannot use update_cache op at the moment
# if the cache is transposed
# Also note that we shold not need separate paths
# for dynamic shape vs !
# Only reason it is done this way is to accommodate
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
_ = torch.ops.llama.update_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand All @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val):
self.cache_fp_type,
)

if self.is_transposed:
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k.copy_(k_val)
narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v.copy_(v_val)
else:
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
else:
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, k_out, start_pos)
_ = torch.ops.llama.update_cache(v_val, v_out, start_pos)

return k_out, v_out
return k_out.transpose(1, 2), v_out.transpose(1, 2)

@classmethod
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
cache_shape = kv_cache.k_cache.shape
if kv_cache.is_transposed:
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
else:
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
if isinstance(kv_cache, CustomKVCache):
# If replacing custom kv cache, then the shape is [B, S, H, D]
max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape
return cls(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
cache_type,
kv_cache.is_transposed,
kv_cache.enable_dynamic_shape,
)


Expand Down Expand Up @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, KVCache):
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
setattr(
module,
name,
Expand Down Expand Up @@ -291,11 +231,13 @@ def __init__(
def update(
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [S], k_val: [B, S, H, D]
# input_pos: [S], k_val: [B, H, S, D]
k_val = k_val.transpose(1, 2)
v_val = v_val.transpose(1, 2)
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos)
_ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos)
return self.k_cache, self.v_cache
return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2)


def replace_kv_cache_with_custom_kv_cache(module):
Expand All @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module):
if isinstance(child, KVCache):
cache_shape = child.k_cache.shape
cache_dtype = child.k_cache.dtype
assert (
child.is_transposed is False
), "CustomKVCache does not support transposed cache"
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
setattr(
module,
name,
Expand Down
Loading

0 comments on commit 38c2d0d

Please sign in to comment.