Skip to content

Commit

Permalink
Use transform to replace rms_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheng Feng Wu committed Sep 6, 2024
1 parent 3614d3b commit 8ab0e79
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 2 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
get_quant_embedding_transform,
get_quant_weight_transform,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
from .source_transformation.sdpa import (
replace_causal_mask,
Expand Down Expand Up @@ -411,6 +412,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
transforms.append(convert_linear_to_conv2d)
export_fn = torch.export.export

Expand Down
7 changes: 4 additions & 3 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, dim: int, eps: float = 1e-6):
"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

Expand Down Expand Up @@ -416,8 +417,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.block_sparse_moe = MOEFeedForward(args)
else:
self.feed_forward = FeedForward(args)
self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
h = self.attention.forward(
Expand All @@ -443,7 +444,7 @@ def __init__(self, params: ModelArgs):
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps)
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
Expand Down
23 changes: 23 additions & 0 deletions examples/models/llama2/source_transformation/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.examples.models.llama2.llama_transformer import RMSNorm


def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, RMSNorm):
rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps)
rms_norm.weight = child.weight
setattr(
module,
name,
rms_norm,
)
else:
replace_rms_norm_with_native_rms_norm(child)
return module

0 comments on commit 8ab0e79

Please sign in to comment.