Skip to content

Commit

Permalink
review changes
Browse files Browse the repository at this point in the history
use fixed_point_multiply to define fraction_size, use LUT for tanh
  • Loading branch information
Aleksei-grovety committed Mar 11, 2024
1 parent 045973c commit de23e58
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 101 deletions.
46 changes: 31 additions & 15 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def get_lut_from_func(
) -> List[int]:
"""Calculates the values of the lookup table based on the calculation function"""

if dtype == np.int8:
assert dtype in ["int8", "int16"]

if dtype == "int8":
lut_values = list()
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
Expand All @@ -150,7 +152,8 @@ def get_lut_from_func(
lut_values.append(lut_result)

return lut_values
elif dtype == np.int16:
else:
# dtype == "int16"
input_min = ifm_scale * (np.iinfo(np.int16).min - ifm_zp)
input_max = ifm_scale * (np.iinfo(np.int16).max - ifm_zp)

Expand Down Expand Up @@ -194,8 +197,6 @@ def get_lut_from_func(
lut[i] = slope + base

return lut
else:
assert f"Unsupported 'dtype = {dtype}' !"


class LutActivationRewriter(DFPatternCallback):
Expand All @@ -222,11 +223,13 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c
output_scale = float(params.ofm.q_params.scale_f32)
output_zp = int(params.ofm.q_params.zero_point)

if params.ifm.dtype == "int8":
# Validation function from pattern matching checks that the input type can be int8 or int16
ifm_dtype = params.ifm.dtype
if ifm_dtype == "int8":
lut_values = get_lut_from_func(
input_scale, input_zp, output_scale, output_zp, self.calc_func, np.int8
input_scale, input_zp, output_scale, output_zp, self.calc_func, ifm_dtype
)
lut = relay.const(lut_values, dtype=params.ifm.dtype)
lut = relay.const(lut_values, dtype=ifm_dtype)

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
Expand All @@ -239,22 +242,23 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c
activation=self.activation_type,
)

return identity
elif params.ifm.dtype == "int16":
lut_tanh = relay.const([], "int16")
tanh_identity = ethosu_ops.ethosu_identity(
else:
# ifm_dtype == "int16"
lut = get_lut_from_func(
input_scale, input_zp, output_scale, output_zp, self.calc_func, ifm_dtype
)
lut = relay.const(lut, dtype="int32")
identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut_tanh,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=0,
ofm_scale=output_scale,
ofm_zero_point=0,
activation=self.activation_type,
)

return tanh_identity
else:
assert f"Unsupported 'ifm.dtype = {params.ifm.dtype}' !"
return identity


class TanhRewriter(LutActivationRewriter):
Expand All @@ -266,6 +270,17 @@ def __init__(self):
)


class TanhFixedPointRewriter(LutActivationRewriter):
"""This pass adds tanh with fixed point as a LUT to the identity operator"""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.TanhFixedPointParams,
activation_type="TANH",
calc_func=math.tanh,
)


def sigmoid_calc_func(x: float) -> float:
"""Function to calculate the values for sigmoid"""
# These limits are inherited from TFLite
Expand Down Expand Up @@ -1748,6 +1763,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
ShlRewriter(),
AbsRewriter(),
TanhRewriter(),
TanhFixedPointRewriter(),
HardSwishRewriter(),
LeakyReLURewriter(),
MeanRewriter(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu
return None
op_map = {
"CLIP": vapi.NpuActivationOp.NONE_OR_RELU,
"TANH": vapi.NpuActivationOp.TANH,
"TANH": vapi.NpuActivationOp.TABLE_LOOKUP,
"SIGMOID": vapi.NpuActivationOp.TABLE_LOOKUP,
"LUT": vapi.NpuActivationOp.TABLE_LOOKUP,
}
Expand All @@ -887,9 +887,6 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu
if serial_activation.op == "CLIP":
act_op.min = int(serial_activation.clip_min.value)
act_op.max = int(serial_activation.clip_max.value)
if serial_activation.op == "TANH":
act_op.min = float(-1.0)
act_op.max = float(1.0)
if op_map[op] == vapi.NpuActivationOp.TABLE_LOOKUP:
act_op.lookup_table_index = 0
return act_op
Expand Down
6 changes: 0 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ def is_copying_constants_disabled() -> bool:
return bool(compiler_attrs.disable_copying_constants)


def is_fixed_point_enabled() -> bool:
"""Determine whether calculation with fixed point is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
return bool(compiler_attrs.enable_fixed_point)


def is_striping_enabled() -> bool:
"""Determine whether the cascader is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
Expand Down
111 changes: 76 additions & 35 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
is_tuple,
wildcard,
)
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr import Call, Constant # type: ignore
from tvm.relay.op.contrib.register import register_pattern_table # type: ignore

