From de23e58d28d2a5633639f618ab5eb238e157ea86 Mon Sep 17 00:00:00 2001 From: Aleksei-grovety <113356454+Aleksei-grovety@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:50:11 +0400 Subject: [PATCH] review changes use fixed_point_multiply to define fraction_size, use LUT for tanh --- .../relay/backend/contrib/ethosu/legalize.py | 46 +++++--- .../contrib/ethosu/tir_to_cs_translator.py | 5 +- .../tvm/relay/backend/contrib/ethosu/util.py | 6 - python/tvm/relay/op/contrib/ethosu.py | 111 ++++++++++++------ .../backend/contrib/ethosu/compiler_attrs.cc | 8 -- tests/python/contrib/test_ethosu/infra.py | 4 - .../contrib/test_ethosu/test_codegen.py | 15 +-- .../contrib/test_ethosu/test_legalize.py | 48 ++++---- 8 files changed, 142 insertions(+), 101 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 0ad4e8d30f99..b7e24c5bde1e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -139,7 +139,9 @@ def get_lut_from_func( ) -> List[int]: """Calculates the values of the lookup table based on the calculation function""" - if dtype == np.int8: + assert dtype in ["int8", "int16"] + + if dtype == "int8": lut_values = list() qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max for x in range(qmin, qmax + 1): @@ -150,7 +152,8 @@ def get_lut_from_func( lut_values.append(lut_result) return lut_values - elif dtype == np.int16: + else: + # dtype == "int16" input_min = ifm_scale * (np.iinfo(np.int16).min - ifm_zp) input_max = ifm_scale * (np.iinfo(np.int16).max - ifm_zp) @@ -194,8 +197,6 @@ def get_lut_from_func( lut[i] = slope + base return lut - else: - assert f"Unsupported 'dtype = {dtype}' !" class LutActivationRewriter(DFPatternCallback): @@ -222,11 +223,13 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c output_scale = float(params.ofm.q_params.scale_f32) output_zp = int(params.ofm.q_params.zero_point) - if params.ifm.dtype == "int8": + # Validation function from pattern matching checks that the input type can be int8 or int16 + ifm_dtype = params.ifm.dtype + if ifm_dtype == "int8": lut_values = get_lut_from_func( - input_scale, input_zp, output_scale, output_zp, self.calc_func, np.int8 + input_scale, input_zp, output_scale, output_zp, self.calc_func, ifm_dtype ) - lut = relay.const(lut_values, dtype=params.ifm.dtype) + lut = relay.const(lut_values, dtype=ifm_dtype) # We baked the requantization into the LUT, so we don't requantize the identity operator identity = ethosu_ops.ethosu_identity( @@ -239,12 +242,15 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c activation=self.activation_type, ) - return identity - elif params.ifm.dtype == "int16": - lut_tanh = relay.const([], "int16") - tanh_identity = ethosu_ops.ethosu_identity( + else: + # ifm_dtype == "int16" + lut = get_lut_from_func( + input_scale, input_zp, output_scale, output_zp, self.calc_func, ifm_dtype + ) + lut = relay.const(lut, dtype="int32") + identity = ethosu_ops.ethosu_identity( ifm=params.ifm.tensor, - lut=lut_tanh, + lut=lut, ifm_scale=input_scale, ifm_zero_point=0, ofm_scale=output_scale, @@ -252,9 +258,7 @@ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.c activation=self.activation_type, ) - return tanh_identity - else: - assert f"Unsupported 'ifm.dtype = {params.ifm.dtype}' !" + return identity class TanhRewriter(LutActivationRewriter): @@ -266,6 +270,17 @@ def __init__(self): ) +class TanhFixedPointRewriter(LutActivationRewriter): + """This pass adds tanh with fixed point as a LUT to the identity operator""" + + def __init__(self): + super().__init__( + params_class=ethosu_patterns.TanhFixedPointParams, + activation_type="TANH", + calc_func=math.tanh, + ) + + def sigmoid_calc_func(x: float) -> float: """Function to calculate the values for sigmoid""" # These limits are inherited from TFLite @@ -1748,6 +1763,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: ShlRewriter(), AbsRewriter(), TanhRewriter(), + TanhFixedPointRewriter(), HardSwishRewriter(), LeakyReLURewriter(), MeanRewriter(), diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 2d708d476e26..e88f9047ddc5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -877,7 +877,7 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu return None op_map = { "CLIP": vapi.NpuActivationOp.NONE_OR_RELU, - "TANH": vapi.NpuActivationOp.TANH, + "TANH": vapi.NpuActivationOp.TABLE_LOOKUP, "SIGMOID": vapi.NpuActivationOp.TABLE_LOOKUP, "LUT": vapi.NpuActivationOp.TABLE_LOOKUP, } @@ -887,9 +887,6 @@ def _create_npu_activation(serial_activation: spec.SerialActivation) -> vapi.Npu if serial_activation.op == "CLIP": act_op.min = int(serial_activation.clip_min.value) act_op.max = int(serial_activation.clip_max.value) - if serial_activation.op == "TANH": - act_op.min = float(-1.0) - act_op.max = float(1.0) if op_map[op] == vapi.NpuActivationOp.TABLE_LOOKUP: act_op.lookup_table_index = 0 return act_op diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index d0c4f443f25d..289754d5c370 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -264,12 +264,6 @@ def is_copying_constants_disabled() -> bool: return bool(compiler_attrs.disable_copying_constants) -def is_fixed_point_enabled() -> bool: - """Determine whether calculation with fixed point is enabled""" - compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() - return bool(compiler_attrs.enable_fixed_point) - - def is_striping_enabled() -> bool: """Determine whether the cascader is enabled""" compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 839a478e639a..dd04d613079b 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -31,7 +31,6 @@ is_tuple, wildcard, ) -from tvm.relay.backend.contrib.ethosu import util from tvm.relay.expr import Call, Constant # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore @@ -1230,40 +1229,23 @@ def __init__(self, func_body: Call): layout = "NHWC" - if util.is_fixed_point_enabled(): - fract_part_for_16_bits = float(1 / 2**15) - in_var = func_body.args[0] - - self.ifm = TensorParams( - in_var, - layout=layout, - scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))), - zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))), - ) - self.ofm = TensorParams( - func_body, - layout=layout, - scale=tvm.relay.Constant(tvm.nd.array(np.array(fract_part_for_16_bits))), - zero_point=tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))), - ) - else: - quantize = func_body - activation = quantize.args[0] - dequantize = activation.args[0] - in_var = dequantize.args[0] + quantize = func_body + activation = quantize.args[0] + dequantize = activation.args[0] + in_var = dequantize.args[0] - self.ifm = TensorParams( - in_var, - layout=layout, - scale=dequantize.args[DequantizeArgs.IFM_SCALE.value], - zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value], - ) - self.ofm = TensorParams( - quantize, - layout=layout, - scale=quantize.args[QuantizeArgs.OFM_SCALE.value], - zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value], - ) + self.ifm = TensorParams( + in_var, + layout=layout, + scale=dequantize.args[DequantizeArgs.IFM_SCALE.value], + zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + quantize, + layout=layout, + scale=quantize.args[QuantizeArgs.OFM_SCALE.value], + zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value], + ) def is_valid(self): """ @@ -1284,7 +1266,61 @@ def tanh_pattern(): dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) tanh = is_op("tanh")(dequant) quant = is_op("qnn.quantize")(tanh, is_constant(), is_constant()) - return quant | is_op("tanh")(wildcard()) + return quant + + +class TanhFixedPointParams: + """ + This class will parse a call to a ethos-u.tanh_fixed_point composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.tanh_fixed_point" + + @requires_vela + def __init__(self, func_body): + layout = "NHWC" + + tanh_fixed_point = func_body.args[0] + tanh = tanh_fixed_point.args[0] + # fixed_point_multiply relay operation uses multiplier with 31 fractional bits + # so to determine the size of the fraction use the formula: 31 - shift + self.fraction_size = 31 - tanh_fixed_point.attrs.shift + fract_scale = tvm.relay.Constant(tvm.nd.array(np.array(1 / 2**self.fraction_size))) + fract_zero_point = tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32"))) + + self.ifm = TensorParams( + tanh.args[0].args[0].args[0], + layout=layout, + scale=fract_scale, + zero_point=fract_zero_point, + ) + self.ofm = TensorParams( + func_body, + layout=layout, + scale=fract_scale, + zero_point=fract_zero_point, + ) + + def is_valid(self) -> bool: + """ + This function checks whether activation has compatible attributes with the NPU + """ + + if self.fraction_size < 0 or self.fraction_size > 16: + return False + if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8, np.int16]): + return False + return True + + +def tanh_fixed_point_pattern(): + """Create pattern for fixed point tanh""" + ifm = is_op("cast")(wildcard()) + ifm = is_op("fixed_point_multiply")(ifm) + tanh = is_op("tanh")(ifm) + tanh = is_op("fixed_point_multiply")(tanh) + return is_op("cast")(tanh) class SigmoidParams(LutActivationParams): @@ -2391,6 +2427,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal lambda pat: AbsParams(pat).is_valid(), ), (TanhParams.composite_name, tanh_pattern(), lambda pat: TanhParams(pat).is_valid()), + ( + TanhFixedPointParams.composite_name, + tanh_fixed_point_pattern(), + lambda pat: TanhFixedPointParams(pat).is_valid(), + ), ( MeanParams.composite_name, mean_pattern(), diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc index d4a532274864..a3a09cf1119b 100644 --- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc +++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc @@ -42,7 +42,6 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode