Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relax][frontend]add relax frontend torch op: tan,asin,acos,atan,sinh,cosh,tanh,asinh,… #15610

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,9 @@ def _call_binary_op(self, op, lhs, rhs):

########## Arithmetic ##########

def _cos(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))

def _exp(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.exp(self.env[node.args[0]]))

def _sin(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))

def _sigmoid(self, node: fx.node.Node) -> relax.Var:
return self.block_builder.emit(relax.op.sigmoid(self.env[node.args[0]]))

Expand Down Expand Up @@ -1291,9 +1285,19 @@ def create_convert_map(self):
nn.modules.sparse.Embedding: self._embedding,
nn.CrossEntropyLoss: self._cross_entropy,
# call_function and call_method
"cos": self._cos,
"sin": lambda node: self.block_builder.emit(relax.op.sin(self.env[node.args[0]])),
"cos": lambda node: self.block_builder.emit(relax.op.cos(self.env[node.args[0]])),
"tan": lambda node: self.block_builder.emit(relax.op.tan(self.env[node.args[0]])),
"asin": lambda node: self.block_builder.emit(relax.op.asin(self.env[node.args[0]])),
"acos": lambda node: self.block_builder.emit(relax.op.acos(self.env[node.args[0]])),
"atan": lambda node: self.block_builder.emit(relax.op.atan(self.env[node.args[0]])),
"sinh": lambda node: self.block_builder.emit(relax.op.sinh(self.env[node.args[0]])),
"cosh": lambda node: self.block_builder.emit(relax.op.cosh(self.env[node.args[0]])),
"tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"asinh": lambda node: self.block_builder.emit(relax.op.asinh(self.env[node.args[0]])),
"acosh": lambda node: self.block_builder.emit(relax.op.acosh(self.env[node.args[0]])),
"atanh": lambda node: self.block_builder.emit(relax.op.atanh(self.env[node.args[0]])),
"exp": self._exp,
"sin": self._sin,
"iadd": self._add,
"add": self._add,
"floordiv": self._floordiv,
Expand Down Expand Up @@ -1350,7 +1354,6 @@ def create_convert_map(self):
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
"tanh": lambda node: self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
Expand Down
208 changes: 204 additions & 4 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def forward(self, input):
return torch.sin(input)

@tvm.script.ir_module
class expected1:
class expected_sin:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
Expand All @@ -1819,15 +1819,15 @@ def main(
R.output(gv)
return gv

verify_model(Sin(), input_info, {}, expected1)
verify_model(Sin(), input_info, {}, expected_sin)

# cos
class Cos(Module):
def forward(self, input):
return torch.cos(input)

@tvm.script.ir_module
class expected2:
class expected_cos:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
Expand All @@ -1839,7 +1839,207 @@ def main(
R.output(gv)
return gv

verify_model(Cos(), input_info, {}, expected2)
verify_model(Cos(), input_info, {}, expected_cos)

# tan
class Tan(Module):
def forward(self, input):
return torch.tan(input)

@tvm.script.ir_module
class expected_tan:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tan(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tan(), input_info, {}, expected_tan)

# asin
class Asin(Module):
def forward(self, input):
return torch.asin(input)

@tvm.script.ir_module
class expected_asin:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asin(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Asin(), input_info, {}, expected_asin)

# acos
class Acos(Module):
def forward(self, input):
return torch.acos(input)

@tvm.script.ir_module
class expected_acos:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acos(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Acos(), input_info, {}, expected_acos)

# atan
class Atan(Module):
def forward(self, input):
return torch.atan(input)

@tvm.script.ir_module
class expected_atan:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Atan(), input_info, {}, expected_atan)

# sinh
class Sinh(Module):
def forward(self, input):
return torch.sinh(input)

@tvm.script.ir_module
class expected_sinh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sinh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Sinh(), input_info, {}, expected_sinh)

# cosh
class Cosh(Module):
def forward(self, input):
return torch.cosh(input)

@tvm.script.ir_module
class expected_cosh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cosh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Cosh(), input_info, {}, expected_cosh)

# tanh
class Tanh(Module):
def forward(self, input):
return torch.tanh(input)

@tvm.script.ir_module
class expected_tanh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Tanh(), input_info, {}, expected_tanh)

# asinh
class Asinh(Module):
def forward(self, input):
return torch.asinh(input)

@tvm.script.ir_module
class expected_asinh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.asinh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Asinh(), input_info, {}, expected_asinh)

# acosh
class Acosh(Module):
def forward(self, input):
return torch.acosh(input)

@tvm.script.ir_module
class expected_acosh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.acosh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Acosh(), input_info, {}, expected_acosh)

# atanh
class Atanh(Module):
def forward(self, input):
return torch.atanh(input)

@tvm.script.ir_module
class expected_atanh:
@R.function
def main(
input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atanh(input_1)
gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Atanh(), input_info, {}, expected_atanh)

# exp
class Exp(Module):
Expand Down