diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 1d757960f7..9e81c5668f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -657,6 +657,8 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 # export_to_edge builder_exported = _prepare_for_llama_export(args).export() + builder_exported.run_canonical_optimizations() + if args.export_only: exit() diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index aaef3cd980..176b597a94 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -232,22 +232,16 @@ def __init__( max_seq_length: int, n_heads: int, head_dim: int, - transpose_cache: bool, enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() self.max_seq_length = max_seq_length - self.is_transposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads self.head_dim = head_dim - self.transpose_cache = transpose_cache self.enable_dynamic_shape = enable_dynamic_shape self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") @@ -259,12 +253,12 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache + # input_pos: [S], k_val: [B, H, S, D] if self.enable_dynamic_shape: start_pos = input_pos[0].item() torch._check_is_size(start_pos) torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 + dim_to_slice = 2 seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token # The following lines are equivalent to: @@ -283,12 +277,8 @@ def update( else: k_out = self.k_cache v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val return k_out, v_out @@ -296,7 +286,6 @@ def update( class SDPA(nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, @@ -304,7 +293,6 @@ def __init__( enable_dynamic_shape: bool, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -314,18 +302,16 @@ def __init__( def forward( self, input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) + q: torch.Tensor, # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim) + k: torch.Tensor, # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim) + v: torch.Tensor, # (bs, n_local_kv_heads, seqlen, head_dim) bsz, seqlen, mask: torch.Tensor, ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - k, v = self.kv_cache.update(input_pos, k, v) + # TODO(kimishpatel): Move this slicing logic to Attention block so that + # SDPA does not have to take input_pos as arg if self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) @@ -336,6 +322,8 @@ def forward( else: attn_mask = mask[None, None, input_pos] + # TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention + # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) @@ -383,11 +371,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): args.max_seq_len, self.n_kv_heads, self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op don't transpose the cache. Expect untransposed q k v args.enable_dynamic_shape, ) self.SDPA = SDPA( - kv_cache=self.kv_cache, dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, @@ -414,15 +400,16 @@ def forward( # RoPE relative positional embeddings q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if self.use_kv_cache: assert input_pos is not None + k, v = self.kv_cache.update(input_pos, k, v) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) return self.wo(output) - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - # grouped multiquery attention: expand out keys and values k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index d8ac99656f..fa0b3f9251 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -37,8 +37,6 @@ def __init__( n_heads, head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, - tranposed=False, - enable_dynamic_shape=False, ): super().__init__() if cache_type not in ( @@ -52,14 +50,8 @@ def __init__( # For now supporting int8 only self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - self.is_transposed = tranposed - self.enable_dynamic_shape = enable_dynamic_shape - if self.is_transposed: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - scale_shape = (max_batch_size, n_heads, max_seq_length, 1) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -98,71 +90,37 @@ def _quantize(self, value): return quantized_value, scales, zero_points def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) # quantize current k_val and store it in the cache quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) - if self.is_transposed: - # We cannot use update_cache op at the moment - # if the cache is transposed - # Also note that we shold not need separate paths - # for dynamic shape vs ! - # Only reason it is done this way is to accommodate - # for lowering pains of backends that work better - # with index_put op. - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k_scales = self.k_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k_zp = self.k_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_k.copy_(quantized_k_val) - narrowed_k_scales.copy_(k_scales) - narrowed_k_zp.copy_(k_zero_points) - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v_scales = self.v_cache_scales.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v_zp = self.v_cache_zero_points.narrow( - dim_to_slice, start_pos, seq_length - ) - narrowed_v.copy_(quantized_v_val) - narrowed_v_scales.copy_(v_scales) - narrowed_v_zp.copy_(v_zero_points) - else: - self.k_cache[:, :, input_pos] = quantized_k_val - self.k_cache_scales[:, :, input_pos] = k_scales - self.k_cache_zero_points[:, :, input_pos] = k_zero_points - self.v_cache[:, :, input_pos] = quantized_v_val - self.v_cache_scales[:, :, input_pos] = v_scales - self.v_cache_zero_points[:, :, input_pos] = v_zero_points - else: - # Right now using custom ops on this path. - # In future we can update custom op to handle transposed cache - # as well. - # Note that we may have to revert this change if other ET - # backends such as QNN want to use quantized cache, with dynamic shape, - # instead of quantizing on their own. - # But until this opting for code simplicity - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) - _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - k_zero_points, self.k_cache_zero_points, start_pos - ) - _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) - _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) - _ = torch.ops.llama.update_cache( - v_zero_points, self.v_cache_zero_points, start_pos - ) + # Right now using custom ops on this path. + # In future we can update custom op to handle transposed cache + # as well. + # Note that we may have to revert this change if other ET + # backends such as QNN want to use quantized cache, with dynamic shape, + # instead of quantizing on their own. + # But until this opting for code simplicity + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos) + _ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + k_zero_points, self.k_cache_zero_points, start_pos + ) + _ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos) + _ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos) + _ = torch.ops.llama.update_cache( + v_zero_points, self.v_cache_zero_points, start_pos + ) k_out = torch.ops.quantized_decomposed.dequantize_per_token( self.k_cache, @@ -183,42 +141,24 @@ def update(self, input_pos, k_val, v_val): self.cache_fp_type, ) - if self.is_transposed: - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - dim_to_slice = 2 if self.is_transposed else 1 - torch._check(start_pos < self.k_cache.size(dim_to_slice)) - seq_length = k_val.size(dim_to_slice) - narrowed_k = k_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_k.copy_(k_val) - narrowed_v = v_out.narrow(dim_to_slice, start_pos, seq_length) - narrowed_v.copy_(v_val) - else: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - start_pos = input_pos[0].item() - _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) - _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) + start_pos = input_pos[0].item() + _ = torch.ops.llama.update_cache(k_val, k_out, start_pos) + _ = torch.ops.llama.update_cache(v_val, v_out, start_pos) - return k_out, v_out + return k_out.transpose(1, 2), v_out.transpose(1, 2) @classmethod def from_float(cls, kv_cache, cache_type: QuantizedCacheType): - cache_shape = kv_cache.k_cache.shape - if kv_cache.is_transposed: - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape - else: - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape return cls( max_batch_size, max_seq_length, n_heads, head_dim, cache_type, - kv_cache.is_transposed, - kv_cache.enable_dynamic_shape, ) @@ -254,7 +194,7 @@ def replace_kv_cache_with_quantized_kv_cache(module): "Replacing KVCache with QuantizedKVCache. This modifies the model in place." ) for name, child in module.named_children(): - if isinstance(child, KVCache): + if isinstance(child, KVCache) or isinstance(child, CustomKVCache): setattr( module, name, @@ -291,11 +231,13 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, S, H, D] + # input_pos: [S], k_val: [B, H, S, D] + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) start_pos = input_pos[0].item() _ = torch.ops.llama.update_cache(k_val, self.k_cache, start_pos) _ = torch.ops.llama.update_cache(v_val, self.v_cache, start_pos) - return self.k_cache, self.v_cache + return self.k_cache.transpose(1, 2), self.v_cache.transpose(1, 2) def replace_kv_cache_with_custom_kv_cache(module): @@ -313,10 +255,7 @@ def replace_kv_cache_with_custom_kv_cache(module): if isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - assert ( - child.is_transposed is False - ), "CustomKVCache does not support transposed cache" - max_batch_size, max_seq_length, n_heads, head_dim = cache_shape + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape setattr( module, name, diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 4d4b3bf7f5..f68e43cbcd 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,19 +22,9 @@ class SDPACustom(torch.nn.Module): def __init__( self, - kv_cache: Union[KVCache, QuantizedKVCache], dim: int, ): super().__init__() - # Custom op only supports float32 currently. Converting to/from float32 is - # faster than not having the op. - self.kv_cache = kv_cache - if not isinstance(kv_cache, QuantizedKVCache): - self.kv_cache = kv_cache.to(torch.float) - else: - assert ( - kv_cache.cache_fp_type == torch.float32 - ), "Only float32 is supported for custom SDPA" self.dim = dim def forward( @@ -47,6 +37,10 @@ def forward( seqlen, mask, ): + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + # Custom op only supports float32 currently. Converting to/from float32 is # faster than not having the op. input_dtype = q.dtype @@ -54,13 +48,10 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - k_cache = self.kv_cache.k_cache - v_cache = self.kv_cache.v_cache - k_cache, v_cache = self.kv_cache.update(input_pos, k, v) output = torch.ops.llama.custom_sdpa( q, - k_cache, - v_cache, + k, + v, input_pos[0].item(), None, # Attention mask 0, # dropout probability. Ignored by the code @@ -75,7 +66,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module): setattr( module, name, - SDPACustom(child.kv_cache, child.dim), + SDPACustom(child.dim), ) else: _replace_sdpa_with_custom_op(child) @@ -91,13 +82,11 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: class SDPASimple(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -112,11 +101,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] k = k.repeat_interleave(self.n_rep, dim=1) @@ -150,12 +134,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class SDPAFlex(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.n_rep = n_rep @@ -169,9 +151,10 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - - k, v = self.kv_cache.update(input_pos, k, v) + """ + q: (bs, n_heads, seqlen, head_dim) + k, v: (bs, n_local_heads, seqlen, head_dim) + """ k = repeat_kv(k, self.n_rep) v = repeat_kv(v, self.n_rep) attn_mask = mask[input_pos] @@ -191,7 +174,7 @@ def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): setattr( module, name, - SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPASimple(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_simple_sdpa(child) @@ -204,7 +187,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module): setattr( module, name, - SDPAFlex(child.kv_cache, child.dim, child.n_rep), + SDPAFlex(child.dim, child.n_rep), ) else: replace_sdpa_with_flex_sdpa(child) @@ -236,13 +219,11 @@ class SDPACoreML(torch.nn.Module): def __init__( self, - kv_cache: KVCache, dim: int, head_dim: int, n_rep: int, ): super().__init__() - self.kv_cache = kv_cache self.dim = dim self.head_dim = head_dim self.n_rep = n_rep @@ -257,11 +238,6 @@ def forward( seqlen, mask, ): - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) attn_mask = mask[None, None, input_pos] if self.n_rep > 1: @@ -279,7 +255,7 @@ def replace_sdpa_with_coreml_sdpa(module: torch.nn.Module): setattr( module, name, - SDPACoreML(child.kv_cache, child.dim, child.head_dim, child.n_rep), + SDPACoreML(child.dim, child.head_dim, child.n_rep), ) else: replace_sdpa_with_coreml_sdpa(child) @@ -366,6 +342,9 @@ def __init__( def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + # can we combine this with KVCacheCoreML? + k_val = k_val.transpose(1, 2) + v_val = v_val.transpose(1, 2) k_out = torch.ops.aten.index_put_(self.past_k_caches, [None, input_pos], k_val) v_out = torch.ops.aten.index_put_(self.past_v_caches, [None, input_pos], v_val) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ebc7f02ee1..c50acb878f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -37,6 +37,8 @@ from torch.export import export_for_training from torch.nn.attention import SDPBackend +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -108,6 +110,7 @@ def __init__( self.calibration_seq_length = calibration_seq_length self.calibration_data = calibration_data self.tokenizer_path = tokenizer_path + self.canonical_passes = [RemoveRedundantTransposes()] def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": """ @@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager": return self + def run_canonical_optimizations(self): + for pass_instance in self.canonical_passes: + logging.info(f"Running canonical pass: {pass_instance.__class__.__name__}") + res = pass_instance(self.pre_autograd_graph_module) + assert res.graph_module is not None, "Pass returned None" + self.pre_autograd_graph_module = res.graph_module + def pt2e_calibrate( self, prepared_module, diff --git a/extension/llm/export/export_passes.py b/extension/llm/export/export_passes.py new file mode 100644 index 0000000000..a63f1c8fc8 --- /dev/null +++ b/extension/llm/export/export_passes.py @@ -0,0 +1,80 @@ +import torch +from torch._subclasses import FakeTensor + +from executorch.exir.pass_base import ExportPass +from torch.fx.passes.infra.pass_base import PassResult + +def _normalize_dims(tensor: FakeTensor, dim_0: int, dim_1: int): + """ + Normalize the dimensions of a tensor. + """ + assert tensor is not None, "Tensor is None" + ndim = tensor.ndim + if dim_0 < 0: + dim_0 = ndim + dim_0 + if dim_1 < 0: + dim_1 = ndim + dim_1 + assert dim_0 < ndim and dim_1 < ndim, f"Invalid dimensions: {dim_0}, {dim_1}" + return dim_0, dim_1 + +class RemoveRedundantTransposes(ExportPass): + """ + This pass removes redundant transpose nodes in the graph. + It checks if the next node is also a transpose node and if the two transpose nodes undo each other. + For example, if the graph has the following nodes: + + node1 = torch.ops.aten.transpose.int(x, 0, 1) + node2 = torch.ops.aten.transpose.int(node1, 0, 1) + + Then node2's use can be replaced by x + + It will also check for permute nodes + node1 = torch.ops.aten.permute(x, [0, 2, 1]) + node2 = torch.ops.aten.permute(node1, [0, 2, 1]) + + Then also node2's use can be replaced by x + + NB: Does not work for inplace ops or functionalized _copy suffix ops + """ + def call(self, graph_module: torch.fx.GraphModule): + graph_changed = False + for node in graph_module.graph.nodes: + if node.op == 'call_function' and node.target == torch.ops.aten.transpose.int: + # Check if the next node is also a transpose node + tranpose_users = list(node.users.keys()) + dim_0 = node.args[1] + dim_1 = node.args[2] + dim_0, dim_1 = _normalize_dims(node.args[0].meta["val"], dim_0, dim_1) + + for user in tranpose_users: + if user.op == 'call_function' and user.target == torch.ops.aten.transpose.int: + # Get the arguments of the current and next transpose nodes + user_dim_0 = user.args[1] + user_dim_1 = user.args[2] + user_dim_0, user_dim_1 = _normalize_dims(user.args[0].meta["val"], user_dim_0, user_dim_1) + + # Check if the two transpose nodes undo each other + if dim_0 == user_dim_0 and dim_1 == user_dim_1: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + for node in graph_module.graph.nodes: + if node.op == 'call_function' and node.target == torch.ops.aten.permute.default: + # Check if the next node is also a transpose node + permute_users = list(node.users.keys()) + dim_list = node.args[1] + + for user in permute_users: + if user.op == 'call_function' and user.target == torch.ops.aten.permute.default: + # Get the arguments of the current and next transpose nodes + user_dim_list = user.args[1] + + # Check if the two permutes undo each other + if dim_list == user_dim_list: + graph_changed = True + user.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, graph_changed) diff --git a/extension/llm/export/test_export_passes.py b/extension/llm/export/test_export_passes.py new file mode 100644 index 0000000000..8478c1ad37 --- /dev/null +++ b/extension/llm/export/test_export_passes.py @@ -0,0 +1,167 @@ +import unittest +import os + +import torch +from torch.testing import FileCheck + +from torch.export import export_for_training + +from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes + +class RemoveRedundantTransposesPassTest(unittest.TestCase): + def _export(self, model, example_inputs): + exported_module = export_for_training( + model, + example_inputs, + ) + return exported_module.module() + + def _check(self, model, example_inputs, key, before_count, after_count): + gm = self._export(model, example_inputs) + FileCheck().check_count(key, before_count, exactly=True).run( + gm.code + ) + pass_res = RemoveRedundantTransposes()(gm) + FileCheck().check_count(key, after_count, exactly=True).run( + pass_res.graph_module.code + ) + + def test_transpose_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 2, 3) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_transpose_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.transpose(x, 1, 2) + x = torch.transpose(x, 1, 2) + x = x + 1 + + x = torch.transpose(x, 2, 3) + x = torch.transpose(x, 1, 2) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.transpose.int" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.transpose(x, 1, 2) + x_2 = torch.transpose(x_1, 1, 2) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.transpose(x, 1, 2) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2) + + def test_permute_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + return x + 1 + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 1, 3, 2]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + m = TestModule1() + self._check(m, (x,), key, 2, 0) + + m = TestModule2() + self._check(m, (x,), key, 4, 0) + + def test_permute_no_removal(self): + class TestModule1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.permute(x, [0, 2, 1, 3]) + x = torch.permute(x, [0, 2, 1, 3]) + x = x + 1 + + x = torch.permute(x, [0, 1, 3, 2]) + x = torch.permute(x, [0, 2, 1, 3]) + + return x + 2 + + x = torch.rand((1, 2, 3, 4)) + key = "torch.ops.aten.permute.default" + + m = TestModule1() + self._check(m, (x,), key, 4, 2) + + class TestModule2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x_1 = torch.permute(x, [0, 2, 1, 3]) + x_2 = torch.permute(x_1, [0, 2, 1, 3]) + x_2 = x_2 + 1 + + x = x_1 + 2 + x = torch.permute(x, [0, 2, 1, 3]) + + return x + x_2 + + m = TestModule2() + self._check(m, (x,), key, 3, 2)