Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Remove attributes of relax.print, assert and unique (#443)
Browse files Browse the repository at this point in the history
* Remove Attr of assert, print and unique

* use relax.null_value

* Fix comments

* Fix comments. Remove null_value

* Fix 421
  • Loading branch information
yongwww authored Feb 23, 2023
1 parent 9b34d04 commit b0c63b5
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 209 deletions.
62 changes: 0 additions & 62 deletions include/tvm/relax/attrs/set.h

This file was deleted.

21 changes: 0 additions & 21 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,6 @@ using FInferStructInfo =
*/
using FCallPacked = String;

struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
std::string format;
TVM_DECLARE_ATTRS(PrintAttrs, "relax.attrs.PrintAttrs") {
TVM_ATTR_FIELD(format)
.describe("Python-style format string to use for displaying the input. Ignored if empty.")
.set_default("");
}
};

struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
std::string format;
TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
TVM_ATTR_FIELD(format)
.describe(
"Python-style format string to use for displaying "
"an error message if the assert fails. "
"Ignored if empty.")
.set_default("");
}
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
21 changes: 14 additions & 7 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Call, ExternFunc
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
Expand Down Expand Up @@ -237,22 +237,25 @@ def relax_print(format_str: str, *format_args: tvm.Object) -> None:
py_print(format_str.format(*val_strs))


def print(*values: List[Expr], format: str = "") -> Expr:
def print(*values: List[Expr], format: Union[str, Expr] = "") -> Expr:
"""Print op to print the values
Parameters
----------
values : List[Expr]
The values to print.
format_str: str
The format string.
format: Union[str, Expr]
The format string or StringImm.
Returns
-------
result : Expr
A relax Call, which will print the value during runtime.
"""
if isinstance(format, str):
format = StringImm(format)

return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member


Expand Down Expand Up @@ -310,7 +313,9 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob


