diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index be95a4880b86..b5cee77d118c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -138,15 +138,9 @@ def _call_binary_op(self, op, lhs, rhs): ########## Arithmetic ########## - def _cos(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) - def _exp(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.exp(self.env[node.args[0]])) - def _sin(self, node: fx.node.Node) -> relax.Var: - return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) - def _sigmoid(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]])) @@ -1291,9 +1285,19 @@ def create_convert_map(self): nn.modules.sparse.Embedding: self._embedding, nn.CrossEntropyLoss: self._cross_entropy, # call_function and call_method - "cos": self._cos, + "sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])), + "cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])), + "tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])), + "asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])), + "acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])), + "atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])), + "sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])), + "cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])), + "tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), + "asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])), + "acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])), + "atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])), "exp": self._exp, - "sin": self._sin, "iadd": self._add, "add": self._add, "floordiv": self._floordiv, @@ -1350,7 +1354,6 @@ def create_convert_map(self): "leaky_relu": self._leakyrelu, "gelu": self._gelu, "silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), - "tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])), "interpolate": self._interpolate, "size": self._size, "getattr": self._getattr, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 2b95d3897d97..ec312767b4b8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1807,7 +1807,7 @@ def forward(self, input): return torch.sin(input) @tvm.script.ir_module - class expected1: + class expected_sin: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -1819,7 +1819,7 @@ def main( R.output(gv) return gv - verify_model(Sin(), input_info, {}, expected1) + verify_model(Sin(), input_info, {}, expected_sin) # cos class Cos(Module): @@ -1827,7 +1827,7 @@ def forward(self, input): return torch.cos(input) @tvm.script.ir_module - class expected2: + class expected_cos: @R.function def main( input_1: R.Tensor((1, 3, 10, 10), dtype="float32") @@ -1839,7 +1839,207 @@ def main( R.output(gv) return gv - verify_model(Cos(), input_info, {}, expected2) + verify_model(Cos(), input_info, {}, expected_cos) + + # tan + class Tan(Module): + def forward(self, input): + return torch.tan(input) + + @tvm.script.ir_module + class expected_tan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tan(), input_info, {}, expected_tan) + + # asin + class Asin(Module): + def forward(self, input): + return torch.asin(input) + + @tvm.script.ir_module + class expected_asin: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Asin(), input_info, {}, expected_asin) + + # acos + class Acos(Module): + def forward(self, input): + return torch.acos(input) + + @tvm.script.ir_module + class expected_acos: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Acos(), input_info, {}, expected_acos) + + # atan + class Atan(Module): + def forward(self, input): + return torch.atan(input) + + @tvm.script.ir_module + class expected_atan: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atan(), input_info, {}, expected_atan) + + # sinh + class Sinh(Module): + def forward(self, input): + return torch.sinh(input) + + @tvm.script.ir_module + class expected_sinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sinh(), input_info, {}, expected_sinh) + + # cosh + class Cosh(Module): + def forward(self, input): + return torch.cosh(input) + + @tvm.script.ir_module + class expected_cosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cosh(), input_info, {}, expected_cosh) + + # tanh + class Tanh(Module): + def forward(self, input): + return torch.tanh(input) + + @tvm.script.ir_module + class expected_tanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tanh(), input_info, {}, expected_tanh) + + # asinh + class Asinh(Module): + def forward(self, input): + return torch.asinh(input) + + @tvm.script.ir_module + class expected_asinh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Asinh(), input_info, {}, expected_asinh) + + # acosh + class Acosh(Module): + def forward(self, input): + return torch.acosh(input) + + @tvm.script.ir_module + class expected_acosh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Acosh(), input_info, {}, expected_acosh) + + # atanh + class Atanh(Module): + def forward(self, input): + return torch.atanh(input) + + @tvm.script.ir_module + class expected_atanh: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atanh(), input_info, {}, expected_atanh) # exp class Exp(Module):