Skip to content

Commit

Permalink
[microNPU][ETHOSU] Add fixed point for tanh
Browse files Browse the repository at this point in the history
Add support for calculation tanh with 16 bits fixed point. Add flag enable_fixed_point to enable fixed point calculation. We get good accuracy with 1 bit to integer part and 15 bits for fractional, with other cases we get worse results.
  • Loading branch information
Aleksei-grovety committed Mar 11, 2024
1 parent 254e90a commit 045973c
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 50 deletions.
120 changes: 89 additions & 31 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,67 @@ def get_lut_from_func(
ofm_scale: float,
ofm_zp: int,
func: Callable[[float], float],
dtype,
) -> List[int]:
"""Calculates the values of the lookup table based on the calculation function"""

lut_values = list()
# Only int8 is currently supported
dtype = np.int8
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = func(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

return lut_values
if dtype == np.int8:
lut_values = list()
qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
for x in range(qmin, qmax + 1):
x_real = ifm_scale * (x - ifm_zp)
out_real = func(x_real)
lut_result = int(util.round_away_zero(ofm_zp + out_real / ofm_scale))
lut_result = min(qmax, max(qmin, lut_result))
lut_values.append(lut_result)

return lut_values
elif dtype == np.int16:
input_min = ifm_scale * (np.iinfo(np.int16).min - ifm_zp)
input_max = ifm_scale * (np.iinfo(np.int16).max - ifm_zp)

output_min = ofm_scale * (np.iinfo(np.int16).min - ofm_zp)
output_max = ofm_scale * (np.iinfo(np.int16).max - ofm_zp)
# Create 16 bit lut following the reference
nbr_steps = 512
step = (input_max - input_min) / nbr_steps
half_step = step / 2
output_scaling_inv = (np.iinfo(np.int16).max - np.iinfo(np.int16).min + 1) / (
output_max - output_min
)
table_min = np.iinfo(np.int16).min
table_max = np.iinfo(np.int16).max

values = []
for i in range(nbr_steps):
val = func(input_min + i * step)
val_midpoint = func(input_min + i * step + half_step)
val_next = func(input_min + (i + 1) * step)

sample_val = util.round_away_zero(val * output_scaling_inv)
midpoint_interp_val = util.round_away_zero(
(val_next * output_scaling_inv + util.round_away_zero(val * output_scaling_inv)) / 2
)
midpoint_val = util.round_away_zero(val_midpoint * output_scaling_inv)
midpoint_err = midpoint_interp_val - midpoint_val
bias = util.round_away_zero(midpoint_err / 2)

lut_result = min(max(sample_val - bias, table_min), table_max)
values.append(lut_result)

val = util.round_away_zero(func(input_max) * output_scaling_inv)
lut_result = min(max(val, table_min), table_max)
values.append(lut_result)
# Convert to hardware 16bit lut with base and slope
lut = [0] * nbr_steps
for i in range(nbr_steps):
slope = (int(values[i + 1]) - int(values[i])) << 16
base = int(values[i])
lut[i] = slope + base

return lut
else:
assert f"Unsupported 'dtype = {dtype}' !"


class LutActivationRewriter(DFPatternCallback):
Expand All @@ -176,27 +222,39 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c
output_scale = float(params.ofm.q_params.scale_f32)
output_zp = int(params.ofm.q_params.zero_point)

lut_values = get_lut_from_func(
input_scale,
input_zp,
output_scale,
output_zp,
self.calc_func,
)
lut = relay.const(lut_values, dtype=params.ifm.dtype)
if params.ifm.dtype == "int8":
lut_values = get_lut_from_func(
input_scale, input_zp, output_scale, output_zp, self.calc_func, np.int8
)
lut = relay.const(lut_values, dtype=params.ifm.dtype)

# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation=self.activation_type,
)
# We baked the requantization into the LUT, so we don't requantize the identity operator
identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=input_scale,
ifm_zero_point=input_zp,
ofm_scale=input_scale,
ofm_zero_point=input_zp,
activation=self.activation_type,
)

return identity
return identity
elif params.ifm.dtype == "int16":
lut_tanh = relay.const([], "int16")
tanh_identity = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut_tanh,
ifm_scale=input_scale,
ifm_zero_point=0,
ofm_scale=output_scale,
ofm_zero_point=0,
activation=self.activation_type,
)

return tanh_identity
else:
assert f"Unsupported 'ifm.dtype = {params.ifm.dtype}' !"


