Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX capability to the SD3 pipeline #830

Open
wants to merge 3 commits into
base: sd3-ckpt-ryzen
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from turbine_models.utils.sdxl_benchmark import run_benchmark
from turbine_models.model_runner import vmfbRunner
import onnxruntime
import pdb

from PIL import Image
import gc
Expand Down Expand Up @@ -75,6 +77,57 @@ def merge_export_arg(model_map, arg, arg_name):
# return out


class OnnxPipelineComponent:
def __init__(
self,
printer,
dest_type="numpy",
dest_dtype="fp16",
):
self.ort_session = None
self.onnx_file_path = None
self.ep = None
self.dest_type = dest_type
self.dest_dtype = dest_dtype
self.printer = printer
self.supported_dtypes = ["fp32"]
self.default_dtype = "fp32"
self.used_dtype = (
dest_dtype if dest_dtype in self.supported_dtypes else self.default_dtype
)

def load(self, onnx_file_path: str, ep="CPUExecutionProvider"):
self.onnx_file_path = onnx_file_path
self.ep = ep

self.ort_session = onnxruntime.InferenceSession(onnx_file_path, providers=[ep])
self.printer.print(f"Loading {onnx_file_path} into onnxruntime with {ep}.")

def unload(self):
self.ort_session = None
gc.collect()

# input type only support numpy
def _convert_inputs(self, inputs):
for iname in inputs.keys():
inp = inputs[iname]
if isinstance(inp, ireert.DeviceArray):
inputs[iname] = inp.to_host()
inputs[iname] = inputs[iname].astype(np_dtypes[self.used_dtype])
return inputs

def _convert_output(self, output):
return output.astype(np_dtypes[self.dest_dtype])

def __call__(self, inputs: dict):
converted_inputs = self._convert_inputs(inputs)
out = self.ort_session.run(
None,
converted_inputs,
)[0]
return self._convert_output(out)


class PipelineComponent:
"""
Wraps a VMFB runner with attributes for embedded metadata, device info, utilities and
Expand Down Expand Up @@ -269,6 +322,18 @@ def __call__(self, function_name, inputs: list):
# def _run_and_validate(self, iree_fn, torch_fn, inputs: list)


class Bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"


class Printer:
def __init__(self, verbose, start_time, print_time):
"""
Expand All @@ -284,24 +349,29 @@ def __init__(self, verbose, start_time, print_time):

def reset(self):
if self.print_time:
print(Bcolors.BOLD + Bcolors.WARNING)
if self.verbose:
self.print("Will now reset clock for printer to 0.0 [s].")
self.last_print = time.time()
self.start_time = time.time()
if self.verbose:
self.print("Clock for printer reset to t = 0.0 [s].")
print(Bcolors.ENDC, end="")

def print(self, message):
if self.verbose:
# Print something like "[t=0.123 dt=0.004] 'message'"
print(Bcolors.BOLD + Bcolors.OKCYAN)
if self.print_time:
time_now = time.time()
print(
f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
)
print(f"[ts={time_now - self.start_time:.3f}s] {message}")
# print(
# f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
# )
self.last_print = time_now
else:
print(f"{message}")
print(Bcolors.ENDC, end="")