def assert_op(
condition: Expr, format_args: Optional[Union[Expr, List[Expr]]] = None, format: str = ""
condition: Expr,
format_args: Optional[Union[Expr, List[Expr]]] = None,
format: Union[str, Expr] = "",
) -> Expr:
"""
Create a call to Relax's assert_op operation (`assert` is reserved in Python,
Expand All @@ -324,8 +329,8 @@ def assert_op(
format_args: Optional[Union[Expr, List[Expr]]]
Format arguments for the error message if the condition fails.
format_str: str
The format string for the error message.
format: Union[str, Expr]
The format string or StringImm for the error message.
Returns
-------
Expand All @@ -336,6 +341,8 @@ def assert_op(
format_args = []
if isinstance(format_args, Expr): # type: ignore
format_args = [format_args]
if isinstance(format, str):
format = StringImm(format)
return _ffi_api.assert_op(condition, format_args, format) # type: ignore


Expand Down
16 changes: 12 additions & 4 deletions python/tvm/relax/op/builtin/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
# specific language governing permissions and limitations
"""The builtin Relax operators."""

from ...expr import Call, Expr
from typing import Union
from ...expr import Call, Expr, PrimValue, DataTypeImm
from ...utils import args_converter
from . import _ffi_api


@args_converter.auto
def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
def alloc_tensor(
shape: Expr, dtype: Union[str, Expr], runtime_device_index: Union[int, Expr]
) -> Call:
"""Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index.
Parameters
----------
shape : Expr
The shape of the tensor to be allocated.
dtype : str
dtype : Union[str, Expr]
The datatype of the tensor to be allocated.
runtime_device_index : int
runtime_device_index : Union[int, Expr]
The device index indicating on which device the tensor is to be allocated at runtime.
Index -1 is reserved for the host device.
Expand All @@ -41,4 +44,9 @@ def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call:
result : Call
A relax Call, which gets the allocated tensor.
"""
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_index, int):
runtime_device_index = PrimValue(runtime_device_index)

return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore
15 changes: 0 additions & 15 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@
import tvm._ffi


@tvm._ffi.register_object("relax.attrs.PrintAttrs")
class PrintAttrs(Attrs):
"""Attributes used for the print operator"""


@tvm._ffi.register_object("relax.attrs.AssertOpAttrs")
class AssertOpAttrs(Attrs):
"""Attributes used for the assert operator"""


@tvm._ffi.register_object("relax.attrs.StatisticalAttrs")
class StatisticalAttrs(Attrs):
"""Attributes used in statistical operator"""
Expand All @@ -44,11 +34,6 @@ class TriluAttrs(Attrs):
"""Attributes used in tril and triu operator"""


@tvm._ffi.register_object("relax.attrs.UniqueAttrs")
class UniqueAttrs(Attrs):
"""Attributes used for the unique operator"""


@tvm._ffi.register_object("relax.attrs.ConcatAttrs")
class ConcatAttrs(Attrs):
"""Attributes for concat operator"""
Expand Down
33 changes: 21 additions & 12 deletions python/tvm/relax/op/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@
# under the License.
# pylint: disable=import-outside-toplevel, redefined-builtin, unused-argument
"""Set operators."""
from typing import Optional
from typing import Optional, Union

import numpy as np # type: ignore
import tvm

from . import _ffi_api
from ..expr import Expr
from ..expr import Expr, PrimValue


def unique(
x: Expr,
sorted: bool = True,
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: Optional[int] = None,
sorted: Union[bool, Expr] = True,
return_index: Union[bool, Expr] = False,
return_inverse: Union[bool, Expr] = False,
return_counts: Union[bool, Expr] = False,
axis: Optional[Union[int, Expr]] = None,
) -> Expr:
"""Find the unique elements in a given tensor.
In addition, it optionally returns
Expand All @@ -44,19 +44,19 @@ def unique(
x : relax.Expr
The input tensor.
sorted : bool
sorted : Union[bool, Expr]
Whether to sort the unique elements in ascending order before
returning as output.
return_index : bool
return_index : Union[bool, Expr]
Whether to return an additional tensor with indices for where elements in
the unique tensor come from the original input.
return_inverse : bool
return_inverse : Union[bool, Expr]
Whether to return an additional tensor with indices for where elements in
the original input ended up in the returned unique list.
return_counts : bool
return_counts : Union[bool, Expr]
Whether to return an additional tensor with counts of each unique elements.
axis : Optional
Expand All @@ -69,6 +69,16 @@ def unique(
The created relax call with
"""

if isinstance(sorted, bool):
sorted = PrimValue(sorted)
if isinstance(return_index, bool):
return_index = PrimValue(return_index)
if isinstance(return_inverse, bool):
return_inverse = PrimValue(return_inverse)
if isinstance(return_counts, bool):
return_counts = PrimValue(return_counts)
if axis and isinstance(axis, int):
axis = PrimValue(axis)
return _ffi_api.unique( # type: ignore
x, sorted, return_index, return_inverse, return_counts, axis
)
Expand All @@ -81,7 +91,6 @@ def numpy_unique(
return_index: int,
return_inverse: int,
return_counts: int,
axis: Optional[int],
) -> tvm.nd.array:
"""Returns the unique elements of the input tensor.
Expand Down
40 changes: 1 addition & 39 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
* \brief A codegen to generate VM executable from a Relax IRModule.
*/
#include <tvm/driver/driver_api.h>
#include <tvm/relax/attrs/set.h>
#include <tvm/relax/attrs/shape.h>
#include <tvm/relax/exec_builder.h>
#include <tvm/relax/expr_functor.h>
Expand Down Expand Up @@ -364,43 +363,9 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
builder_->EmitCall(func, args, dst_reg);
}

// TODO(relax-team) revisit after PrimValue.
// Emit the `call_node` attributes as constants and append these constants to `args` vector.
void AppendAttrsAsConstants(const Call& call_node, std::vector<Instruction::Arg>& args) {
auto attrs = call_node->attrs;
if (!attrs.defined()) return;

if (call_node->op == unique_op_) {
auto unique_attrs = call_node->attrs.as<UniqueAttrs>();
args.push_back(builder_->ConvertConstant(unique_attrs->sorted));
args.push_back(builder_->ConvertConstant(unique_attrs->return_index));
args.push_back(builder_->ConvertConstant(unique_attrs->return_inverse));
args.push_back(builder_->ConvertConstant(unique_attrs->return_counts));
args.push_back(builder_->ConvertConstant(unique_attrs->axis));
return;
}
if (call_node->op == print_op_) {
auto print_attrs = call_node->attrs.as<PrintAttrs>();
// format string is the first argument
args.insert(args.begin(), builder_->ConvertConstant(print_attrs->format));
return;
}
if (call_node->op == assert_op_) {
auto assert_attrs = call_node->attrs.as<AssertOpAttrs>();
// format string comes before the format args
args.insert(args.begin() + 1, builder_->ConvertConstant(assert_attrs->format));
return;
}
LOG(FATAL) << "Support for attributes of Op " << call_node->op
<< " has not been implemented yet.";
return;
}

// Emits call to packed function `name` with arguments copied over from `call_node` args and
// attributes.
// Emits call to packed function `name` with arguments copied over from `call_node` args
void EmitPackedFuncCall(const Call& call_node, const FCallPacked& name, RegName dst_reg) {
std::vector<Instruction::Arg> args = VisitArray(call_node->args);
AppendAttrsAsConstants(call_node, args);
builder_->EmitCall(name, args, dst_reg);
}

Expand Down Expand Up @@ -428,9 +393,6 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
const Op& unique_op_ = Op::Get("relax.unique");
const Op& print_op_ = Op::Get("relax.print");
const Op& assert_op_ = Op::Get("relax.assert_op");
};

/*!
Expand Down
Loading

0 comments on commit b0c63b5

Please sign in to comment.