Skip to content

Commit

Permalink
[relax][frontend]add relax frontend torch op: tan,asin,acos,atan,sinh…
Browse files Browse the repository at this point in the history
…,cosh,tanh,asinh,… (#15610)

add relax frontend torch op: tan,asin,acos,atan,sinh,cosh,tanh,asinh,acosh,atanh

Co-authored-by: HLearning <[email protected]>
  • Loading branch information
HLearning and HLearning authored Aug 25, 2023
1 parent 71cdd46 commit 4a51382
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 13 deletions.
21 changes: 12 additions & 9 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,9 @@ def _call_binary_op(self, op, lhs, rhs):

########## Arithmetic ##########

def _cos(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))

def _exp(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))

def _sin(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))

def _sigmoid(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))

Expand Down Expand Up @@ -1291,9 +1285,19 @@ def create_convert_map(self):
nn.modules.sparse.Embedding: self._embedding,
nn.CrossEntropyLoss: self._cross_entropy,
# call_function and call_method
"cos": self._cos,
"sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
"cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
"tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
"asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
"acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
"atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
"sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
"cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
"tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
"acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
"atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
"exp": self._exp,
"sin": self._sin,
"iadd": self._add,
"add": self._add,
"floordiv": self._floordiv,
Expand Down Expand Up @@ -1350,7 +1354,6 @@ def create_convert_map(self):
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
"tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
Expand Down
208 changes: 204 additions & 4 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def forward(self, input):
return torch.sin(input)

@tvm.script.ir_module
class expected1:
class expected_sin:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
Expand All @@ -1819,15 +1819,15 @@ def main(
R.output(gv)
return gv

verify_model(Sin(), input_info, {}, expected1)
verify_model(Sin(), input_info, {}, expected_sin)

# cos
class Cos(Module):
def forward(self, input):
return torch.cos(input)

@tvm.script.ir_module
class expected2:
class expected_cos:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
Expand All @@ -1839,7 +1839,207 @@ def main(
R.output(gv)
return gv

verify_model(Cos(), input_info, {}, expected2)
verify_model(Cos(), input_info, {}, expected_cos)

# tan
class Tan(Module):
def forward(self, input):
return torch.tan(input)

@tvm.script.ir_module
class expected_tan:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tan(), input_info, {}, expected_tan)

# asin
class Asin(Module):
def forward(self, input):
return torch.asin(input)

@tvm.script.ir_module
class expected_asin:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Asin(), input_info, {}, expected_asin)

# acos
class Acos(Module):
def forward(self, input):
return torch.acos(input)

@tvm.script.ir_module
class expected_acos:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Acos(), input_info, {}, expected_acos)

# atan
class Atan(Module):
def forward(self, input):
return torch.atan(input)

@tvm.script.ir_module
class expected_atan:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Atan(), input_info, {}, expected_atan)

# sinh
class Sinh(Module):
def forward(self, input):
return torch.sinh(input)

@tvm.script.ir_module
class expected_sinh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Sinh(), input_info, {}, expected_sinh)

# cosh
class Cosh(Module):
def forward(self, input):
return torch.cosh(input)

@tvm.script.ir_module
class expected_cosh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Cosh(), input_info, {}, expected_cosh)

# tanh
class Tanh(Module):
def forward(self, input):
return torch.tanh(input)

@tvm.script.ir_module
class expected_tanh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tanh(), input_info, {}, expected_tanh)

# asinh
class Asinh(Module):
def forward(self, input):
return torch.asinh(input)

@tvm.script.ir_module
class expected_asinh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Asinh(), input_info, {}, expected_asinh)

# acosh
class Acosh(Module):
def forward(self, input):
return torch.acosh(input)

@tvm.script.ir_module
class expected_acosh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Acosh(), input_info, {}, expected_acosh)

# atanh
class Atanh(Module):
def forward(self, input):
return torch.atanh(input)

@tvm.script.ir_module
class expected_atanh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Atanh(), input_info, {}, expected_atanh)

# exp
class Exp(Module):
Expand Down

0 comments on commit 4a51382

Please sign in to comment.