Skip to content

Commit

Permalink
[SharkInference] Make SharkInference compile the entire module (#708)
Browse files Browse the repository at this point in the history
* [SharkInference] Make SharkInference compile the entire module

-- Previously SharkInference was compiling and providing run APIs
   for a harcoded function with function name "forward".
-- This commit makes the compiling functionality generic and now
   any function being defined within the module can be run.
-- It also creates an API to fetch all the function names defined
   within the compiled module.
-- This commit updates both web and command-line execution of Stable
   Diffusion to use new API of  SharkInference.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
  • Loading branch information
Abhishek-Varma authored Jan 3, 2023
1 parent 4ee3d95 commit e60b456
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 72 deletions.
11 changes: 6 additions & 5 deletions shark/examples/shark_inference/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def end_profiling(device):
vae_warmup_input = torch.clone(latents).detach().numpy()
clip_warmup_input = torch.randint(1, 2, (2, args.max_length))
for i in range(args.warmup_count):
vae.forward((vae_warmup_input,))
clip.forward((clip_warmup_input,))
vae("forward", (vae_warmup_input,))
clip("forward", (clip_warmup_input,))

start = time.time()

Expand All @@ -174,7 +174,7 @@ def end_profiling(device):
text_input = torch.cat([uncond_input.input_ids, text_input.input_ids])

clip_inf_start = time.time()
text_embeddings = clip.forward((text_input,))
text_embeddings = clip("forward", (text_input,))
clip_inf_end = time.time()
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
Expand All @@ -196,7 +196,8 @@ def end_profiling(device):

profile_device = start_profiling(file_path="unet.rdc")

noise_pred = unet.forward(
noise_pred = unet(
"forward",
(
latent_model_input,
timestep,
Expand Down Expand Up @@ -227,7 +228,7 @@ def end_profiling(device):
latents_numpy = latents.detach().numpy()
profile_device = start_profiling(file_path="vae.rdc")
vae_start = time.time()
images = vae.forward((latents_numpy,))
images = vae("forward", (latents_numpy,))
vae_end = time.time()
end_profiling(profile_device)
if args.use_base_vae:
Expand Down
6 changes: 4 additions & 2 deletions shark/examples/shark_inference/stable_diffusion/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def forward(self, noise_pred, sigma, latent, dt):
def scale_model_input(self, sample, timestep):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
return self.scaling_model.forward(
return self.scaling_model(
"forward",
(
sample,
sigma,
Expand All @@ -120,7 +121,8 @@ def step(self, noise_pred, timestep, latent):
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
dt = self.sigmas[step_index + 1] - sigma
return self.step_model.forward(
return self.step_model(
"forward",
(
noise_pred,
sigma,
Expand Down
3 changes: 1 addition & 2 deletions shark/examples/shark_inference/stable_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, func_name, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="linalg"
)
return _compile_module(shark_module, model_name, extra_args)

Expand All @@ -65,7 +65,6 @@ def compile_through_fx(model, inputs, model_name, extra_args=[]):

shark_module = SharkInference(
mlir_module,
func_name,
device=args.device,
mlir_dialect="linalg",
)
Expand Down
4 changes: 1 addition & 3 deletions shark/iree_eager_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from iree.runtime import DeviceArray
from torch_mlir._mlir_libs._mlir.ir import Module
from torch_mlir.compiler_utils import (
get_module_name_for_debug_dump,
run_pipeline_with_repro_report,
)
from torch_mlir.eager_mode.torch_mlir_eager_backend import (
Expand Down Expand Up @@ -64,14 +63,13 @@ def get_torch_metadata(
)

def compile(self, imported_module: Module):
fn_name = get_module_name_for_debug_dump(imported_module)
run_pipeline_with_repro_report(
imported_module,
"torch-function-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline",
"EagerMode",
)
callable, _ = get_iree_compiled_module(
imported_module, self.raw_device_str, func_name=fn_name
imported_module, self.raw_device_str
)
return callable

Expand Down
32 changes: 16 additions & 16 deletions shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def compile_module_to_flatbuffer(
module,
device,
frontend,
func_name,
model_config_path,
extra_args,
model_name="None",
Expand Down Expand Up @@ -277,62 +276,58 @@ def compile_module_to_flatbuffer(
return flatbuffer_blob


def get_iree_module(flatbuffer_blob, device, func_name):
def get_iree_module(flatbuffer_blob, device):
# Returns the compiled module and the configs.
config = get_iree_runtime_config(device)
vm_module = ireert.VmModule.from_flatbuffer(
config.vm_instance, flatbuffer_blob
)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(vm_module)
ModuleCompiled = ctx.modules.module[func_name]
ModuleCompiled = ctx.modules.module
return ModuleCompiled, config


def get_iree_compiled_module(
module,
device: str,
frontend: str = "torch",
func_name: str = "forward",
model_config_path: str = None,
extra_args: list = [],
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, frontend, func_name, model_config_path, extra_args
module, device, frontend, model_config_path, extra_args
)
return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device)


def load_flatbuffer(
flatbuffer_path: str, device: str, func_name: str = "forward"
):
def load_flatbuffer(flatbuffer_path: str, device: str):

with open(os.path.join(flatbuffer_path), "rb") as f:
flatbuffer_blob = f.read()

return get_iree_module(flatbuffer_blob, device, func_name)
return get_iree_module(flatbuffer_blob, device)


def export_iree_module_to_vmfb(
module,
device: str,
directory: str,
mlir_dialect: str = "linalg",
func_name: str = "forward",
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module, device, mlir_dialect, func_name, model_config_path, extra_args
module, device, mlir_dialect, model_config_path, extra_args
)
if module_name is None:
device_name = (
device if "://" not in device else "-".join(device.split("://"))
)
module_name = f"{mlir_dialect}_{func_name}_{device_name}"
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
Expand All @@ -355,11 +350,16 @@ def export_module_to_mlir_file(module, frontend, directory: str):


def get_results(
compiled_vm, input, config, frontend="torch", send_to_host=True
compiled_vm,
function_name,
input,
config,
frontend="torch",
send_to_host=True,
):
"""Runs a .vmfb file given inputs and config and returns output."""
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm(*device_inputs)
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
Expand All @@ -376,7 +376,7 @@ def get_results(
return np.copy(res)
return data
else:
if send_to_host:
if send_to_host and result is not None:
return result.to_host()
return result

Expand Down
7 changes: 2 additions & 5 deletions shark/shark_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class SharkBenchmarkRunner(SharkRunner):
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
extra_args: list = [],
Expand All @@ -73,7 +72,6 @@ def __init__(
SharkRunner.__init__(
self,
mlir_module,
function_name,
device,
self.mlir_dialect,
self.extra_args,
Expand All @@ -85,7 +83,6 @@ def __init__(
device,
shark_args.repro_dir,
self.mlir_dialect,
function_name,
extra_args=self.extra_args,
)

Expand Down Expand Up @@ -185,11 +182,11 @@ def benchmark_c(self):
def benchmark_python(self, inputs):
input_list = [x for x in inputs]
for i in range(shark_args.num_warmup_iterations):
self.run(input_list)
self.run("forward", input_list)

begin = time.time()
for i in range(shark_args.num_iterations):
out = self.run(input_list)
out = self.run("forward", input_list)
if i == shark_args.num_iterations - 1:
end = time.time()
print(
Expand Down
31 changes: 13 additions & 18 deletions shark/shark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class SharkInference:
----------
mlir_module : str
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
function_name : str
function to execute in the given mlir_module.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
Expand All @@ -53,10 +51,10 @@ class SharkInference:
Methods
-------
run(inputs=None):
Runs the mlir_module with the given inputs, if the inputs are not
given it autogenerates the inputs. Also, the inputs should be a
numpy array.
__call__(function_name, inputs=None):
Runs the function with `function_name` within the mlir_module along
with the given inputs, if the inputs are not given it autogenerates the
inputs. Also, the inputs should be a numpy array.
input_info():
Gives the information about the inputs required by the `function_name`.
This can be expensive as it does string matching to do so.
Expand All @@ -66,15 +64,13 @@ class SharkInference:
def __init__(
self,
mlir_module: bytes,
function_name: str = "forward",
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
dispatch_benchmark: str = None,
dispatch_benchmark_dir: str = "temp_dispatch_benchmarks",
):
self.mlir_module = mlir_module
self.function_name = function_name
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
Expand Down Expand Up @@ -113,7 +109,6 @@ def compile(self, extra_args=[]):

self.shark_runner = SharkBenchmarkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
Expand All @@ -122,7 +117,6 @@ def compile(self, extra_args=[]):
else:
self.shark_runner = SharkRunner(
self.mlir_module,
self.function_name,
self.device,
self.mlir_dialect,
extra_args=extra_args,
Expand All @@ -138,21 +132,25 @@ def compile(self, extra_args=[]):
os.system(f"rm -rf {self.temp_dispatch_benchmarks_dir}")

# inputs are considered to be tuple of np.array.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run(inputs, send_to_host)
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)

# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()

# Captures the static input information from the mlir_module.
# TODO(pashu123): Generate the input information for dynamic shapes.
def _input_info(self):
def _input_info(self, function_name):
# func_key to get the line which contains the function.
func_key = "func.func @" + self.function_name
func_key = "func.func @" + function_name
func_header = None
for line in str(self.mlir_module).splitlines():
if func_key in line:
func_header = line
break
if func_header is None:
print(f"Function: {self.function_name} not found")
print(f"Function: {function_name} not found")

import re

Expand Down Expand Up @@ -190,15 +188,13 @@ def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
self.device,
dir,
self.mlir_dialect,
self.function_name,
module_name=module_name,
extra_args=extra_args,
)

# load and return the module.
def load_module(self, path, extra_args=[]):
self.shark_runner = SharkRunner(
function_name=self.function_name,
device=self.device,
compile_vmfb=False,
extra_args=extra_args,
Expand All @@ -209,6 +205,5 @@ def load_module(self, path, extra_args=[]):
) = load_flatbuffer(
path,
self.device,
self.function_name,
)
return
Loading

0 comments on commit e60b456

Please sign in to comment.