Skip to content

Commit

Permalink
Only add pass when vision model
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Dec 18, 2024
1 parent 9cdfb43 commit 9e68531
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 42 deletions.
12 changes: 10 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import torch

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.cache_pos_init_mutable_pass import (
CachePosToInitializedMutableBufferPass,
)

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

Expand Down Expand Up @@ -760,6 +763,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [CachePosToInitializedMutableBufferPass()]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -774,7 +780,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(
passes=additional_passes,
)

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -792,7 +800,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(passes=additional_passes)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
2 changes: 0 additions & 2 deletions examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)

from executorch.extension.pybindings.portable_lib import (
_load_for_executorch,
_load_for_executorch_from_buffer,
)

Expand Down Expand Up @@ -50,7 +49,6 @@ def __init__(self, args):
with open(args.pte, "rb") as f:
self.model_bytes = f.read()
self.model = _load_for_executorch_from_buffer(self.model_bytes)
# self.model = _load_for_executorch(args.pte)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
1 change: 0 additions & 1 deletion exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,6 @@ def placeholder(

if isinstance(target, str) and isinstance(spec, TensorSpec):
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}")

# If the placeholder has a constant_tag, it is external to the PTE file
# and requires a fqn and location=TensorDataLocation.EXTERNAL
Expand Down
21 changes: 0 additions & 21 deletions exir/passes/init_mutable_buffer_pass.py

This file was deleted.

2 changes: 0 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
OpReplacePass,
)
from executorch.exir.passes.external_constants_pass import external_constants_pass
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)
Expand Down Expand Up @@ -707,7 +706,6 @@ def edge_to_executorch_passes(
passes: List[PassType] = [
*config.passes,
SpecPropPass(),
InitMutableBufferPass(),
# ExecuTorch backend ops are unable to handle unbacked symints. So after
# this pass, passes cannot be Interpreter-based, because it will fail if
# there exists an unbacked symint operation.
Expand Down
22 changes: 13 additions & 9 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -395,26 +396,29 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_executorch(self) -> "LLMEdgeManager":
def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
]
if passes:
to_executorch_passes.extend(passes)

self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
],
passes=to_executorch_passes,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
print(self.export_program.dump_executorch_program(verbose=True))
logging.info(
"Required memory for activation in bytes: {}".format(
self.export_program._emitter_output.program.execution_plan[
Expand Down
5 changes: 0 additions & 5 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <cinttypes> // @donotremove
#include <cstdint>
#include <cstdio>
#include <iostream>

#include <executorch/extension/evalue_util/print_evalue.h>
#include <executorch/runtime/backend/interface.h>
Expand Down Expand Up @@ -1181,10 +1180,6 @@ Error Method::execute_instruction() {
if (err == Error::Ok) {
step_state_.instr_idx = next_instr_idx;
}

// TODO: Print an EValue.
std::cout << "(" << values_[1] << " ) Printing kv_cache k_cache: " << executorch::extension::evalue_edge_items(9216) << values_[2] << std::endl;

return err;
}

Expand Down

0 comments on commit 9e68531

Please sign in to comment.