class TanhRewriter(LutActivationRewriter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu
return None
op_map = {
"CLIP": vapi.NpuActivationOp.NONE_OR_RELU,
"TANH": vapi.NpuActivationOp.TABLE_LOOKUP,
"TANH": vapi.NpuActivationOp.TANH,
"SIGMOID": vapi.NpuActivationOp.TABLE_LOOKUP,
"LUT": vapi.NpuActivationOp.TABLE_LOOKUP,
}
Expand All @@ -887,6 +887,9 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu
if serial_activation.op == "CLIP":
act_op.min = int(serial_activation.clip_min.value)
act_op.max = int(serial_activation.clip_max.value)
if serial_activation.op == "TANH":
act_op.min = float(-1.0)
act_op.max = float(1.0)
if op_map[op] == vapi.NpuActivationOp.TABLE_LOOKUP:
act_op.lookup_table_index = 0
return act_op
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ def is_copying_constants_disabled() -> bool:
return bool(compiler_attrs.disable_copying_constants)


def is_fixed_point_enabled() -> bool:
"""Determine whether calculation with fixed point is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
return bool(compiler_attrs.enable_fixed_point)


def is_striping_enabled() -> bool:
"""Determine whether the cascader is enabled"""
compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
Expand Down
54 changes: 36 additions & 18 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_tuple,
wildcard,
)
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.expr import Call, Constant # type: ignore
from tvm.relay.op.contrib.register import register_pattern_table # type: ignore

Expand Down Expand Up @@ -1229,29 +1230,46 @@ def __init__(self, func_body: Call):

layout = "NHWC"

quantize = func_body
activation = quantize.args[0]
dequantize = activation.args[0]
in_var = dequantize.args[0]
if util.is_fixed_point_enabled():
fract_part_for_16_bits = float(1 / 2**15)
in_var = func_body.args[0]

self.ifm = TensorParams(
in_var,
layout=layout,
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
quantize,
layout=layout,
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
)
self.ifm = TensorParams(
in_var,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))),
zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))),
)
self.ofm = TensorParams(
func_body,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))),
zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))),
)
else:
quantize = func_body
activation = quantize.args[0]
dequantize = activation.args[0]
in_var = dequantize.args[0]

self.ifm = TensorParams(
in_var,
layout=layout,
scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
quantize,
layout=layout,
scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
)

def is_valid(self):
"""
This function checks whether activation has compatible attributes with the NPU
"""
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8, np.int16]):
return False
return True

Expand All @@ -1266,7 +1284,7 @@ def tanh_pattern():
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
tanh = is_op("tanh")(dequant)
quant = is_op("qnn.quantize")(tanh, is_constant(), is_constant())
return quant
return quant | is_op("tanh")(wildcard())


class SigmoidParams(LutActivationParams):
Expand Down
8 changes: 8 additions & 0 deletions src/relay/backend/contrib/ethosu/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
Bool enable_cascader = Bool(false);
Bool enable_striping = Bool(false);
Bool disable_copying_constants = Bool(false);
Bool enable_fixed_point = Bool(false);
String dev_force_block_config;
String dev_max_open_plans;
String dev_max_closed_plans;
Expand Down Expand Up @@ -71,6 +72,13 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
"in "
"the linker script for section \".rodata.tvm\" that the constants are located in SRAM)")
.set_default(Bool(false));
TVM_ATTR_FIELD(enable_fixed_point)
.describe(
"Whether calculation with fixed point is enabled. When this option "
"is "
"enabled, it is assumed that input data should be converted to fixed point "
"representation")
.set_default(Bool(false));
String dev_warning = "Option is intended for development and debugging purposes only. ";
TVM_ATTR_FIELD(dev_force_block_config)
.describe((dev_warning + String("Force the block config to a given value; format = "
Expand Down
4 changes: 4 additions & 0 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def create_test_runner(
enable_cascader=False,
enable_striping=False,
workspace_pools=None,
enable_fixed_point=False,
):

file_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -168,6 +169,7 @@ def create_test_runner(
"accelerator_config": accel,
"enable_cascader": enable_cascader,
"enable_striping": enable_striping,
"enable_fixed_point": enable_fixed_point,
},
"tir.usmp.enable": enable_usmp,
"tir.usmp.algorithm": "hill_climb",
Expand Down Expand Up @@ -333,6 +335,7 @@ def compare_ethosu_with_reference(
output_tolerance=0,
print_cmm=False,
enable_cascader=None,
enable_fixed_point=False,
):
if enable_cascader is None:
enable_cascader = "u65" not in accel_type
Expand All @@ -359,6 +362,7 @@ def compare_ethosu_with_reference(
enable_cascader=enable_cascader,
enable_striping=False,
workspace_pools=workspace_pools,
enable_fixed_point=enable_fixed_point,
)
compiled_models = build_source(
mod,
Expand Down
44 changes: 44 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,5 +1675,49 @@ def convert_to_fixed_point(arr, fract_size):
)


@pytest.mark.parametrize("accel_type", ["ethos-u55-256", "ethos-u65-256"])
@pytest.mark.parametrize(
"ifm_shape,fract_size,tolerance",
[[(1, 2, 8, 4), 15, 0.001], [(1, 8), 12, 0.15], [(1, 1, 4, 8), 10, 0.25]],
)
def test_ethosu_tanh_fixed_point(accel_type, ifm_shape, fract_size, tolerance):
np.random.seed(0)
dtype = "int16"

def create_model():
ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
tanh = relay.tanh(ifm)
return tvm.IRModule.from_expr(relay.Function([ifm], tanh))

def generate_ref(input_data):
return np.tanh(input_data)

def convert_to_fixed_point(arr, fract_size):
fract_fact = 0b1 << fract_size
return np.array(arr * fract_fact, dtype=np.int16)

cpu_mod = create_model()
input_data = {"ifm": np.random.uniform(-1, 1, size=ifm_shape)}
output_data = generate_ref(input_data["ifm"])

input_data = {"ifm": convert_to_fixed_point(input_data["ifm"], fract_size)}
output_data = {"output": convert_to_fixed_point(output_data, fract_size)}
tolerance = convert_to_fixed_point(tolerance, fract_size)

config = {"enable_fixed_point": True}
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
ethosu_mod = partition_for_ethosu(cpu_mod)

infra.compare_ethosu_with_reference(
ethosu_mod,
input_data,
output_data,
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
output_tolerance=tolerance,
enable_fixed_point=True,
)


if __name__ == "__main__":
tvm.testing.main()
Loading

0 comments on commit 045973c

Please sign in to comment.