From 73e0184e64de21053a325f360f8524e06471092d Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Wed, 23 Aug 2023 17:04:26 -0500 Subject: [PATCH 1/4] Replacing unary ops with take op --- .../tvm/contrib/hexagon/generate_take_op.py | 86 ++++ .../tvm/contrib/hexagon/hexagon_unary_ops.py | 93 ++++ .../python/contrib/test_hexagon/test_take.py | 403 ++++++++++++++++++ 3 files changed, 582 insertions(+) create mode 100644 python/tvm/contrib/hexagon/generate_take_op.py create mode 100644 python/tvm/contrib/hexagon/hexagon_unary_ops.py create mode 100644 tests/python/contrib/test_hexagon/test_take.py diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py new file mode 100644 index 000000000000..0ca6aefb52f0 --- /dev/null +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.contrib.hexagon import hexagon_unary_ops + + +def op_replace(call_node): + def is_op(op_name: str, call_node: relax.Call) -> bool: + if not isinstance(call_node, relax.Call): + return False + call_tir_op = tvm.ir.Op.get("relax.call_tir") + if call_node.op != call_tir_op: + return False + global_var = call_node.args[0] + return op_name in global_var.name_hint + + ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"] + for op in ops: + if is_op(op, call_node): + return True + return False + + +@relax.expr_functor.mutator +class Tanh2TakeReplace(tvm.relax.PyExprMutator): + def __init__(self, mod: tvm.IRModule) -> None: + super().__init__(mod) + self.mod_ = mod + + def transform(self) -> tvm.IRModule: + # Iterate over all the nodes to check for the node replaceable + for global_var, func in self.mod_.functions.items(): + # Skip non-relax functions + if not isinstance(func, relax.Function): + continue + updated_func = self.visit_expr(func) + self.builder_.normalize(updated_func) + self.builder_.update_func(global_var, updated_func) + # At the end of the transformation we return the updated IRModule from the BlockBuilder. + return self.builder_.get() + + def visit_call_(self, call_node: relax.Call) -> relax.Call: + if call_node.args[1][0].struct_info.dtype == "uint8": + if op_replace(call_node): + inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]] + # LUT node creation + LUT = hexagon_unary_ops.LUT_generation( + inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint + ) + # Take operation node creation + take_func = hexagon_unary_ops.generate_take_primfunc(inp, call_node.struct_info) + take_func_gv = self.builder_.add_func(take_func, "take") + take_node = relax.call_tir( + take_func_gv, + relax.expr.Tuple( + [call_node.args[1][0], relax.expr.Constant(tvm.nd.array(LUT))] + ), + call_node.struct_info, + ) + return take_node + return call_node + + +@tvm.ir.transform.module_pass(opt_level=2, name="replace_tanh_take") +class PassReplaceWithTakeOpPrimFuncs: + def transform_module(self, mod, ctx): + return Tanh2TakeReplace(mod).transform() diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py b/python/tvm/contrib/hexagon/hexagon_unary_ops.py new file mode 100644 index 000000000000..c4cdf48997a4 --- /dev/null +++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +from scipy import special + +from tvm import te +from tvm.script import tir as T + +###################################################################### +#################### PRIMFUNC FOR LUT and Take Op #################### +###################################################################### + + +def saturate(x: te.Tensor, dtype: str): + """Saturate value for the specified data type""" + return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + + +def hardswish_func(x): + x2 = np.add(x, 3.0) + x2 = np.clip(x2, 0.0, 6.0) + return x * x2 / 6.0 + + +def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None: + LUT = [] + for i in range(256): + i = np.int32(i) + # converting the constants to the numpy value + if inp_zp.data.shape == (): + i_zp = inp_zp.data.numpy()[()] + if inp_scale.data.shape == (): + i_scale = inp_scale.data.numpy()[()] + if out_zp.data.shape == (): + o_zp = out_zp.data.numpy()[()] + if out_scale.data.shape == (): + o_scale = out_scale.data.numpy()[()] + # Dequantization followed by computing the op value + dequant = (i - i_zp) * i_scale + if op_name == "tanh": + op_val = np.tanh(dequant) + elif op_name == "sqrt": + op_val = np.sqrt(dequant) + elif op_name == "rsqrt": + op_val = 1 / np.sqrt(dequant) + elif op_name == "exp": + op_val = np.exp(dequant) + elif op_name == "erf": + op_val = special.erf(dequant) + elif op_name == "sigmoid": + op_val = 1 / (1 + np.exp(np.negative(dequant))) + elif op_name == "hardswish": + op_val = hardswish_func(dequant) + elif op_name == "log": + op_val = np.log(dequant) + elif op_name == "abs": + op_val = np.abs(dequant) + # Quantizing the value generated and appending in the Look Up Table + quant = np.round((op_val) / o_scale) + o_zp + val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8) + LUT.append(val) + return LUT + + +def generate_take_primfunc(inp, struct_info): + # Generating the take op + N, H, W, C = inp.struct_info.shape + data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data") + LUT_func = te.placeholder((256,), dtype="uint8", name="LUT") + take = te.compute( + struct_info.shape, + lambda *indices: saturate( + (LUT_func[data[indices].astype("uint8")]), struct_info.dtype + ).astype(struct_info.dtype), + name="take_op", + ) + mod = te.create_prim_func([data, LUT_func, take]) + return mod diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py new file mode 100644 index 000000000000..b37966570a14 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_take.py @@ -0,0 +1,403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +from scipy import special + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import tir as T, relax as R +from tvm.contrib.hexagon import generate_take_op +from tvm.contrib.hexagon import hexagon_unary_ops + +from .infrastructure import quantize_np + + +# Testing the structural and value correctness on replacing unary op with take op. + + +@tvm.script.ir_module +class Module_tanh: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_tanh.tanh, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002631544131858676, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def tanh( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_sqrt: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_sqrt.sqrt, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.003535157327728918, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def sqrt( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_rsqrt: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_rsqrt.rsqrt, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.008154160766635542, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def rsqrt( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_exp: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_exp.exp, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.008838622987079832, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def exp( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_erf: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_erf.erf, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002939393251118067, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def erf( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_sigmoid: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_sigmoid.sigmoid, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.002631544131858676, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def sigmoid( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_hardswish: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_hardswish.hardswish, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0020250332087720325, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def hardswish( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_log: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_log.log, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0057414634248614226, "float32"), + R.const(255, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def log( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +@tvm.script.ir_module +class Module_abs: + @R.function + def main( + input: R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): + out = R.call_tir( + Module_abs.abs, + ( + input, + R.const(0.003186821002586215, "float32"), + R.const(0, "int32"), + R.const(0.0031868210196078434, "float32"), + R.const(0, "int32"), + ), + out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), + ) + return out + + @T.prim_func + def abs( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + rxplaceholder_1: T.Buffer((), "float32"), + rxplaceholder_2: T.Buffer((), "int32"), + rxplaceholder_3: T.Buffer((), "float32"), + rxplaceholder_4: T.Buffer((), "int32"), + compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), + ): + T.func_attr({"tir.noalias": True}) + pass + + +# data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode the data +# so that we can get the quantization parameters and use them as input to the main func +data = [ + [ + [[0.3034368, 0.60848576], [0.29697746, 0.67340654]], + [[0.656068, 0.23129226], [0.42117321, 0.81263936]], + ] +] +dtype = "uint8" + +# Quantizing input : scale is returned as float64 and zp is returned as int32 +inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype) +inp_quant = tvm.nd.array(inp_quant.astype(np.uint8)) + + +# Test the implementations value output with numpy data. First the IR is runn through pass +# to replace unary op with take op. Followed by value testing. +def test_value(): + ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"] + + atol_val = 2 + for op_name in ops: + if op_name == "tanh": + op_val = np.tanh(data) + before = Module_tanh + elif op_name == "sqrt": + op_val = np.sqrt(data) + before = Module_sqrt + elif op_name == "rsqrt": + op_val = 1 / np.sqrt(data) + before = Module_rsqrt + elif op_name == "exp": + op_val = np.exp(data) + before = Module_exp + elif op_name == "erf": + op_val = special.erf(data) + before = Module_erf + elif op_name == "sigmoid": + op_val = 1 / (1 + np.exp(np.negative(data))) + atol_val = 15 + before = Module_sigmoid + elif op_name == "hardswish": + op_val = hexagon_unary_ops.hardswish_func(data) + before = Module_hardswish + elif op_name == "log": + op_val = np.log(data) + before = Module_log + elif op_name == "abs": + op_val = np.abs(data) + before = Module_abs + + # Quantizing output : scale is returned as float64 and zp is returned as int32 + out_quant, out_scale, out_zero_point = quantize_np(op_val, dtype) + + after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(before) + target = tvm.target.Target("llvm", host="llvm") + ex = relax.build(after, target, exec_mode="compiled") + vm = relax.VirtualMachine(ex, tvm.cpu()) + res = vm["main"](inp_quant) + + tvm.testing.assert_allclose(res.numpy(), out_quant, atol=atol_val) + print("Passed Value : ", op_name) + + +# Testing the structural implementation, if the unary op is replaced with take op. +def test_structural(): + Modules = [ + Module_tanh, + Module_sqrt, + Module_rsqrt, + Module_exp, + Module_erf, + Module_sigmoid, + Module_hardswish, + Module_log, + Module_abs, + ] + for mod in Modules: + after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod) + print(after) + assert not tvm.ir.structural_equal(after["main"], mod["main"]) + print("Passed Structural") From 853fa95db9abf17ae3e65ab274f86e85d644bac0 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Wed, 23 Aug 2023 17:06:02 -0500 Subject: [PATCH 2/4] Removed extra print statements --- tests/python/contrib/test_hexagon/test_take.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py index b37966570a14..2045ff00a2ef 100644 --- a/tests/python/contrib/test_hexagon/test_take.py +++ b/tests/python/contrib/test_hexagon/test_take.py @@ -398,6 +398,5 @@ def test_structural(): ] for mod in Modules: after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod) - print(after) assert not tvm.ir.structural_equal(after["main"], mod["main"]) print("Passed Structural") From de85bea07d79b21a37e409bf493612bc5a4aed62 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Wed, 30 Aug 2023 17:42:05 -0500 Subject: [PATCH 3/4] Fixed the pylint errors --- .../tvm/contrib/hexagon/generate_take_op.py | 16 +++--- .../tvm/contrib/hexagon/hexagon_unary_ops.py | 55 +++++++++++-------- 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py index 0ca6aefb52f0..4763c5b9f502 100644 --- a/python/tvm/contrib/hexagon/generate_take_op.py +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import numpy as np - +"""Pass to replace unary ops with Look Up Table and take op""" import tvm import tvm.testing from tvm import relax @@ -24,6 +22,7 @@ def op_replace(call_node): + """Checks if the op in the graph matched the list of unary ops which can be replaced""" def is_op(op_name: str, call_node: relax.Call) -> bool: if not isinstance(call_node, relax.Call): return False @@ -42,12 +41,13 @@ def is_op(op_name: str, call_node: relax.Call) -> bool: @relax.expr_functor.mutator class Tanh2TakeReplace(tvm.relax.PyExprMutator): + """Pass which iterated over the nodes, checks for unary ops and replaces them with LUT and take op""" def __init__(self, mod: tvm.IRModule) -> None: super().__init__(mod) self.mod_ = mod def transform(self) -> tvm.IRModule: - # Iterate over all the nodes to check for the node replaceable + """Iterates over all the nodes""" for global_var, func in self.mod_.functions.items(): # Skip non-relax functions if not isinstance(func, relax.Function): @@ -61,9 +61,9 @@ def transform(self) -> tvm.IRModule: def visit_call_(self, call_node: relax.Call) -> relax.Call: if call_node.args[1][0].struct_info.dtype == "uint8": if op_replace(call_node): - inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in call_node.args[1]] + inp, inp_scale, inp_zp, out_scale, out_zp = list(call_node.args[1]) # LUT node creation - LUT = hexagon_unary_ops.LUT_generation( + lut = hexagon_unary_ops.lut_generation( inp_scale, inp_zp, out_scale, out_zp, call_node.args[0].name_hint ) # Take operation node creation @@ -72,7 +72,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: take_node = relax.call_tir( take_func_gv, relax.expr.Tuple( - [call_node.args[1][0], relax.expr.Constant(tvm.nd.array(LUT))] + [call_node.args[1][0], relax.expr.Constant(tvm.nd.array(lut))] ), call_node.struct_info, ) @@ -80,7 +80,7 @@ def visit_call_(self, call_node: relax.Call) -> relax.Call: return call_node -@tvm.ir.transform.module_pass(opt_level=2, name="replace_tanh_take") +@tvm.ir.transform.module_pass(opt_level=2, name="replace_unaryop_take") class PassReplaceWithTakeOpPrimFuncs: def transform_module(self, mod, ctx): return Tanh2TakeReplace(mod).transform() diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py b/python/tvm/contrib/hexagon/hexagon_unary_ops.py index c4cdf48997a4..798a1dcf244c 100644 --- a/python/tvm/contrib/hexagon/hexagon_unary_ops.py +++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py @@ -14,17 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +"""Primitive Function for lut and Take Op""" import numpy as np from scipy import special - +from typing import List from tvm import te -from tvm.script import tir as T - -###################################################################### -#################### PRIMFUNC FOR LUT and Take Op #################### -###################################################################### - +from tvm.tir.function import PrimFunc def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" @@ -32,13 +27,15 @@ def saturate(x: te.Tensor, dtype: str): def hardswish_func(x): - x2 = np.add(x, 3.0) - x2 = np.clip(x2, 0.0, 6.0) - return x * x2 / 6.0 + """Hardswich Function""" + x_2 = np.add(x, 3.0) + x_2 = np.clip(x_2, 0.0, 6.0) + return x * x_2 / 6.0 -def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None: - LUT = [] +def lut_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> List[np.uint8]: + """Generating the Look Up Table for unary ops""" + lut = [] for i in range(256): i = np.int32(i) # converting the constants to the numpy value @@ -73,21 +70,35 @@ def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None: # Quantizing the value generated and appending in the Look Up Table quant = np.round((op_val) / o_scale) + o_zp val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8) - LUT.append(val) - return LUT + lut.append(val) + return lut + +def generate_take_primfunc(inp, struct_info) -> PrimFunc: + """Generating the take op -def generate_take_primfunc(inp, struct_info): - # Generating the take op - N, H, W, C = inp.struct_info.shape - data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data") - LUT_func = te.placeholder((256,), dtype="uint8", name="LUT") + Parameters + ---------- + inp : expr.Var + The input to be searched in the lut and whose take op needs to be returned + + struct_info : TensorStructInfo + The struct info of the input data + + Returns + ---------- + mod : PrimFunc + The take op primitive function + """ + n, h, w, c = inp.struct_info.shape + data = te.placeholder((n, h, w, c), dtype=struct_info.dtype, name="data") + lut_func = te.placeholder((256,), dtype="uint8", name="lut") take = te.compute( struct_info.shape, lambda *indices: saturate( - (LUT_func[data[indices].astype("uint8")]), struct_info.dtype + (lut_func[data[indices].astype("uint8")]), struct_info.dtype ).astype(struct_info.dtype), name="take_op", ) - mod = te.create_prim_func([data, LUT_func, take]) + mod = te.create_prim_func([data, lut_func, take]) return mod From 37933597c86801c1249dfa1403cac8e33017cfa1 Mon Sep 17 00:00:00 2001 From: snigdha dalvi Date: Wed, 30 Aug 2023 22:47:33 -0500 Subject: [PATCH 4/4] Black formatted --- python/tvm/contrib/hexagon/generate_take_op.py | 2 ++ python/tvm/contrib/hexagon/hexagon_unary_ops.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/hexagon/generate_take_op.py b/python/tvm/contrib/hexagon/generate_take_op.py index 4763c5b9f502..c4569cffc121 100644 --- a/python/tvm/contrib/hexagon/generate_take_op.py +++ b/python/tvm/contrib/hexagon/generate_take_op.py @@ -23,6 +23,7 @@ def op_replace(call_node): """Checks if the op in the graph matched the list of unary ops which can be replaced""" + def is_op(op_name: str, call_node: relax.Call) -> bool: if not isinstance(call_node, relax.Call): return False @@ -42,6 +43,7 @@ def is_op(op_name: str, call_node: relax.Call) -> bool: @relax.expr_functor.mutator class Tanh2TakeReplace(tvm.relax.PyExprMutator): """Pass which iterated over the nodes, checks for unary ops and replaces them with LUT and take op""" + def __init__(self, mod: tvm.IRModule) -> None: super().__init__(mod) self.mod_ = mod diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py b/python/tvm/contrib/hexagon/hexagon_unary_ops.py index 798a1dcf244c..0261b1880662 100644 --- a/python/tvm/contrib/hexagon/hexagon_unary_ops.py +++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py @@ -21,6 +21,7 @@ from tvm import te from tvm.tir.function import PrimFunc + def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) @@ -79,12 +80,12 @@ def generate_take_primfunc(inp, struct_info) -> PrimFunc: Parameters ---------- - inp : expr.Var + inp : expr.Var The input to be searched in the lut and whose take op needs to be returned - + struct_info : TensorStructInfo The struct info of the input data - + Returns ---------- mod : PrimFunc