Expand Down Expand Up @@ -1230,40 +1229,23 @@ def __init__(self, func_body: Call):

layout = "NHWC"

if util.is_fixed_point_enabled():
fract_part_for_16_bits = float(1 / 2**15)
in_var = func_body.args[0]

self.ifm = TensorParams(
in_var,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))),
zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))),
)
self.ofm = TensorParams(
func_body,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))),
zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))),
)
else:
quantize = func_body
activation = quantize.args[0]
dequantize = activation.args[0]
in_var = dequantize.args[0]
quantize = func_body
activation = quantize.args[0]
dequantize = activation.args[0]
in_var = dequantize.args[0]

self.ifm = TensorParams(
in_var,
layout=layout,
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
quantize,
layout=layout,
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
)
self.ifm = TensorParams(
in_var,
layout=layout,
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
quantize,
layout=layout,
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
)

def is_valid(self):
"""
Expand All @@ -1284,7 +1266,61 @@ def tanh_pattern():
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
tanh = is_op("tanh")(dequant)
quant = is_op("qnn.quantize")(tanh, is_constant(), is_constant())
return quant | is_op("tanh")(wildcard())
return quant


class TanhFixedPointParams:
"""
This class will parse a call to a ethos-u.tanh_fixed_point composite function
and extract the parameter information.
"""

composite_name = "ethos-u.tanh_fixed_point"

@requires_vela
def __init__(self, func_body):
layout = "NHWC"

tanh_fixed_point = func_body.args[0]
tanh = tanh_fixed_point.args[0]
# fixed_point_multiply relay operation uses multiplier with 31 fractional bits
# so to determine the size of the fraction use the formula: 31 - shift
self.fraction_size = 31 - tanh_fixed_point.attrs.shift
fract_scale = tvm.relay.Constant(tvm.nd.array(np.array(1 / 2**self.fraction_size)))
fract_zero_point = tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32")))

self.ifm = TensorParams(
tanh.args[0].args[0].args[0],
layout=layout,
scale=fract_scale,
zero_point=fract_zero_point,
)
self.ofm = TensorParams(
func_body,
layout=layout,
scale=fract_scale,
zero_point=fract_zero_point,
)

def is_valid(self) -> bool:
"""
This function checks whether activation has compatible attributes with the NPU
"""

if self.fraction_size < 0 or self.fraction_size > 16:
return False
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8, np.int16]):
return False
return True


def tanh_fixed_point_pattern():
"""Create pattern for fixed point tanh"""
ifm = is_op("cast")(wildcard())
ifm = is_op("fixed_point_multiply")(ifm)
tanh = is_op("tanh")(ifm)
tanh = is_op("fixed_point_multiply")(tanh)
return is_op("cast")(tanh)


class SigmoidParams(LutActivationParams):
Expand Down Expand Up @@ -2391,6 +2427,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
lambda pat: AbsParams(pat).is_valid(),
),
(TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()),
(
TanhFixedPointParams.composite_name,
tanh_fixed_point_pattern(),
lambda pat: TanhFixedPointParams(pat).is_valid(),
),
(
MeanParams.composite_name,
mean_pattern(),
Expand Down
8 changes: 0 additions & 8 deletions src/relay/backend/contrib/ethosu/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
Bool enable_cascader = Bool(false);
Bool enable_striping = Bool(false);
Bool disable_copying_constants = Bool(false);
Bool enable_fixed_point = Bool(false);
String dev_force_block_config;
String dev_max_open_plans;
String dev_max_closed_plans;
Expand Down Expand Up @@ -72,13 +71,6 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
"in "
"the linker script for section \".rodata.tvm\" that the constants are located in SRAM)")
.set_default(Bool(false));
TVM_ATTR_FIELD(enable_fixed_point)
.describe(
"Whether calculation with fixed point is enabled. When this option "
"is "
"enabled, it is assumed that input data should be converted to fixed point "
"representation")
.set_default(Bool(false));
String dev_warning = "Option is intended for development and debugging purposes only. ";
TVM_ATTR_FIELD(dev_force_block_config)
.describe((dev_warning + String("Force the block config to a given value; format = "
Expand Down
4 changes: 0 additions & 4 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def create_test_runner(
enable_cascader=False,
enable_striping=False,
workspace_pools=None,
enable_fixed_point=False,
):

file_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -169,7 +168,6 @@ def create_test_runner(
"accelerator_config": accel,
"enable_cascader": enable_cascader,
"enable_striping": enable_striping,
"enable_fixed_point": enable_fixed_point,
},
"tir.usmp.enable": enable_usmp,
"tir.usmp.algorithm": "hill_climb",
Expand Down Expand Up @@ -335,7 +333,6 @@ def compare_ethosu_with_reference(
output_tolerance=0,
print_cmm=False,
enable_cascader=None,
enable_fixed_point=False,
):
if enable_cascader is None:
enable_cascader = "u65" not in accel_type
Expand All @@ -362,7 +359,6 @@ def compare_ethosu_with_reference(
enable_cascader=enable_cascader,
enable_striping=False,
workspace_pools=workspace_pools,
enable_fixed_point=enable_fixed_point,
)
compiled_models = build_source(
mod,
Expand Down
15 changes: 8 additions & 7 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,15 +1678,19 @@ def convert_to_fixed_point(arr, fract_size):
@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
@pytest.mark.parametrize(
"ifm_shape,fract_size,tolerance",
[[(1, 2, 8, 4), 15, 0.001], [(1, 8), 12, 0.15], [(1, 1, 4, 8), 10, 0.25]],
[[(1, 2, 8, 4), 15, 0.001], [(1, 8), 12, 0.001], [(1, 1, 4, 8), 10, 0.002]],
)
def test_ethosu_tanh_fixed_point(accel_type, ifm_shape, fract_size, tolerance):
np.random.seed(0)
dtype = "int16"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
tanh = relay.tanh(ifm)
ifm_fixed_point = relay.cast(ifm, "int32")
ifm_fixed_point = relay.fixed_point_multiply(ifm_fixed_point, 2**31 - 1, 0)
tanh = relay.tanh(ifm_fixed_point)
tanh = relay.fixed_point_multiply(tanh, 1, 31 - fract_size)
tanh = relay.cast(tanh, dtype)
return tvm.IRModule.from_expr(relay.Function([ifm], tanh))

def generate_ref(input_data):
Expand All @@ -1697,25 +1701,22 @@ def convert_to_fixed_point(arr, fract_size):
return np.array(arr * fract_fact, dtype=np.int16)

cpu_mod = create_model()
ethosu_mod = partition_for_ethosu(cpu_mod)

input_data = {"ifm": np.random.uniform(-1, 1, size=ifm_shape)}
output_data = generate_ref(input_data["ifm"])

input_data = {"ifm": convert_to_fixed_point(input_data["ifm"], fract_size)}
output_data = {"output": convert_to_fixed_point(output_data, fract_size)}
tolerance = convert_to_fixed_point(tolerance, fract_size)

config = {"enable_fixed_point": True}
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
ethosu_mod = partition_for_ethosu(cpu_mod)

infra.compare_ethosu_with_reference(
ethosu_mod,
input_data,
output_data,
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
output_tolerance=tolerance,
enable_fixed_point=True,
)


Expand Down
Loading

0 comments on commit de23e58

Please sign in to comment.