class TurbinePipelineBase:
Expand Down Expand Up @@ -359,6 +429,8 @@ def __init__(
ireec_flags: str | dict[str] = None,
precision: str | dict[str] = "fp16",
attn_spec: str | dict[str] = None,
onnx_model_path: str | dict[str] = None,
run_onnx_mmdit: bool = False,
decomp_attn: bool | dict[bool] = False,
external_weights: str | dict[str] = None,
pipeline_dir: str = "./shark_vmfbs",
Expand All @@ -372,6 +444,7 @@ def __init__(
self.map = model_map
self.verbose = verbose
self.printer = Printer(self.verbose, time.time(), True)
self.run_onnx_mmdit = run_onnx_mmdit
if isinstance(device, dict):
assert isinstance(
target, dict
Expand All @@ -396,6 +469,7 @@ def __init__(
map_arguments = {
"ireec_flags": ireec_flags,
"precision": precision,
"onnx_model_path": onnx_model_path,
"attn_spec": attn_spec,
"decomp_attn": decomp_attn,
"external_weights": external_weights,
Expand Down Expand Up @@ -761,6 +835,7 @@ def load_map(self):
self.load_submodel(submodel)

def load_submodel(self, submodel):

if not self.map[submodel].get("vmfb"):
raise ValueError(f"VMFB not found for {submodel}.")
if not self.map[submodel].get("weights") and self.map[submodel].get(
Expand All @@ -783,6 +858,19 @@ def load_submodel(self, submodel):
)
setattr(self, submodel, self.map[submodel]["runner"])

# add an onnx runners
if self.run_onnx_mmdit and submodel == "mmdit":
dest_type = "numpy"
dest_dtype = self.map[submodel]["precision"]
onnx_runner = OnnxPipelineComponent(
printer=self.printer, dest_type=dest_type, dest_dtype=dest_dtype
)
ep = "CPUExecutionProvider"
onnx_runner.load(
onnx_file_path=self.map[submodel]["onnx_model_path"], ep=ep
)
setattr(self, submodel + "_onnx", onnx_runner)

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def export_mmdit_model(
attn_spec=None,
input_mlir=None,
weights_only=False,
onnx_model_path=None,
):
dtype = torch.float16 if precision == "fp16" else torch.float32
mmdit_model = MMDiTModel(
Expand Down
134 changes: 134 additions & 0 deletions models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import copy
import os
import sys
import math

import numpy as np
from shark_turbine.aot import *

from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from diffusers import SD3Transformer2DModel


class MMDiTModel(torch.nn.Module):
def __init__(
self,
hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers",
dtype=torch.float16,
):
super().__init__()
self.mmdit = SD3Transformer2DModel.from_pretrained(
hf_model_name,
subfolder="transformer",
torch_dtype=dtype,
low_cpu_mem_usage=False,
)

def forward(
self,
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
):
# timestep.expand(hidden_states.shape[0])
noise_pred = self.mmdit(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
return_dict=False,
)[0]
return noise_pred


@torch.no_grad()
def export_mmdit_model(
hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers",
batch_size=1,
height=512,
width=512,
precision="fp16",
max_length=77,
):
dtype = torch.float16 if precision == "fp16" else torch.float32
mmdit_model = MMDiTModel(
dtype=dtype,
)
file_prefix = "C:/Users/chiz/work/sd3/mmdit/exported/"
safe_name = (
file_prefix
+ utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit",
)
+ ".onnx"
)
print(safe_name)

do_classifier_free_guidance = True
init_batch_dim = 2 if do_classifier_free_guidance else 1
batch_size = batch_size * init_batch_dim
hidden_states_shape = (
batch_size,
16,
height // 8,
width // 8,
)
encoder_hidden_states_shape = (batch_size, 154, 4096)
pooled_projections_shape = (batch_size, 2048)
hidden_states = torch.empty(hidden_states_shape, dtype=dtype)
encoder_hidden_states = torch.empty(encoder_hidden_states_shape, dtype=dtype)
pooled_projections = torch.empty(pooled_projections_shape, dtype=dtype)
timestep = torch.empty(batch_size, dtype=dtype)
# mmdit_model(hidden_states, encoder_hidden_states, pooled_projections, timestep)

torch.onnx.export(
mmdit_model, # model being run
(
hidden_states,
encoder_hidden_states,
pooled_projections,
timestep,
), # model input (or a tuple for multiple inputs)
safe_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=[
"hidden_states",
"encoder_hidden_states",
"pooled_projections",
"timestep",
], # the model's input names
output_names=[
"sample_out",
], # the model's output names
)
return safe_name


if __name__ == "__main__":
import logging

logging.basicConfig(level=logging.DEBUG)
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args

onnx_model_name = export_mmdit_model(
args.hf_model_name,
1, # args.batch_size,
512, # args.height,
512, # args.width,
"fp16", # args.precision,
77, # args.max_length,
)

print("Saved to", onnx_model_name)
Loading
Loading