From 38c2d0d753b772392c813cbc0f51e85168bb7010 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 20 Dec 2024 16:09:47 -0800 Subject: [PATCH] [ExecuTorch][BE] Split kv cache and SDPA for better code sharing Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 6289ce22a2c190da7e38e098ba8a5d0254d6bf9d Pull Request resolved: https://github.com/pytorch/executorch/pull/7413 --- examples/models/llama/export_llama_lib.py | 2 + examples/models/llama/llama_transformer.py | 47 ++--- .../quantized_kv_cache.py | 145 +++++---------- .../llama/source_transformation/sdpa.py | 55 ++---- extension/llm/export/builder.py | 10 ++ extension/llm/export/export_passes.py | 80 +++++++++ extension/llm/export/test_export_passes.py | 167 ++++++++++++++++++ 7 files changed, 335 insertions(+), 171 deletions(-) create mode 100644 extension/llm/export/export_passes.py create mode 100644 extension/llm/export/test_export_passes.py diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 1d757960f7..9e81c5668f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported = _prepare_for_llama_export(args).export() + builder_exported.run_canonical_optimizations() + if args.export_only: exit() diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aaef3cd980..176b597a94 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,22 +232,16 @@ def __init__( max_seq_length: int, n_heads: int, head_dim: int, - transpose_cache: bool, enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length - self.is_transposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads self.head_dim = head_dim - self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") @@ -259,12 +253,12 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache + # input_pos: [S], k_val: [B, H, S, D] if self.enable_dynamic_shape: start_pos = input_pos[0].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 + dim_to_slice = 2 seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token # The following lines are equivalent to: @@ -283,12 +277,8 @@ def update( else: k_out = self.k_cache v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val return k_out, v_out @@ -296,7 +286,6 @@ def update( class SDPA(nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, @@ -304,7 +293,6 @@ def __init__( enable_dynamic_shape: bool, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -314,18 +302,16 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) bsz, seqlen, mask: torch.Tensor, ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - k, v = self.kv_cache.update(input_pos, k, v) + # TODO(kimishpatel): Move this slicing logic to Attention block so that + # SDPA does not have to take input_pos as arg if self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) @@ -336,6 +322,8 @@ def forward( else: attn_mask = mask[None, None, input_pos] + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention + # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( - kv_cache=self.kv_cache, dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, @@ -414,15 +400,16 @@ def forward( # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if self.use_kv_cache: assert input_pos is not None + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - # grouped multiquery attention: expand out keys and values k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index d8ac99656f..fa0b3f9251 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -37,8 +37,6 @@ def __init__( n_heads, head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, - tranposed=False, - enable_dynamic_shape=False, ): super().__init__() if cache_type not in ( @@ -52,14 +50,8 @@ def __init__( # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.is_transposed = tranposed - self.enable_dynamic_shape = enable_dynamic_shape - if self.is_transposed: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - scale_shape = (max_batch_size, n_heads, max_seq_length, 1) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -98,71 +90,37 @@ def _quantize(self, value): return quantized_value, scales, zero_points def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) # quantize current k_val and store it in the cache quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) - if self.is_transposed: - # We cannot use update_cache op at the moment - # if the cache is transposed - # Also note that we shold not need separate paths - # for dynamic shape vs ! - # Only reason it is done this way is to accommodate - # for lowering pains of backends that work better - # with index_put op. - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - self.k_cache[:, :, input_pos] = quantized_k_val - self.k_cache_scales[:, :, input_pos] = k_scales - self.k_cache_zero_points[:, :, input_pos] = k_zero_points - self.v_cache[:, :, input_pos] = quantized_v_val - self.v_cache_scales[:, :, input_pos] = v_scales - self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - # Right now using custom ops on this path. - # In future we can update custom op to handle transposed cache - # as well. - # Note that we may have to revert this change if other ET - # backends such as QNN want to use quantized cache, with dynamic shape, - # instead of quantizing on their own. - # But until this opting for code simplicity - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos - ) - _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) - _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos - ) + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) + _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val): self.cache_fp_type, ) - if self.is_transposed: - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k.copy_(k_val) - narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v.copy_(v_val) - else: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) - return k_out, v_out + return k_out.transpose(1, 2), v_out.transpose(1, 2) @classmethod def from_float(cls, kv_cache, cache_type: QuantizedCacheType): - cache_shape = kv_cache.k_cache.shape - if kv_cache.is_transposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape return cls( max_batch_size, max_seq_length, n_heads, head_dim, cache_type, - kv_cache.is_transposed, - kv_cache.enable_dynamic_shape, ) @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module): "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) for name, child in module.named_children(): - if isinstance(child, KVCache): + if isinstance(child, KVCache) or isinstance(child, CustomKVCache): setattr( module, name, @@ -291,11 +231,13 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, S, H, D] + # input_pos: [S], k_val: [B, H, S, D] + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) - return self.k_cache, self.v_cache + return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2) def replace_kv_cache_with_custom_kv_cache(module): @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module): if isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - assert ( - child.is_transposed is False - ), "CustomKVCache does not support transposed cache" - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape setattr( module, name, diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 4d4b3bf7f5..f68e43cbcd 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,19 +22,9 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], dim: int, ): super().__init__() - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. - self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): - self.kv_cache = kv_cache.to(torch.float) - else: - assert ( - kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" self.dim = dim def forward( @@ -47,6 +37,10 @@ def forward( seqlen, mask, ): + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. input_dtype = q.dtype @@ -54,13 +48,10 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) output = torch.ops.llama.custom_sdpa( q, - k_cache, - v_cache, + k, + v, input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code @@ -75,7 +66,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.dim), + SDPACustom(child.dim), ) else: _replace_sdpa_with_custom_op(child) @@ -91,13 +82,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -112,11 +101,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) @@ -150,12 +134,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class SDPAFlex(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep @@ -169,9 +151,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - k, v = self.kv_cache.update(input_pos, k, v) + """ + q: (bs, n_heads, seqlen, head_dim) + k, v: (bs, n_local_heads, seqlen, head_dim) + """ k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] @@ -191,7 +174,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): setattr( module, name, - SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPASimple(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_simple_sdpa(child) @@ -204,7 +187,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -236,13 +219,11 @@ class SDPACoreML(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -257,11 +238,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] if self.n_rep > 1: @@ -279,7 +255,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module): setattr( module, name, - SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPACoreML(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_coreml_sdpa(child) @@ -366,6 +342,9 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + # can we combine this with KVCacheCoreML? + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ebc7f02ee1..c50acb878f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -37,6 +37,8 @@ from torch.export import export_for_training from torch.nn.attention import SDPBackend +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -108,6 +110,7 @@ def __init__( self.calibration_seq_length = calibration_seq_length self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path + self.canonical_passes = [RemoveRedundantTransposes()] def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager": return self + def run_canonical_optimizations(self): + for pass_instance in self.canonical_passes: + logging.info(f"Running canonical pass: {pass_instance.__class__.__name__}") + res = pass_instance(self.pre_autograd_graph_module) + assert res.graph_module is not None, "Pass returned None" + self.pre_autograd_graph_module = res.graph_module + def pt2e_calibrate( self, prepared_module, diff --git a/extension/llm/export/export_passes.py b/extension/llm/export/export_passes.py new file mode 100644 index 0000000000..a63f1c8fc8 --- /dev/null +++ b/extension/llm/export/export_passes.py @@ -0,0 +1,80 @@ +import torch +from torch._subclasses import FakeTensor + +from executorch.exir.pass_base import ExportPass +from torch.fx.passes.infra.pass_base import PassResult + +def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int): + """ + Normalize the dimensions of a tensor. + """ + assert tensor is not None, "Tensor is None" + ndim = tensor.ndim + if dim_0 < 0: + dim_0 = ndim + dim_0 + if dim_1 < 0: + dim_1 = ndim + dim_1 + assert dim_0 < ndim and dim_1 < ndim, f"Invalid dimensions: {dim_0}, {dim_1}" + return dim_0, dim_1 + +class RemoveRedundantTransposes(ExportPass): + """ + This pass removes redundant transpose nodes in the graph. + It checks if the next node is also a transpose node and if the two transpose nodes undo each other. + For example, if the graph has the following nodes: + + node1 = torch.ops.aten.transpose.int(x, 0, 1) + node2 = torch.ops.aten.transpose.int(node1, 0, 1) + + Then node2's use can be replaced by x + + It will also check for permute nodes + node1 = torch.ops.aten.permute(x, [0, 2, 1]) + node2 = torch.ops.aten.permute(node1, [0, 2, 1]) + + Then also node2's use can be replaced by x + + NB: Does not work for inplace ops or functionalized _copy suffix ops + """ + def call(self, graph_module: torch.fx.GraphModule): + graph_changed = False + for node in graph_module.graph.nodes: + if node.op == 'call_function' and node.target == torch.ops.aten.transpose.int: + # Check if the next node is also a transpose node + tranpose_users = list(node.users.keys()) + dim_0 = node.args[1] + dim_1 = node.args[2] + dim_0, dim_1 = _normalize_dims(node.args[0].meta["val"], dim_0, dim_1) + + for user in tranpose_users: + if user.op == 'call_function' and user.target == torch.ops.aten.transpose.int: + # Get the arguments of the current and next transpose nodes + user_dim_0 = user.args[1] + user_dim_1 = user.args[2] + user_dim_0, user_dim_1 = _normalize_dims(user.args[0].meta["val"], user_dim_0, user_dim_1) + + # Check if the two transpose nodes undo each other + if dim_0 == user_dim_0 and dim_1 == user_dim_1: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + for node in graph_module.graph.nodes: + if node.op == 'call_function' and node.target == torch.ops.aten.permute.default: + # Check if the next node is also a transpose node + permute_users = list(node.users.keys()) + dim_list = node.args[1] + + for user in permute_users: + if user.op == 'call_function' and user.target == torch.ops.aten.permute.default: + # Get the arguments of the current and next transpose nodes + user_dim_list = user.args[1] + + # Check if the two permutes undo each other + if dim_list == user_dim_list: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) diff --git a/extension/llm/export/test_export_passes.py b/extension/llm/export/test_export_passes.py new file mode 100644 index 0000000000..8478c1ad37 --- /dev/null +++ b/extension/llm/export/test_export_passes.py @@ -0,0 +1,167 @@ +import unittest +import os + +import torch +from torch.testing import FileCheck + +from torch.export import export_for_training + +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes + +class RemoveRedundantTransposesPassTest(unittest.TestCase): + def _export(self, model, example_inputs): + exported_module = export_for_training( + model, + example_inputs, + ) + return exported_module.module() + + def _check(self, model, example_inputs, key, before_count, after_count): + gm = self._export(model, example_inputs) + FileCheck().check_count(key, before_count, exactly=True).run( + gm.code + ) + pass_res = RemoveRedundantTransposes()(gm) + FileCheck().check_count(key, after_count, exactly=True).run( + pass_res.graph_module.code + ) + + def test_transpose_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 2, 3) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_transpose_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 1, 2) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.transpose(x, 1, 2) + x_2 = torch.transpose(x_1, 1, 2) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.transpose(x, 1, 2) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2) + + def test_permute_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 1, 3, 2]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_permute_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 2, 1, 3]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.permute(x, [0, 2, 1, 3]) + x_2 = torch.permute(x_1, [0, 2, 1, 3]) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.permute(x, [0, 2, 1, 3]) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2)