From 0fc09194fc7a0176bb8ebcc5fdd5f0e9a849c717 Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu, 18 Aug 2022 21:51:35 -0700 Subject: [PATCH] [Pass] New Python ExprVisitor/ExprMutator! (#190) Add decorators `visitor` and `mutator` to help users create `ExprVisitor` and `ExprMutator` in Python. Users can customize visit/rewrite/post-order-rewrite function in Python. `PyExprVisitor` and `PyExprMutator` lists the functions users can customize. --- include/tvm/relax/expr_functor.h | 538 +++++++ python/tvm/meta_schedule/utils.py | 50 +- python/tvm/relax/__init__.py | 5 +- python/tvm/relax/expr_functor.py | 1700 +++++++++++++++------ python/tvm/relax/testing/ast_printer.py | 2 +- python/tvm/relax/testing/transform.py | 7 +- python/tvm/relax/transform/fma_rewrite.py | 20 +- src/relax/ir/expr_functor.cc | 123 ++ tests/python/relax/test_expr_functor.py | 689 ++++++++- tests/python/relax/test_pass_manager.py | 7 +- 10 files changed, 2606 insertions(+), 535 deletions(-) diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index bcc1bbecb0..29b449c033 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -65,6 +65,47 @@ class ExprFunctor; return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ }); +#define PY_EXPR_VISITOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC) \ + { \ + if (PY_FUNC != nullptr) \ + PY_FUNC(N); \ + else \ + DEFAULT_FUNC; \ + } + +#define PY_EXPR_MUTATOR_DEFAULT(N, PY_FUNC, DEFAULT_FUNC, RET_TYPE) \ + { \ + if (PY_FUNC != nullptr) { \ + RET_TYPE ret = PY_FUNC(N); \ + return ret; \ + } else { \ + return DEFAULT_FUNC; \ + } \ + } + +#define PY_EXPR_VISITOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) \ + self->PY_FUNC(n); \ + else \ + self->VisitExpr_(static_cast(n.get())); \ + }); + +#define PY_EXPR_MUTATOR_DISPATCH(OP, PY_FUNC) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + if (self->PY_FUNC != nullptr) { \ + Expr expr = self->PY_FUNC(n); \ + return expr; \ + } else { \ + return self->VisitExpr_(static_cast(n.get())); \ + } \ + }); + +#define PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OP) \ + post_order_vtable.template set_dispatch([](const ObjectRef& n, TSelf* self) { \ + return self->VisitExprPostOrder_(static_cast(n.get())); \ + }); + template class ExprFunctor { private: @@ -338,6 +379,503 @@ class ExprMutator : public ExprMutatorBase { std::unordered_map var_remap_; }; +/*! + * \brief The abstract interface of ExprVisitor. + */ +class PyExprVisitorNode : public Object, public ExprVisitor { + private: + using TSelf = PyExprVisitorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ + PackedFunc f_visit_runtime_dep_shape_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + * function. */ + PackedFunc f_visit_match_shape_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitType(const Type& t)` function. */ + PackedFunc f_visit_type{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + void VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + f_visit_expr(expr); + } else { + // Need to init the overwrite VTable + static FType vtable = InitVTable(); + vtable(expr, this); + } + } + + void VisitBinding(const Binding& binding) + PY_EXPR_VISITOR_DEFAULT(binding, f_visit_binding, ExprVisitor::VisitBinding(binding)); + + void VisitBinding_(const VarBindingNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_var_binding_, + ExprVisitor::VisitBinding_(binding)); + void VisitBinding_(const MatchShapeNode* binding) + PY_EXPR_VISITOR_DEFAULT(GetRef(binding), f_visit_match_shape_, + ExprVisitor::VisitBinding_(binding)); + + void VisitBindingBlock(const BindingBlock& block) + PY_EXPR_VISITOR_DEFAULT(block, f_visit_binding_block, ExprVisitor::VisitBindingBlock(block)); + + void VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprVisitor::VisitBindingBlock_(block)); + void VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_VISITOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprVisitor::VisitBindingBlock_(block)); + + void VisitVarDef(const Var& var) + PY_EXPR_VISITOR_DEFAULT(var, f_visit_var_def, ExprVisitor::VisitVarDef(var)); + void VisitVarDef_(const VarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_var_def_, ExprVisitor::VisitVarDef_(var)); + void VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_VISITOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprVisitor::VisitVarDef_(var)); + + void VisitType(const Type& t) PY_EXPR_VISITOR_DEFAULT(t, f_visit_type, ExprVisitor::VisitType(t)); + void VisitSpan(const Span& span) + PY_EXPR_VISITOR_DEFAULT(span, f_visit_span, ExprVisitor::VisitSpan(span)); + + void VisitAttrs(AttrVisitor* v) {} + static constexpr const char* _type_key = "expr_functor.PyExprVisitor"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprVisitorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_VISITOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_VISITOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_VISITOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_VISITOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_VISITOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_VISITOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); + PY_EXPR_VISITOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_VISITOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_VISITOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_VISITOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_VISITOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_VISITOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_VISITOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_VISITOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + return vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprVisitorNode); + +/*! + * \brief Managed reference to PyExprVisitorNode. + * \sa PyExprVisitorNode + */ +class PyExprVisitor : public ObjectRef { + public: + /*! + * \brief Create a PyExprVisitor with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* + * op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_type The packed function of `VisitType(const Type& t)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyVisitor created. + */ + TVM_DLL static PyExprVisitor MakePyExprVisitor( + PackedFunc f_visit_expr, PackedFunc f_visit_constant_, PackedFunc f_visit_tuple_, + PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, PackedFunc f_visit_shape_expr_, + PackedFunc f_visit_runtime_dep_shape_, PackedFunc f_visit_extern_func_, + PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, PackedFunc f_visit_call_, + PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, PackedFunc f_visit_op_, + PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->f_visit_expr = f_visit_expr; + n->f_visit_binding = f_visit_binding; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_type = f_visit_type; + n->f_visit_span = f_visit_span; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + return PyExprVisitor(n); + } + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprVisitor, ObjectRef, PyExprVisitorNode); +}; + +/*! + * \brief The abstract interface of ExprMutator. + */ +class PyExprMutatorNode : public Object, public ExprMutator { + private: + using TSelf = PyExprMutatorNode; + using FType = tvm::NodeFunctor; + + public: + /*! \brief The packed function to the `VisitExpr(const Expr& expr)` function. */ + PackedFunc f_visit_expr{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ConstantNode* op)` function. */ + PackedFunc f_visit_constant_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleNode* op)` function. */ + PackedFunc f_visit_tuple_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const VarNode* op)` function. */ + PackedFunc f_visit_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const DataflowVarNode* op)` function. */ + PackedFunc f_visit_dataflow_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ShapeExprNode* op)` function. */ + PackedFunc f_visit_shape_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const RuntimeDepShapeNode* op)` function. */ + PackedFunc f_visit_runtime_dep_shape_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const ExternFuncNode* op)` function. */ + PackedFunc f_visit_extern_func_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const GlobalVarNode* op)` function. */ + PackedFunc f_visit_global_var_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const FunctionNode* op)` function. */ + PackedFunc f_visit_function_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const CallNode* op)` function. */ + PackedFunc f_visit_call_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const SeqExprNode* op)` function. */ + PackedFunc f_visit_seq_expr_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const IfNode* op)` function. */ + PackedFunc f_visit_if_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const OpNode* op)` function. */ + PackedFunc f_visit_op_{nullptr}; + /*! \brief The packed function to the `VisitExpr_(const TupleGetItemNode* op)` function. */ + PackedFunc f_visit_tuple_getitem_{nullptr}; + /*! \brief The packed function to the `VisitBinding(const Binding& binding)` function. */ + PackedFunc f_visit_binding{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const VarBindingNode* binding)` + * function. */ + PackedFunc f_visit_var_binding_{nullptr}; + /*! \brief The packed function to the `VisitBinding_(const MatchShapeNode* binding)` + * function. */ + PackedFunc f_visit_match_shape_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock(const BindingBlock& block)` + * function. */ + PackedFunc f_visit_binding_block{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const BindingBlockNode* block)` + * function. */ + PackedFunc f_visit_binding_block_{nullptr}; + /*! \brief The packed function to the `VisitBindingBlock_(const DataflowBlockNode* block)` + * function. */ + PackedFunc f_visit_dataflow_block_{nullptr}; + /*! \brief The packed function to the `VisitVarDef(const Var& var)` function. */ + PackedFunc f_visit_var_def{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const VarNode* var)` function. */ + PackedFunc f_visit_var_def_{nullptr}; + /*! \brief The packed function to the `VisitVarDef_(const DataflowVarNode* var)` function. */ + PackedFunc f_visit_dataflow_var_def_{nullptr}; + /*! \brief The packed function to the `VisitType(const Type& t)` function. */ + PackedFunc f_visit_type{nullptr}; + /*! \brief The packed function to the `VisitSpan(const Span& span)` function. */ + PackedFunc f_visit_span{nullptr}; + + Expr VisitExpr(const Expr& expr) { + if (f_visit_expr != nullptr) { + return builder_->Normalize(f_visit_expr(expr)); + } else { + static FType vtable = InitVTable(); + return builder_->Normalize(vtable(expr, this)); + } + } + + void VisitBinding(const Binding& binding) { + if (f_visit_binding != nullptr) + f_visit_binding(binding); + else + ExprMutator::VisitBinding(binding); + } + + void VisitBinding_(const VarBindingNode* binding) { + if (f_visit_var_binding_ != nullptr) + f_visit_var_binding_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + void VisitBinding_(const MatchShapeNode* binding) { + if (f_visit_match_shape_ != nullptr) + f_visit_match_shape_(GetRef(binding)); + else + ExprMutator::VisitBinding_(binding); + } + + BindingBlock VisitBindingBlock(const BindingBlock& block) + PY_EXPR_MUTATOR_DEFAULT(block, f_visit_binding_block, ExprMutator::VisitBindingBlock(block), + BindingBlock); + + BindingBlock VisitBindingBlock_(const BindingBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_binding_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) + PY_EXPR_MUTATOR_DEFAULT(GetRef(block), f_visit_dataflow_block_, + ExprMutator::VisitBindingBlock_(block), BindingBlock); + + Var VisitVarDef(const Var& var) + PY_EXPR_MUTATOR_DEFAULT(var, f_visit_var_def, ExprMutator::VisitVarDef(var), Var); + Var VisitVarDef_(const VarNode* var) PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_var_def_, + ExprMutator::VisitVarDef_(var), Var); + Var VisitVarDef_(const DataflowVarNode* var) + PY_EXPR_MUTATOR_DEFAULT(GetRef(var), f_visit_dataflow_var_def_, + ExprMutator::VisitVarDef_(var), Var); + + Type VisitType(const Type& t) + PY_EXPR_MUTATOR_DEFAULT(t, f_visit_type, ExprMutator::VisitType(t), Type); + + /*! + * \brief Dispatcher for post-order rewrite. + * \param expr The Expr to be rewritten. + * \return The Expr after post-order rewritten. + */ + Expr VisitExprPostOrder(const Expr& expr) { + static FType post_order_vtable = InitPostOrderVTable(); + return post_order_vtable(expr, this); + } + + using ExprMutator::builder_; + using ExprMutator::LookupBinding; + using ExprMutator::var_remap_; + using ExprMutator::VisitWithNewScope; + using ExprMutator::WithShapeAndType; + + void VisitAttrs(AttrVisitor* v) { v->Visit("builder_", &builder_); } + static constexpr const char* _type_key = "expr_functor.PyExprMutator"; + TVM_DECLARE_BASE_OBJECT_INFO(PyExprMutatorNode, Object); + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + PY_EXPR_MUTATOR_DISPATCH(ConstantNode, f_visit_constant_); + PY_EXPR_MUTATOR_DISPATCH(TupleNode, f_visit_tuple_); + PY_EXPR_MUTATOR_DISPATCH(VarNode, f_visit_var_); + PY_EXPR_MUTATOR_DISPATCH(DataflowVarNode, f_visit_dataflow_var_); + PY_EXPR_MUTATOR_DISPATCH(ShapeExprNode, f_visit_shape_expr_); + PY_EXPR_MUTATOR_DISPATCH(RuntimeDepShapeNode, f_visit_runtime_dep_shape_); + PY_EXPR_MUTATOR_DISPATCH(ExternFuncNode, f_visit_extern_func_); + PY_EXPR_MUTATOR_DISPATCH(GlobalVarNode, f_visit_global_var_); + PY_EXPR_MUTATOR_DISPATCH(FunctionNode, f_visit_function_); + PY_EXPR_MUTATOR_DISPATCH(CallNode, f_visit_call_); + PY_EXPR_MUTATOR_DISPATCH(SeqExprNode, f_visit_seq_expr_); + PY_EXPR_MUTATOR_DISPATCH(IfNode, f_visit_if_); + PY_EXPR_MUTATOR_DISPATCH(OpNode, f_visit_op_); + PY_EXPR_MUTATOR_DISPATCH(TupleGetItemNode, f_visit_tuple_getitem_); + return vtable; + } + + // initialize the vtable for post order visit. + static FType InitPostOrderVTable() { + FType post_order_vtable; + // Set dispatch + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ConstantNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(VarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataflowVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ShapeExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(RuntimeDepShapeNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(ExternFuncNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(GlobalVarNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(FunctionNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(CallNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(SeqExprNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(IfNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(OpNode); + PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(TupleGetItemNode); + return post_order_vtable; + } +}; + +TVM_REGISTER_NODE_TYPE(PyExprMutatorNode); + +/*! + * \brief Managed reference to PyExprMutatorNode. + * \sa PyExprMutatorNode + */ +class PyExprMutator : public ObjectRef { + public: + /*! + * \brief Create a PyExprMutator with customized methods on the python-side. + * \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`. + * \param f_visit_constant_ The packed function of `VisitExpr_(const ConstantNode* op)`. + * \param f_visit_tuple_ The packed function of `VisitExpr_(const TupleNode* op)`. + * \param f_visit_var_ The packed function of `VisitExpr_(const VarNode* op)`. + * \param f_visit_dataflow_var_ The packed function of `VisitExpr_(const DataflowVarNode* op)`. + * \param f_visit_shape_expr_ The packed function of `VisitExpr_(const ShapeExprNode* op)`. + * \param f_visit_runtime_dep_shape_ The packed function of `VisitExpr_(const RuntimeDepShapeNode* + * op)`. + * \param f_visit_extern_func_ The packed function of `VisitExpr_(const ExternFuncNode* op)`. + * \param f_visit_global_var_ The packed function of `VisitExpr_(const GlobalVarNode* op)`. + * \param f_visit_function_ The packed function of `VisitExpr_(const FunctionNode* op)`. + * \param f_visit_call_ The packed function of `VisitExpr_(const CallNode* op)`. + * \param f_visit_seq_expr_ The packed function of `VisitExpr_(const SeqExprNode* op)`. + * \param f_visit_if_ The packed function of `VisitExpr_(const IfNode* op)`. + * \param f_visit_op_ The packed function of `VisitExpr_(const OpNode* op)`. + * \param f_visit_tuple_getitem_ The packed function of `VisitExpr_(const TupleGetItemNode* op)`. + * \param f_visit_binding The packed function of `VisitBinding(const Binding& binding)`. + * \param f_visit_var_binding_ The packed function of `VisitBinding_(const VarBindingNode* + * binding)`. + * \param f_visit_match_shape_ The packed function of `VisitBinding_(const MatchShapeNode* + * binding)`. + * \param f_visit_binding_block The packed function of `VisitBindingBlock(const BindingBlock& + * block)`. + * \param f_visit_binding_block_ The packed function of `VisitBindingBlock_(const + * BindingBlockNode* block)`. + * \param f_visit_dataflow_block_ The packed function of `VisitBindingBlock_(const + * DataflowBlockNode* block)`. + * \param f_visit_var_def The packed function of `VisitVarDef(const Var& var)`. + * \param f_visit_var_def_ The packed function of `VisitVarDef_(const VarNode* var)`. + * \param f_visit_dataflow_var_def_ The packed function of `VisitVarDef_(const DataflowVarNode* + * var)`. + * \param f_visit_type The packed function of `VisitType(const Type& t)`. + * \param f_visit_span The packed function of `VisitSpan(const Span& span)`. + * \return The PyExprMutator created. + */ + TVM_DLL static PyExprMutator MakePyExprMutator( + BlockBuilder builder_, PackedFunc f_visit_expr, PackedFunc f_visit_constant_, + PackedFunc f_visit_tuple_, PackedFunc f_visit_var_, PackedFunc f_visit_dataflow_var_, + PackedFunc f_visit_shape_expr_, PackedFunc f_visit_runtime_dep_shape_, + PackedFunc f_visit_extern_func_, PackedFunc f_visit_global_var_, PackedFunc f_visit_function_, + PackedFunc f_visit_call_, PackedFunc f_visit_seq_expr_, PackedFunc f_visit_if_, + PackedFunc f_visit_op_, PackedFunc f_visit_tuple_getitem_, PackedFunc f_visit_binding, + PackedFunc f_visit_var_binding_, PackedFunc f_visit_match_shape_, + PackedFunc f_visit_binding_block, PackedFunc f_visit_binding_block_, + PackedFunc f_visit_dataflow_block_, PackedFunc f_visit_var_def, PackedFunc f_visit_var_def_, + PackedFunc f_visit_dataflow_var_def_, PackedFunc f_visit_type, PackedFunc f_visit_span) { + ObjectPtr n = make_object(); + n->builder_ = builder_; + n->f_visit_expr = f_visit_expr; + n->f_visit_constant_ = f_visit_constant_; + n->f_visit_tuple_ = f_visit_tuple_; + n->f_visit_var_ = f_visit_var_; + n->f_visit_dataflow_var_ = f_visit_dataflow_var_; + n->f_visit_shape_expr_ = f_visit_shape_expr_; + n->f_visit_runtime_dep_shape_ = f_visit_runtime_dep_shape_; + n->f_visit_extern_func_ = f_visit_extern_func_; + n->f_visit_global_var_ = f_visit_global_var_; + n->f_visit_function_ = f_visit_function_; + n->f_visit_call_ = f_visit_call_; + n->f_visit_seq_expr_ = f_visit_seq_expr_; + n->f_visit_if_ = f_visit_if_; + n->f_visit_op_ = f_visit_op_; + n->f_visit_tuple_getitem_ = f_visit_tuple_getitem_; + n->f_visit_binding = f_visit_binding; + n->f_visit_var_binding_ = f_visit_var_binding_; + n->f_visit_match_shape_ = f_visit_match_shape_; + n->f_visit_binding_block = f_visit_binding_block; + n->f_visit_binding_block_ = f_visit_binding_block_; + n->f_visit_dataflow_block_ = f_visit_dataflow_block_; + n->f_visit_var_def = f_visit_var_def; + n->f_visit_var_def_ = f_visit_var_def_; + n->f_visit_dataflow_var_def_ = f_visit_dataflow_var_def_; + n->f_visit_type = f_visit_type; + n->f_visit_span = f_visit_span; + return PyExprMutator(n); + } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PyExprMutator, ObjectRef, PyExprMutatorNode); +}; + } // namespace relax } // namespace tvm #endif // TVM_RELAX_EXPR_FUNCTOR_H_ diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 401fdab08a..9132402b4c 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -75,14 +75,27 @@ def _extract(inst: type, name: str): def method(*args, **kwargs): return getattr(inst, name)(*args, **kwargs) - if getattr(base, name) is getattr(cls, name) and name != "__str__": - # for task scheduler return None means calling default function - # otherwise it will trigger a TVMError of method not implemented - # on the c++ side when you call the method, __str__ not required - return None - return method + for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]): + # extract functions that differ from the base class + if not hasattr(base_cls, name): + continue + if getattr(base_cls, name) is getattr(inherit_cls, name) and name != "__str__": + continue + return method + + # for task scheduler return None means calling default function + # otherwise it will trigger a TVMError of method not implemented + # on the c++ side when you call the method, __str__ not required + return None assert isinstance(cls.__base__, type) + if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": + raise TypeError( + ( + f"Inheritance from a decorated object `{cls.__name__}` is not allowed. " + f"Please inherit from `{cls.__name__}._cls`." + ) + ) assert hasattr( cls, "_tvm_metadata" ), "Please use the user-facing method overriding class, i.e., PyRunner." @@ -95,6 +108,9 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + _cls = cls + _type = "TVMDerivedObject" + def __init__(self, *args, **kwargs): """Constructor.""" self.handle = None @@ -111,12 +127,22 @@ def __init__(self, *args, **kwargs): # using weakref to avoid cyclic dependency self._inst._outer = weakref.ref(self) - def __getattr__(self, name: str): - """Bridge the attribute function.""" - try: - return self._inst.__getattribute__(name) - except AttributeError: - return super(TVMDerivedObject, self).__getattr__(name) + def __getattr__(self, name): + # fall back to instance attribute if there is not any + # return self._inst.__getattribute__(name) + import inspect # pylint: disable=import-outside-toplevel + + result = self._inst.__getattribute__(name) + if inspect.ismethod(result): + + def method(*args, **kwargs): + return result(*args, **kwargs) + + # set __own__ to aviod implicit deconstruction + setattr(method, "__own__", self) + return method + + return result def __setattr__(self, name, value): if name not in ["_inst", "key", "handle"]: diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 3db32611ea..ae61cec2c3 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -77,6 +77,5 @@ # ExprFunctor ExprFunctor = expr_functor.ExprFunctor -ExprVisitor = expr_functor.ExprVisitor -ExprMutatorBase = expr_functor.ExprMutatorBase -ExprMutator = expr_functor.ExprMutator +PyExprVisitor = expr_functor.PyExprVisitor +PyExprMutator = expr_functor.PyExprMutator diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index 28e11613df..38129faa76 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -16,11 +16,13 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ """The expression functor of Relax.""" -from typing import Optional +from typing import Optional, Callable + +import tvm +from tvm.runtime import Object from tvm.ir import Op -from tvm.ir.base import structural_equal -from tvm.ir.module import IRModule -from .ty import DynTensorType +from tvm.meta_schedule.utils import derived_object + from .expr import Type, Span, Expr from .expr import Function, ExternFunc from .expr import Constant, Var, DataflowVar @@ -29,14 +31,85 @@ from .expr import Call, If, TupleGetItem from .expr import Binding, MatchShape, VarBinding from .expr import BindingBlock, DataflowBlock -from .expr import _update_shape, _update_type +from ..relay import Id +from ..ir.module import IRModule from .block_builder import BlockBuilder +from . import _ffi_api + +visitor = derived_object +""" +A decorator to wrap user-customized PyExprVisitor as TVM object _PyExprVisitor. + +Parameters +---------- +visitor_cls : PyExprVisitor + The user-customized PyExprVisitor. + +Returns +------- +cls : _PyExprVisitor + The decorated TVM object _PyExprVisitor(ExprVisitor on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.visitor + class MyExprVisitor(PyExprVisitor): + # customize visit function + def visit_call_(self, op: Call) -> None: + # just for demo purposes + ... + # myvisitor is now a special visitor that visit every Call with + # user-customized visit_call_ + myvisitor = MyExprVisitor() + # apply myvisitor to Expr/Binding/BindingBlock/VarDef + myvisitor.visit_expr(expr) + myvisitor.visit_binding(binding) + myvisitor.visit_binding_block(bindingblock) + myvisitor.visit_var_def(var) +""" + +mutator = derived_object +""" +A decorator to wrap user-customized PyExprMutator as TVM object _PyExprMutator. +Note: Cannot override visit function and post-order rewrite at the same time. + +Parameters +---------- +mutator_cls : PyExprMutator + The user-customized PyExprMutator. + +Returns +------- +cls : _PyExprMutator + The decorated TVM object _PyExprMutator(ExprMutator on the C++ side). + +Example +------- +.. code-block:: python + + @relax.expr_functor.mutator + class MyExprMutator(PyExprMutator): + # customize rewrite function + def visit_tuple_(self, op: Tuple) -> Expr: + # just for demo purposes + ... + + # mymutator is now a special mutator that rewrite every Tuple with + # user-customized visit_tuple_ + mymutator = MyExprMutator() + # apply mymutator to Expr/Binding/BindingBlock/VarDef + mymutator.visit_expr(expr) + mymutator.visit_binding(binding) + mymutator.visit_binding_block(bindingblock) + mymutator.visit_var_def(var) +""" class ExprFunctor: """ An abstract visitor defined over Expr. - Defines the default dispatch over expressions, and implements memoization. """ @@ -118,123 +191,23 @@ def visit_op_(self, op: Op): def visit_tuple_getitem_(self, op: TupleGetItem): raise NotImplementedError() - -class ExprVisitor(ExprFunctor): - """ - A visitor over Expr. - - The default behavior recursively traverses the AST. - """ - - def visit_expr(self, expr: Expr) -> None: - ExprFunctor.visit_expr(self, expr) - - def visit_constant_(self, op: Constant) -> None: - self.visit_span(op.span) - - if op.shape_: - self.visit_expr(op.shape_) - - def visit_global_var_(self, op: GlobalVar) -> None: - self.visit_span(op.span) - - def visit_tuple_(self, op: Tuple) -> None: - self.visit_span(op.span) - for field in op.fields: - self.visit_expr(field) - - if op.shape_: - self.visit_expr(op.shape_) - - def visit_var_(self, op: Var) -> None: - self.visit_span(op.span) - - def visit_dataflow_var_(self, op: DataflowVar) -> None: - self.visit_span(op.span) - - def visit_function_(self, op: Function) -> None: - self.visit_span(op.span) - for param in op.params: - self.visit_var_def(param) - - self.visit_expr(op.body) - - def visit_call_(self, op: Call) -> None: - self.visit_span(op.span) - self.visit_expr(op.op) - - for ty_arg in op.type_args: - self.visit_type(ty_arg) - - for arg in op.args: - self.visit_expr(arg) - - if op.shape_: - self.visit_expr(op.shape_) - - def visit_if_(self, op: If) -> None: - self.visit_span(op.span) - self.visit_expr(op.cond) - self.visit_expr(op.true_branch) - self.visit_expr(op.false_branch) - - def visit_op_(self, op: Op) -> None: - pass - - def visit_tuple_getitem_(self, op: TupleGetItem) -> None: - self.visit_span(op.span) - self.visit_expr(op.tuple_value) - - def visit_shape_expr_(self, op: ShapeExpr) -> None: - self.visit_span(op.span) - - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: - self.visit_span(op.span) - - def visit_extern_func_(self, op: ExternFunc) -> None: - self.visit_span(op.span) - - def visit_seq_expr_(self, op: SeqExpr) -> None: - self.visit_span(op.span) - for block in op.blocks: - self.visit_binding_block(block) - self.visit_expr(op.body) - - def visit_type(self, t: Type) -> None: - pass - - def visit_span(self, span: Span) -> None: - pass - def visit_var_binding_(self, binding: VarBinding) -> None: - self.visit_expr(binding.value) - self.visit_var_def(binding.var) + raise NotImplementedError() def visit_match_shape_(self, binding: MatchShape) -> None: - self.visit_expr(binding.value) - self.visit_expr(ShapeExpr(binding.pattern)) - if binding.var: - self.visit_var_def(binding.var) + raise NotImplementedError() def visit_binding_block_(self, block: BindingBlock) -> None: - for binding in block.bindings: - self.visit_binding(binding) + raise NotImplementedError() def visit_dataflow_block_(self, block: DataflowBlock) -> None: - for binding in block.bindings: - self.visit_binding(binding) + raise NotImplementedError() def visit_var_def_(self, var: Var) -> None: - self.visit_span(var.span) - - if var.shape_: - self.visit_expr(var.shape_) + raise NotImplementedError() def visit_dataflow_var_def_(self, var: DataflowVar) -> None: - self.visit_span(var.span) - - if var.shape_: - self.visit_expr(var.shape_) + raise NotImplementedError() def visit_binding(self, binding: Binding) -> None: if isinstance(binding, MatchShape): @@ -261,454 +234,1233 @@ def visit_var_def(self, var: Var): raise TypeError("Invalid type: {0}".format(type(var))) -class ExprMutatorBase(ExprFunctor): +@tvm._ffi.register_object("expr_functor.PyExprVisitor") +class _PyExprVisitor(Object): """ - A mutator works in unnormalized form. + A TVM object to support customization of ExprVisitor on the python side. + This is the decorated result returned from visitor decorator. - ExprMutatorBase expects input AST to be in the unnormalized form, - i.e., _checked_type_ and shape_ of expressions can be None, - and the expressions may nest (and as a result the AST is not in ANF). + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: visitor, PyExprVisitor """ - def visit_expr(self, expr: Expr) -> Expr: - return ExprFunctor.visit_expr(self, expr) + def __init__( + self, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_runtime_dep_shape_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_shape_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_type: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprVisitor, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_runtime_dep_shape_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_shape_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_type, + f_visit_span, + ) - def visit_constant_(self, op: Constant) -> Expr: - return op + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. - def visit_global_var_(self, op: GlobalVar) -> Expr: - return op + Parameters + ---------- + expr : Expr + The expr to be visited. + """ + return _ffi_api.PyExprVisitorVisitExpr(self, expr) - def visit_tuple_(self, op: Tuple) -> Expr: - unchanged = True - fields = [] - for field in op.fields: - new_field = self.visit_expr(field) - fields.append(new_field) - unchanged &= field.same_as(new_field) - - if unchanged: - return op - else: - return Tuple(fields, op.span) + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. - def visit_var_(self, op: Var) -> Expr: - return op + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprVisitorVisitBinding(self, binding) - def visit_dataflow_var_(self, op: DataflowVar) -> Expr: - return op + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. - def visit_function_(self, op: Function) -> Expr: - body = self.visit_expr(op.body) + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + return _ffi_api.PyExprVisitorVisitBindingBlock(self, block) - if op.body.same_as(body): - return op - else: - return Function(op.params, body, op.ret_type, op.attrs, op.span) - - def visit_call_(self, call_node: Call) -> Expr: - new_op = self.visit_expr(call_node.op) - unchanged = call_node.op.same_as(new_op) - - ty_args = [] - for ty_arg in call_node.type_args: - new_ty_arg = self.visit_type(ty_arg) - ty_args.append(new_ty_arg) - unchanged &= ty_arg.same_as(new_ty_arg) - - call_args = [] - for arg in call_node.args: - new_arg = self.visit_expr(arg) - call_args.append(new_arg) - unchanged &= arg.same_as(new_arg) - - if unchanged: - return call_node - else: - return Call(new_op, call_args, call_node.attrs, ty_args, call_node.span) + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. - def visit_if_(self, op: If) -> Expr: - guard = self.visit_expr(op.cond) - true_b = self.visit_expr(op.true_branch) - false_b = self.visit_expr(op.false_branch) - if ( - op.cond.same_as(guard) - and op.true_branch.same_as(true_b) - and op.false_branch.same_as(false_b) - ): - return op - else: - return If(guard, true_b, false_b, op.span) + Parameters + ---------- + var : Var + The var to be visited. + """ + return _ffi_api.PyExprVisitorVisitVarDef(self, var) - def visit_op_(self, op: Op) -> Expr: - return op - def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: - t = self.visit_expr(op.tuple_value) - if op.tuple_value.same_as(t): - return op - else: - return TupleGetItem(t, op.index) +class PyExprVisitor: + """ + An abstract ExprVisitor with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods + that users can overwrite("methods"). - def visit_shape_expr_(self, op: ShapeExpr) -> Expr: - return op + Note: @relax.expr_functor.visitor is required for proper usage of any inherited class. - def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: - return op + See also: visitor, _PyExprVisitor - def visit_extern_func_(self, op: ExternFunc) -> Expr: - return op + Example: + @relax.expr_functor.visitor + def MyExprVisitor(PyExprVisitor): + ... + """ - def visit_seq_expr_(self, op: SeqExpr) -> Expr: - all_blocks_unchanged = True - blocks = [] - for block in op.blocks: - new_block = self.visit_binding_block(block) - if new_block.bindings: - blocks.append(new_block) - all_blocks_unchanged &= block.same_as(new_block) - - body = self.visit_expr(op.body) - if all_blocks_unchanged and op.body.same_as(body): - return op - else: - return SeqExpr(blocks, body, op.span) + _tvm_metadata = { + "cls": _PyExprVisitor, + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_runtime_dep_shape_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_binding", + "visit_var_binding_", + "visit_match_shape_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_type", + "visit_span", + ], + } - def visit_binding_block(self, block: BindingBlock) -> BindingBlock: - """Mutate BindingBlock. + def visit_expr(self, expr: Expr) -> None: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. Parameters ---------- - block: BindingBlock - The binding block to be visited. - - Returns - ------- - block: BindingBlock - The binding block after transformation. + expr : Expr + The expr to be visited. """ - bindings = [] - if isinstance(block, BindingBlock): - for binding in block.bindings: - if isinstance(binding, VarBinding): - new_value = self.visit_expr(binding.value) - bindings.append(VarBinding(binding.var, new_value, binding.span)) - elif isinstance(binding, MatchShape): - new_value = self.visit_expr(binding.value) - bindings.append( - MatchShape(new_value, binding.pattern, binding.var, binding.span) - ) - else: - raise TypeError("Invalid type: {0}".format(type(block))) - else: - raise TypeError("Invalid type: {0}".format(type(block))) - if isinstance(block, DataflowBlock): - return DataflowBlock(bindings) - else: - return BindingBlock(bindings) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitExpr(self._outer(), expr) - def visit_type(self, t: Type) -> Type: - return t + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBinding(self._outer(), binding) -class ExprMutator(ExprMutatorBase): - """ - A mutator works in normal form. + def visit_binding_block(self, block: BindingBlock) -> None: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. - ExprMutator expects input AST to be in the normal form, i.e., the expressions are normalized(no - nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are - available. Note: We can use relax.transform.Normalize()(mod) to transform relax IR into - the normal form. - """ + Parameters + ---------- + block : BindingBlock + The block to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitBindingBlock(self._outer(), block) - def __init__(self, mod: Optional[IRModule] = None) -> None: - super().__init__() - self.builder_ = BlockBuilder(mod) - self.var_remap_ = dict() + def visit_var_def(self, var: Var) -> None: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. - def visit_expr(self, expr) -> Expr: - return self.builder_.normalize(ExprFunctor.visit_expr(self, expr)) + Parameters + ---------- + var : Var + The var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.PyExprVisitorVisitVarDef(self._outer(), var) - def visit_tuple_(self, op: Tuple) -> Expr: - unchanged = True - fields = [] - for field in op.fields: - new_field = self.visit_expr(field) - fields.append(new_field) - unchanged &= field.same_as(new_field) - - if unchanged: - return op - else: - new_tuple = Tuple(fields, op.span) - return new_tuple + def visit_constant_(self, op: Constant) -> None: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. - def visit_var_(self, op: Var) -> Expr: - if op.vid in self.var_remap_: - return self.var_remap_[op.vid] + Parameters + ---------- + op : Constant + The Constant to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - return op + def visit_tuple_(self, op: Tuple) -> None: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. - def visit_dataflow_var_(self, op: DataflowVar) -> Expr: - if op.vid in self.var_remap_: - return self.var_remap_[op.vid] + Parameters + ---------- + op : Tuple + The Tuple to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - return op + def visit_var_(self, op: Var) -> None: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. - def visit_function_(self, op: Function) -> Expr: - params = [] - all_params_unchanged = True - for param in op.params: - new_param = self.visit_var_def(param) - params.append(new_param) - all_params_unchanged &= param.same_as(new_param) - - ret_type = self.visit_type(op.ret_type) - body = self.visit_with_new_scope(op.body) - - # TODO(@lesheng): op.ret_type.same_as(ret_type) after Type.same_as is fixed - if all_params_unchanged and (op.ret_type == ret_type) and op.body.same_as(body): - return op - else: - return Function(params, body, ret_type, op.attrs, op.span) + Parameters + ---------- + op : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_if_(self, op: If) -> Expr: - guard = self.visit_expr(op.cond) - true_b = self.visit_with_new_scope(op.true_branch) - false_b = self.visit_with_new_scope(op.false_branch) - if ( - op.cond.same_as(guard) - and op.true_branch.same_as(true_b) - and op.false_branch.same_as(false_b) - ): - return op - else: - return If(guard, true_b, false_b, op.span) + def visit_dataflow_var_(self, op: DataflowVar) -> None: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. - def visit_seq_expr_(self, op: SeqExpr) -> Expr: - all_blocks_unchanged = True - blocks = [] - for block in op.blocks: - new_block = self.visit_binding_block(block) - if new_block.bindings: - blocks.append(new_block) - all_blocks_unchanged &= block.same_as(new_block) - - self.builder_._begin_binding_block() - body = self.visit_expr(op.body) - prologue = self.builder_._end_block() - if prologue.bindings: - blocks.append(prologue) - all_blocks_unchanged = False - - if all_blocks_unchanged and op.body.same_as(body): - return op - else: - return SeqExpr(blocks, body, op.span) + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_var_binding_(self, binding: VarBinding) -> None: - """Visit VarBinding, a new VarBinding will be emitted + def visit_shape_expr_(self, op: ShapeExpr) -> None: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. Parameters ---------- - binding: VarBinding - The VarBinding to be visited. + op : ShapeExpr + The ShapeExpr to be visited. """ - new_value = self.visit_expr(binding.value) - new_var = self.visit_var_def(binding.var) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def emit(b: VarBinding): - if self.builder_.current_block_is_dataflow() and not isinstance(b.var, DataflowVar): - self.builder_.emit_output_var_binding(b) - else: - self.builder_.emit_var_binding(b) + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: + """Visit RuntimeDepShape. + Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) + on the C++ side. - if binding.var.same_as(new_var) and binding.value.same_as(new_value): - emit(binding) - return + Parameters + ---------- + op : RuntimeDepShape + The RuntimeDepShape to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - temp = self.with_shape_and_type(new_var, new_value.shape_, new_value._checked_type_) - if not temp.same_as(new_var): - new_var = temp - self.var_remap_[binding.var.vid] = new_var + def visit_extern_func_(self, op: ExternFunc) -> None: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. - emit(VarBinding(new_var, new_value)) + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_match_shape_(self, binding: MatchShape) -> None: - """Visit MatchShape, a new MatchShape will be emitted + def visit_global_var_(self, op: GlobalVar) -> None: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. Parameters ---------- - binding: MatchShape - The MatchShape binding to be visited. + op : GlobalVar + The GlobalVar to be visited. """ - new_value = self.visit_expr(binding.value) - new_pattern = self.visit_expr(ShapeExpr(binding.pattern)) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - if binding.var: - new_shape = None - if new_value._checked_type_ and isinstance(new_value._checked_type_, DynTensorType): - new_shape = new_pattern - new_var = self.visit_var_def(binding.var) - temp = self.with_shape_and_type(new_var, new_shape, new_value._checked_type_) - if not temp.same_as(new_var): - new_var = temp - self.var_remap_[binding.var.vid] = new_var - - if binding.value.same_as(new_value) and binding.pattern.same_as(new_pattern): - if not binding.var or (binding.var and binding.var.same_as(new_var)): - self.builder_.match_shape_binding(binding) - return + def visit_function_(self, op: Function) -> None: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. - self.builder_.match_shape_binding(MatchShape(new_value, new_pattern.values, new_var)) + Parameters + ---------- + op : Function + The Function to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: - self.builder_._begin_binding_block() - for binding in block.bindings: - self.visit_binding(binding) - return self.builder_._end_block() + def visit_call_(self, op: Call) -> None: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. - def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: - self.builder_._begin_dataflow_block() - for binding in block.bindings: - self.visit_binding(binding) - return self.builder_._end_block() + Parameters + ---------- + op : Call + The Call to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: - """Rewrite the dataflow var definition site. + def visit_seq_expr_(self, op: SeqExpr) -> None: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. Parameters ---------- - var: DataflowVar - The dataflow var to be visited. + op : SeqExpr + The SeqExpr to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - Returns - ------- - var: Dataflowvar - The dataflow var after post-order rewritten. + def visit_if_(self, op: If) -> None: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. """ - shape_unchanged = True - new_shape = None - if var.shape_: - new_shape = self.visit_expr(var.shape_) - shape_unchanged &= var.shape_.same_as(new_shape) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - if shape_unchanged: - return var - else: - new_var = DataflowVar(var.vid, None, var._checked_type_, var.span) - _update_shape(new_var, new_shape) + def visit_op_(self, op: Op) -> None: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. - self.var_remap_[var.vid] = new_var - return new_var + Parameters + ---------- + op : Op + The Op to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - def visit_var_def_(self, var: Var) -> Var: - """Rewrite the var definition site. + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. Parameters ---------- - var: Var - The var to be visited. + op : TupleGetItem + The TupleGetItem to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitExpr(self._outer(), op) - Returns - ------- - var: Var - The var after post-order rewritten. + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. """ - shape_unchanged = True - new_shape = None - if var.shape_: - new_shape = self.visit_expr(var.shape_) - shape_unchanged &= var.shape_.same_as(new_shape) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) - if shape_unchanged: - return var - else: - new_var = Var(var.vid, None, var._checked_type_, var.span) - _update_shape(new_var, new_shape) + def visit_match_shape_(self, binding: MatchShape) -> None: + """Visit MatchShape. + Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + on the C++ side. - self.var_remap_[var.vid] = new_var - return new_var + Parameters + ---------- + binding : MatchShape + The MatchShape to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBinding(self._outer(), binding) - def visit_binding(self, binding: Binding) -> None: - if isinstance(binding, MatchShape): - self.visit_match_shape_(binding) - elif isinstance(binding, VarBinding): - self.visit_var_binding_(binding) - else: - raise TypeError("Invalid type: {0}".format(type(binding))) + def visit_binding_block_(self, block: BindingBlock) -> None: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. - def visit_binding_block(self, block: BindingBlock) -> BindingBlock: - if isinstance(block, DataflowBlock): - ret = self.visit_dataflow_block_(block) - elif isinstance(block, BindingBlock): - ret = self.visit_binding_block_(block) - else: - raise TypeError("Invalid type: {0}".format(type(block))) + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) - return ret + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. - def visit_var_def(self, var: Var) -> Var: - ret = None - if isinstance(var, DataflowVar): - ret = self.visit_dataflow_var_def_(var) - elif isinstance(var, Var): - ret = self.visit_var_def_(var) - else: - raise TypeError("Invalid type: {0}".format(type(var))) - return ret + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitBindingBlock(self._outer(), block) - def visit_with_new_scope(self, expr: Expr) -> Expr: - self.builder_._begin_binding_block() - ret = self.visit_expr(expr) - prologue = self.builder_._end_block() - if prologue.bindings: - ret = SeqExpr([prologue], ret) - return ret + def visit_var_def_(self, var: Var) -> None: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) - def with_shape_and_type(self, var: Var, shape: Optional[Expr], t: Type) -> Var: - """Create a new var with specified shape and type if the original var's shape or type - does not match with the specified ones. + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. Parameters ---------- - var: Var - The var to be updated. - shape: Optional[Expr] - The specified shape. - t: Type - The specified type. + var : DataflowVar + The DataflowVar to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitVarDef(self._outer(), var) - Returns - ------- - var: Var - The var filled with shape and type. + def visit_type(self, t: Type) -> None: + """Visit Type. + Users can customized this function to overwrite VisitType(const Type& t) on the C++ side. + + Parameters + ---------- + t : Type + The Type to be visited. """ - shape_changed = (var.shape_ is not None) ^ (shape is not None) - shape_changed |= ( - var.shape_ and shape and not self.builder_.can_prove_shape_equal(var.shape_, shape) - ) + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitType(self._outer(), t) - type_changed = (var._checked_type_ is not None) ^ (t is not None) - type_changed |= var._checked_type_ and t and not structural_equal(var._checked_type_, t) + def visit_span(self, span: Span) -> None: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. - if shape_changed or type_changed: - new_var = ( - DataflowVar(var.vid, None, None, var.span) - if isinstance(var, DataflowVar) - else Var(var.vid, None, None, var.span) - ) - _update_shape(new_var, var.shape_) - _update_type(new_var, var._checked_type_) - var = new_var + Parameters + ---------- + span : Span + The Span to be visited. + """ + # Using self._outer() to ref _PyExprVisitor + return _ffi_api.ExprVisitorVisitSpan(self._outer(), span) - if shape_changed: - var.shape_ = shape - if type_changed: - var._checked_type_ = t +@tvm._ffi.register_object("expr_functor.PyExprMutator") +class _PyExprMutator(Object): + """ + A TVM object to support customization of ExprMutator on the python side. + This is the decorated result returned from mutator decorator. - return var + WARNING: This is NOT the user facing class for method overwriting inheritance. + + See also: mutator, PyExprmutator + """ + + def __init__( + self, + builder: BlockBuilder = None, + f_visit_expr: Callable = None, + f_visit_constant_: Callable = None, + f_visit_tuple_: Callable = None, + f_visit_var_: Callable = None, + f_visit_dataflow_var_: Callable = None, + f_visit_shape_expr_: Callable = None, + f_visit_runtime_dep_shape_: Callable = None, + f_visit_extern_func_: Callable = None, + f_visit_global_var_: Callable = None, + f_visit_function_: Callable = None, + f_visit_call_: Callable = None, + f_visit_seq_expr_: Callable = None, + f_visit_if_: Callable = None, + f_visit_op_: Callable = None, + f_visit_tuple_getitem_: Callable = None, + f_visit_binding: Callable = None, + f_visit_var_binding_: Callable = None, + f_visit_match_shape_: Callable = None, + f_visit_binding_block: Callable = None, + f_visit_binding_block_: Callable = None, + f_visit_dataflow_block_: Callable = None, + f_visit_var_def: Callable = None, + f_visit_var_def_: Callable = None, + f_visit_dataflow_var_def_: Callable = None, + f_visit_type: Callable = None, + f_visit_span: Callable = None, + ) -> None: + """Constructor.""" + + self.__init_handle_by_constructor__( + _ffi_api.MakePyExprMutator, + builder, + f_visit_expr, + f_visit_constant_, + f_visit_tuple_, + f_visit_var_, + f_visit_dataflow_var_, + f_visit_shape_expr_, + f_visit_runtime_dep_shape_, + f_visit_extern_func_, + f_visit_global_var_, + f_visit_function_, + f_visit_call_, + f_visit_seq_expr_, + f_visit_if_, + f_visit_op_, + f_visit_tuple_getitem_, + f_visit_binding, + f_visit_var_binding_, + f_visit_match_shape_, + f_visit_binding_block, + f_visit_binding_block_, + f_visit_dataflow_block_, + f_visit_var_def, + f_visit_var_def_, + f_visit_dataflow_var_def_, + f_visit_type, + f_visit_span, + ) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + return _ffi_api.PyExprMutatorVisitExpr(self, expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + return _ffi_api.PyExprMutatorVisitBinding(self, binding) + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + return _ffi_api.PyExprMutatorVisitBindingBlock(self, block) + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitVarDef(self, var) + + +class PyExprMutator: + """ + An abstract ExprMutator with customized methods on the python-side. + This is the user facing class for method overwriting inheritance. + _tvm_metadata discribes the class to inherit("cls"), the methods that users can + overwrite("methods"), the constructor's parameters("fields") + + Note: @relax.expr_functor.mutator is required for proper usage of any inherited class. + + See also: visitor, _PyExprVisitor + + Example: + @relax.expr_functor.mutator + def MyExprMutator(PyExprMutator): + ... + """ + + _tvm_metadata = { + "cls": _PyExprMutator, + "fields": ["builder_"], + "methods": [ + "visit_expr", + "visit_constant_", + "visit_tuple_", + "visit_var_", + "visit_dataflow_var_", + "visit_shape_expr_", + "visit_runtime_dep_shape_", + "visit_extern_func_", + "visit_global_var_", + "visit_function_", + "visit_call_", + "visit_seq_expr_", + "visit_if_", + "visit_op_", + "visit_tuple_getitem_", + "visit_binding", + "visit_var_binding_", + "visit_match_shape_", + "visit_binding_block", + "visit_binding_block_", + "visit_dataflow_block_", + "visit_var_def", + "visit_var_def_", + "visit_dataflow_var_def_", + "visit_type", + "visit_span", + ], + } + + def __init__(self, mod: Optional[IRModule] = None) -> None: + """Constructor""" + self.builder_ = BlockBuilder(mod) + + def visit_expr(self, expr: Expr) -> Expr: + """Generic dispatcher for Expr. + Users can customized this function to overwrite VisitExpr(const Expr& expr) on the C++ side. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitExpr(self._outer(), expr) + + def visit_binding(self, binding: Binding) -> None: + """Generic dispatcher for Binding. + Users can customized this function to overwrite VisitBinding(const Binding& binding) + on the C++ side. + + Parameters + ---------- + binding : Binding + The binding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBinding(self._outer(), binding) + + def visit_binding_block(self, block: BindingBlock) -> BindingBlock: + """Generic dispatcher for BindingBlock. + Users can customized this function to overwrite VisitBindingBlock(const BindingBlock& block) + on the C++ side. + + Parameters + ---------- + block : BindingBlock + The block to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_var_def(self, var: Var) -> Var: + """Generic dispatcher for visiting the var definition site. + Users can customized this function to overwrite VisitVarDef(const Var& var) on the C++ side. + Note that visit_var_() will only visit the usage site of an Var. + + Parameters + ---------- + var : Var + The var to be visited. + + Returns + ------- + result: Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitVarDef(self._outer(), var) + + def visit_constant_(self, op: Constant) -> Expr: + """Visit Constant. + Users can customized this function to overwrite VisitExpr_(const ConstantNode* op) + on the C++ side. + + Parameters + ---------- + op : Constant + The Constant to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_tuple_(self, op: Tuple) -> Expr: + """Visit Tuple. + Users can customized this function to overwrite VisitExpr_(const TupleNode* op) + on the C++ side. + + Parameters + ---------- + op : Tuple + The Tuple to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_var_(self, op: Var) -> Expr: + """Visit Var. + Users can customized this function to overwrite VisitExpr_(const VarNode* op) + on the C++ side. + + Parameters + ---------- + op : Var + The Var to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + """Visit DataflowVar. + Users can customized this function to overwrite VisitExpr_(const DataflowVarNode* op) + on the C++ side. + + Parameters + ---------- + op : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + """Visit ShapeExpr. + Users can customized this function to overwrite VisitExpr_(const ShapeExprNode* op) + on the C++ side. + + Parameters + ---------- + op : ShapeExpr + The ShapeExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: + """Visit RuntimeDepShape. + Users can customized this function to overwrite VisitExpr_(const RuntimeDepShapeNode* op) + on the C++ side. + + Parameters + ---------- + op : RuntimeDepShape + The RuntimeDepShape to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + """Visit ExternFunc. + Users can customized this function to overwrite VisitExpr_(const ExternFuncNode* op) + on the C++ side. + + Parameters + ---------- + op : ExternFunc + The ExternFunc to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_global_var_(self, op: GlobalVar) -> Expr: + """Visit GlobalVar. + Users can customized this function to overwrite VisitExpr_(const GlobalVarNode* op) + on the C++ side. + + Parameters + ---------- + op : GlobalVar + The GlobalVar to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_function_(self, op: Function) -> Expr: + """Visit Function. + Users can customized this function to overwrite VisitExpr_(const FunctionNode* op) + on the C++ side. + + Parameters + ---------- + op : Function + The Function to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_call_(self, op: Call) -> Expr: + """Visit Call. + Users can customized this function to overwrite VisitExpr_(const CallNode* op) + on the C++ side. + + Parameters + ---------- + op : Call + The Call to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + """Visit SeqExpr. + Users can customized this function to overwrite VisitExpr_(const SeqExprNode* op) + on the C++ side. + + Parameters + ---------- + op : SeqExpr + The SeqExpr to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_if_(self, op: If) -> Expr: + """Visit If. + Users can customized this function to overwrite VisitExpr_(const IfNode* op) + on the C++ side. + + Parameters + ---------- + op : If + The If to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_op_(self, op: Op) -> Expr: + """Visit Op. + Users can customized this function to overwrite VisitExpr_(const OpNode* op) + on the C++ side. + + Parameters + ---------- + op : Op + The Op to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + """Visit TupleGetItem. + Users can customized this function to overwrite VisitExpr_(const TupleGetItemNode* op) + on the C++ side. + + Parameters + ---------- + op : TupleGetItem + The TupleGetItem to be visited. + + Returns + ------- + result : Expr + The Expr after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitExpr(self._outer(), op) + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Visit VarBinding. + Users can customized this function to overwrite VisitBinding_(const VarBindingNode* binding) + on the C++ side. + + Parameters + ---------- + binding : VarBinding + The VarBinding to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) + + def visit_match_shape_(self, binding: MatchShape) -> None: + """Visit MatchShape. + Users can customized this function to overwrite VisitBinding_(const MatchShapeNode* binding) + on the C++ side. + + Parameters + ---------- + binding : MatchShape + The MatchShape to be visited. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBinding(self._outer(), binding) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Visit BindingBlock. + Users can customized this function to overwrite VisitBindingBlock_(const BindingBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : BindingBlock + The BindingBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_dataflow_block_(self, block: DataflowBlock) -> BindingBlock: + """Visit DataflowBlock. + Users can customized this function to overwrite VisitBindingBlock_(const DataflowBlockNode* + block) on the C++ side. + + Parameters + ---------- + block : DataflowBlock + The DataflowBlock to be visited. + + Returns + ------- + result : BindingBlock + The binding block after transformation + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitBindingBlock(self._outer(), block) + + def visit_var_def_(self, var: Var) -> Var: + """Visit the Var definition site. + Users can customized this function to overwrite VisitVarDef_(const VarNode* var) + on the C++ side. + + Parameters + ---------- + var : Var + The Var to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) + + def visit_dataflow_var_def_(self, var: DataflowVar) -> Var: + """Visit the DataflowVar definition site. + Users can customized this function to overwrite VisitVarDef_(const DataflowVarNode* var) + on the C++ side. + + Parameters + ---------- + var : DataflowVar + The DataflowVar to be visited. + + Returns + ------- + result : Var + The var after post-order rewritten. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitVarDef(self._outer(), var) + + def visit_type(self, t: Type) -> Type: + """Visit Type. + Users can customized this function to overwrite VisitType(const Type& t) on the C++ side. + + Parameters + ---------- + t : Type + The Type to be visited. + + Returns + ------- + result : Type + The type after transformation. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.ExprMutatorVisitType(self._outer(), t) + + def visit_span(self, span: Span) -> Span: + """Visit Span. + Users can customized this function to overwrite VisitSpan(const Span& span) on the C++ side. + + Parameters + ---------- + span : Span + The Span to be visited. + + Returns + ------- + result : Span + The span after transformation. + """ + raise NotImplementedError + + def visit_expr_post_order(self, expr: Expr) -> Expr: + """Post-order rewrite an Expr and normalize. + + Parameters + ---------- + expr : Expr + The Expr to be rewritten. + + Returns + ------- + result : Expr + The Expr after post-order rewritten. + """ + return _ffi_api.PyExprMutatorVisitExprPostOrder(self._outer(), expr) + + def set_var_remap(self, vid: Id, var: Var) -> None: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var. + var : Var + The new var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorSetVarRemap(self._outer(), vid, var) + + def get_var_remap(self, vid: Id) -> Var: + """Remap a var to a new var in use-site. + + Parameters + ---------- + vid : Id + The vid of the old var + + Returns + ------- + var : Var + The remapped var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorGetVarRemap(self._outer(), vid) + + def visit_with_new_scope(self, expr: Expr) -> Expr: + """Rewrite the expr with a new scope, used in a Function's body and the branches of If. + + Parameters + ---------- + expr : Expr + The expr to be visited. + + Returns + ------- + var : Var + The expr after visiting. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorVisitWithNewScope(self._outer(), expr) def lookup_binding(self, var: Var) -> Optional[Expr]: - return self.builder_.lookup_binding(var) + """Look up the value bound to a variable. + Note: For function parameters, this function returns NullOpt. + + Parameters + ---------- + var : Var + The var to be looked up. + + Returns + ------- + var : Var + The value bound to the input var. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorLookupBinding(self._outer(), var) + + def with_shape_and_type(self, var: Var, shape: Optional[Object], t: Type) -> Var: + """Create a new var with specified shape and type if the original var's shape or type does + not match with the specified ones. + + Parameters + ---------- + var : Var + The var to be updated. + shape : Optional[Object] + The specified shape. + t : Type + The specified type. + + Returns + ------- + var : Var + The var filled with shape and type. + """ + # Using self._outer() to ref _PyExprMutator + return _ffi_api.PyExprMutatorWithShapeAndType(self._outer(), var, shape, t) diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index e73365acf2..c71b37c49d 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin, abstract-method, arguments-differ """ Utility script for printing Relax modules as AST diagrams, only intended to show how the AST is put together. diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 2abf23cfb9..6b0e664c3d 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument, invalid-name, no-else-return +# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" from __future__ import annotations @@ -24,7 +24,7 @@ from tvm.ir.transform import PassContext from tvm.target import Target from tvm.ir import transform -from tvm.relax import ExprMutator +from tvm.relax import PyExprMutator from tvm.relax.expr import Call from tvm.relay.backend.te_compiler import select_implementation @@ -69,7 +69,8 @@ def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: """ target = self.target - class Lowerer(ExprMutator): + @relax.expr_functor.mutator + class Lowerer(PyExprMutator): """Mutator that performs lowering.""" def visit_call_(self, call_node: Call): diff --git a/python/tvm/relax/transform/fma_rewrite.py b/python/tvm/relax/transform/fma_rewrite.py index 80cda574b9..7e7f307f42 100644 --- a/python/tvm/relax/transform/fma_rewrite.py +++ b/python/tvm/relax/transform/fma_rewrite.py @@ -14,17 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument, invalid-name +# pylint: disable=unused-argument, invalid-name, abstract-method """Perform fused multiply-add rewriting in Python""" from tvm.ir import Op from tvm.ir.module import IRModule from tvm.ir.transform import module_pass -from ..expr_functor import ExprMutator +from ..expr_functor import mutator, PyExprMutator from ..expr import Call, Function, Var from ..transform import dataflowblock_pass -class EwiseFMARewriter(ExprMutator): +@mutator +class EwiseFMARewriter(PyExprMutator): """Rewrites the relax.add call to a relax.ewise_fma call when detecting the multiply-add pattern. @@ -36,9 +37,8 @@ class EwiseFMARewriter(ExprMutator): z0 = ewise_fma(a, b, c) """ - def visit_call_(self, call_node: Call) -> Call: - call = self.builder_.normalize(ExprMutator.visit_call_(self, call_node)) - + def visit_call_(self, call: Call) -> Call: # pylint: disable=arguments-differ + call = self.visit_expr_post_order(call) add_op = Op.get("relax.add") multiply_op = Op.get("relax.multiply") ewise_fma_op = Op.get("relax.ewise_fma") @@ -62,7 +62,8 @@ def transform_dataflowblock(self, block, mod, ctx): return EwiseFMARewriter().visit_binding_block(block) -class EwiseFuseFMAMutator(ExprMutator): +@mutator +class EwiseFuseFMAMutator(PyExprMutator): """Performs multiply add fusion. The difference of EwiseFMARewriter and this EwiseFuseFMAMutator class is that this mutator generates a sub function(subgraph) whose body is a CallNode that calls to the relax.ewise_fma op, and rewrites the @@ -95,9 +96,8 @@ def transform(self) -> IRModule: return self.builder_.get() - def visit_call_(self, call_node: Call) -> Call: - call = self.builder_.normalize(ExprMutator.visit_call_(self, call_node)) - + def visit_call_(self, call: Call) -> Call: # pylint: disable=arguments-differ + call = self.visit_expr_post_order(call) add_op = Op.get("relax.add") multiply_op = Op.get("relax.multiply") ewise_fma_op = Op.get("relax.ewise_fma") diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index bc2c65044f..2e2ea89eb3 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -664,5 +664,128 @@ Var ExprMutator::WithShapeAndType(Var var, Optional shape, Type type) return var; } +TVM_REGISTER_GLOBAL("relax.MakePyExprVisitor").set_body_typed(PyExprVisitor::MakePyExprVisitor); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { visitor->VisitExpr(expr); }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { visitor->VisitVarDef(var); }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitExpr") + .set_body_typed([](PyExprVisitor visitor, const Expr& expr) { + visitor->ExprVisitor::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBinding") + .set_body_typed([](PyExprVisitor visitor, const Binding& binding) { + visitor->ExprVisitor::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitBindingBlock") + .set_body_typed([](PyExprVisitor visitor, const BindingBlock& block) { + visitor->ExprVisitor::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitVarDef") + .set_body_typed([](PyExprVisitor visitor, const Var& var) { + visitor->ExprVisitor::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitType") + .set_body_typed([](PyExprVisitor visitor, const Type& type) { + visitor->ExprVisitor::VisitType(type); + }); + +TVM_REGISTER_GLOBAL("relax.ExprVisitorVisitSpan") + .set_body_typed([](PyExprVisitor visitor, const Span& span) { + visitor->ExprVisitor::VisitSpan(span); + }); + +TVM_REGISTER_GLOBAL("relax.MakePyExprMutator").set_body_typed(PyExprMutator::MakePyExprMutator); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator visitor, const Binding& binding) { + visitor->VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator visitor, const BindingBlock& block) { + return visitor->VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitExpr") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->ExprMutator::VisitExpr(expr); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBinding") + .set_body_typed([](PyExprMutator visitor, const Binding& binding) { + return visitor->ExprMutator::VisitBinding(binding); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitBindingBlock") + .set_body_typed([](PyExprMutator visitor, const BindingBlock& block) { + return visitor->ExprMutator::VisitBindingBlock(block); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitVarDef") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->ExprMutator::VisitVarDef(var); + }); + +TVM_REGISTER_GLOBAL("relax.ExprMutatorVisitType") + .set_body_typed([](PyExprMutator visitor, const Type& type) { + return visitor->ExprMutator::VisitType(type); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitExprPostOrder") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitExprPostOrder(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorVisitWithNewScope") + .set_body_typed([](PyExprMutator visitor, const Expr& expr) { + return visitor->VisitWithNewScope(expr); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorLookupBinding") + .set_body_typed([](PyExprMutator visitor, const Var& var) { + return visitor->LookupBinding(var); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorWithShapeAndType") + .set_body_typed([](PyExprMutator visitor, Var var, Optional shape, Type type) { + return visitor->WithShapeAndType(var, shape, type); + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorSetVarRemap") + .set_body_typed([](PyExprMutator visitor, Id id, Var var) { + return visitor->var_remap_[id] = var; + }); + +TVM_REGISTER_GLOBAL("relax.PyExprMutatorGetVarRemap") + .set_body_typed([](PyExprMutator visitor, Id id) { return visitor->var_remap_[id]; }); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index e9fc4d897e..23344cbb9e 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -18,8 +18,19 @@ import tvm from tvm import relax, tir -from tvm.relax import ExprFunctor, ExprVisitor, ExprMutatorBase, ExprMutator +from tvm.relax import PyExprVisitor, PyExprMutator from tvm.ir.base import assert_structural_equal +from tvm.ir import Op +from tvm.relax.ty import DynTensorType +from tvm.relax.expr import Type, Span, Expr +from tvm.relax.expr import Function, ExternFunc +from tvm.relax.expr import Constant, Var, DataflowVar +from tvm.relax.expr import ShapeExpr, RuntimeDepShape +from tvm.relax.expr import GlobalVar, SeqExpr, Tuple +from tvm.relax.expr import Call, If, TupleGetItem +from tvm.relax.expr import Binding, MatchShape, VarBinding +from tvm.relax.expr import BindingBlock, DataflowBlock +from tvm.relax.expr import _update_shape, _update_type m, n = tir.Var("m", "int64"), tir.Var("n", "int64") type_anno1 = relax.DynTensorType(1, "float32") @@ -29,83 +40,440 @@ bb = relax.BlockBuilder() -def check_visit(expr): +@relax.expr_functor.visitor +class BasicVisitor(PyExprVisitor): + """Default ExprVisitor""" + + +class ASTLog: + """Helper class to log AST""" + + def __init__(self) -> None: + self.log = [] + self.indent = "\t" + self.level = 0 + + def push_scope(self): + self.level += 1 + + def pop_scope(self): + self.level -= 1 + + def add(self, s: str): + self.log.append(self.indent * self.level + s) + + def __str__(self) -> str: + return "\n".join(self.log) + + +@relax.expr_functor.visitor +class ASTPrinter(PyExprVisitor): + """Print relax AST in structured format. The shape of Node is ignored.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> None: + self.log.add("Constant") + + def visit_global_var_(self, op: GlobalVar) -> None: + self.log.add("GlobalVar") + + def visit_tuple_(self, op: Tuple) -> None: + self.log.add("Tuple") + self.log.push_scope() + for field in op.fields: + self.visit_expr(field) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_dataflow_var_(self, op: DataflowVar) -> None: + self.log.add("DataflowVar") + + def visit_function_(self, op: Function) -> None: + self.log.add("Function") + self.log.push_scope() + for param in op.params: + self.visit_var_def(param) + + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_call_(self, op: Call) -> None: + self.log.add("Call") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_if_(self, op: If) -> None: + self.log.add("If") + self.log.push_scope() + self.visit_expr(op.cond) + self.visit_expr(op.true_branch) + self.visit_expr(op.false_branch) + self.log.pop_scope() + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + def visit_tuple_getitem_(self, op: TupleGetItem) -> None: + self.log.add("TupleGetItem") + self.log.push_scope() + self.visit_expr(op.tuple_value) + self.log.pop_scope() + + def visit_shape_expr_(self, op: ShapeExpr) -> None: + self.log.add("ShapeExpr") + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> None: + self.log.add("RuntimeDepShape") + + def visit_extern_func_(self, op: ExternFunc) -> None: + self.log.add("ExternFunc") + + def visit_seq_expr_(self, op: SeqExpr) -> None: + self.log.add("SeqExpr") + self.log.push_scope() + for block in op.blocks: + self.visit_binding_block(block) + self.visit_expr(op.body) + self.log.pop_scope() + + def visit_var_binding_(self, binding: VarBinding) -> None: + self.log.add("VarBinding") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_match_shape_(self, binding: MatchShape) -> None: + self.log.add("MatchShape") + self.log.push_scope() + self.visit_expr(binding.value) + self.visit_expr(ShapeExpr(binding.pattern)) + if binding.var: + self.visit_var_def(binding.var) + self.log.pop_scope() + + def visit_binding_block_(self, block: BindingBlock) -> None: + self.log.add("BindingBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + self.log.add("DataflowBlock") + self.log.push_scope() + for binding in block.bindings: + self.visit_binding(binding) + self.log.pop_scope() + + def visit_var_def_(self, var: Var) -> None: + self.log.add("VarDef") + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + self.log.add("DataflowVarDef") + + +@relax.expr_functor.mutator +class BasicMutator(PyExprMutator): + """Default ExprMutator""" + + +@relax.expr_functor.mutator +class ASTPostPrinterMutator(PyExprMutator): + """Print relax AST in the post order format.""" + + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_constant_(self, op: Constant) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Constant") + return op + + def visit_global_var_(self, op: GlobalVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("GlobalVar") + return op + + def visit_tuple_(self, op: Tuple) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Tuple") + return op + + def visit_var_(self, op: Var) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Var") + return op + + def visit_dataflow_var_(self, op: DataflowVar) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("DataflowVar") + return op + + def visit_function_(self, op: Function) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Function") + return op + + def visit_call_(self, op: Call) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Call") + return op + + def visit_if_(self, op: If) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("If") + return op + + def visit_op_(self, op: Op) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("Op") + return op + + def visit_tuple_getitem_(self, op: TupleGetItem) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("TupleGetItem") + return op + + def visit_shape_expr_(self, op: ShapeExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ShapeExpr") + return op + + def visit_runtime_dep_shape_(self, op: RuntimeDepShape) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("RuntimeDepShape") + return op + + def visit_extern_func_(self, op: ExternFunc) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("ExternFunc") + return op + + def visit_seq_expr_(self, op: SeqExpr) -> Expr: + op = self.visit_expr_post_order(op) + self.log.add("SeqExpr") + return op + + def visit_var_binding_(self, binding: VarBinding) -> None: + """Identical with ExprMutator::VisitBinding_(const VarBindingNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_var = self.visit_var_def(binding.var) + + def emit(b: VarBinding): + if self.builder_.current_block_is_dataflow() and not isinstance(b.var, DataflowVar): + self.builder_.emit_output_var_binding(b) + else: + self.builder_.emit_var_binding(b) + + self.log.add("VarBinding") + if binding.var.same_as(new_var) and binding.value.same_as(new_value): + emit(binding) + return + + temp = self.with_shape_and_type(new_var, new_value.shape_, new_value._checked_type_) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + emit(VarBinding(new_var, new_value)) + + def visit_match_shape_(self, binding: MatchShape) -> None: + """Identical with ExprMutator::VisitBinding_(const MatchShapeNode* binding) on the C++ side.""" + new_value = self.visit_expr(binding.value) + new_pattern = self.visit_expr(ShapeExpr(binding.pattern)) + + if binding.var: + new_shape = None + if new_value._checked_type_ and isinstance(new_value._checked_type_, DynTensorType): + new_shape = new_pattern + new_var = self.visit_var_def(binding.var) + temp = self.with_shape_and_type(new_var, new_shape, new_value._checked_type_) + if not temp.same_as(new_var): + new_var = temp + self.set_var_remap(binding.var.vid, new_var) + + self.log.add("MatchShape") + if binding.value.same_as(new_value) and binding.pattern.same_as(new_pattern): + if not binding.var or (binding.var and binding.var.same_as(new_var)): + self.builder_.match_shape_binding(binding) + return + + self.builder_.match_shape_binding(MatchShape(new_value, new_pattern.values, new_var)) + + def visit_binding_block_(self, block: BindingBlock) -> BindingBlock: + """Identical with ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) on the C++ side.""" + self.builder_._begin_binding_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("BindingBlock") + return self.builder_._end_block() + + def visit_dataflow_block_(self, block: DataflowBlock) -> None: + """Identical with ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) on the C++ side.""" + self.builder_._begin_dataflow_block() + for binding in block.bindings: + self.visit_binding(binding) + self.log.add("DataflowBlock") + return self.builder_._end_block() + + def visit_var_def_(self, var: Var) -> None: + """Identical with ExprMutator::VisitVarDef_(const VarNode* var) on the C++ side.""" + shape_unchanged = True + new_shape = None + if var.shape_: + new_shape = self.visit_expr(var.shape_) + shape_unchanged &= var.shape_.same_as(new_shape) + + self.log.add("VarDef") + if shape_unchanged: + return var + else: + new_var = Var(var.vid, None, var._checked_type_, var.span) + _update_shape(new_var, new_shape) + + self.set_var_remap(var.vid, new_var) + return new_var + + def visit_dataflow_var_def_(self, var: DataflowVar) -> None: + """Identical with ExprMutator::VisitVarDef_(const DataflowVarNode* var) on the C++ side.""" + shape_unchanged = True + new_shape = None + if var.shape_: + new_shape = self.visit_expr(var.shape_) + shape_unchanged &= var.shape_.same_as(new_shape) + + self.log.add("DataflowVarDef") + if shape_unchanged: + return var + else: + new_var = DataflowVar(var.vid, None, var._checked_type_, var.span) + _update_shape(new_var, new_shape) + + self.set_var_remap(var.vid, new_var) + return new_var + + +def basic_check(expr, visitor_str, mutator_str): def visit(f, expr): if isinstance(expr, relax.Expr): return f.visit_expr(expr) elif isinstance(expr, relax.BindingBlock): return f.visit_binding_block(expr) - if isinstance(expr, relax.Expr): - with pytest.raises(NotImplementedError): - ef = ExprFunctor() - visit(ef, expr) + # check no overloading case + basic_visitor = BasicVisitor() + visit(basic_visitor, expr) - ev = ExprVisitor() - visit(ev, expr) + # check the output log + log_visitor = ASTPrinter() + visit(log_visitor, expr) + assert str(log_visitor.log) == visitor_str - em_base = ExprMutatorBase() - assert_structural_equal(visit(em_base, expr), expr) + # check no overloading case + basic_mutator = BasicMutator() + if isinstance(expr, relax.Expr): + expr = bb.normalize(expr) + assert_structural_equal(visit(basic_mutator, expr), expr) - em = ExprMutator() + # check the output log and return value + post_log_mutator = ASTPostPrinterMutator() if isinstance(expr, relax.Expr): expr = bb.normalize(expr) - assert_structural_equal(visit(em, expr), expr) + assert_structural_equal(visit(post_log_mutator, expr), expr) + assert str(post_log_mutator.log) == mutator_str def test_constant(): - check_visit(relax.const(1.0)) + basic_check(relax.const(1.0), "Constant", "Constant") def test_var(): - check_visit(x) + basic_check(x, "Var", "Var") def test_dataflow_var(): lv = relax.DataflowVar("lv", [n], type_anno1) - check_visit(lv) + basic_check(lv, "DataflowVar", "DataflowVar") def test_tuple(): t = relax.Tuple([x, y]) - check_visit(t) + basic_check(t, "\n".join(["Tuple", "\tVar", "\tVar"]), "\n".join(["Var", "Var", "Tuple"])) def test_global_var(): gv = relax.GlobalVar("gv") - check_visit(gv) + basic_check(gv, "GlobalVar", "GlobalVar") def test_seq_expr(): bindings = [relax.VarBinding(x, relax.const(1))] blocks = [relax.BindingBlock(bindings)] seq_expr = relax.SeqExpr(blocks, x) - check_visit(seq_expr) + basic_check( + seq_expr, + "\n".join( + [ + "SeqExpr", + "\tBindingBlock", + "\t\tVarBinding", + "\t\t\tConstant", + "\t\t\tVarDef", + "\tVar", + ] + ), + "\n".join( + ["Constant", "ShapeExpr", "VarDef", "VarBinding", "BindingBlock", "Var", "SeqExpr"] + ), + ) def test_shape_expr(): x = relax.ShapeExpr([m, n]) - check_visit(x) + basic_check(x, "ShapeExpr", "ShapeExpr") def test_runtime_dep_shape(): runtime_dep_shape = relax.RuntimeDepShape() - check_visit(runtime_dep_shape) + basic_check(runtime_dep_shape, "RuntimeDepShape", "RuntimeDepShape") def test_call(): call_node = relax.op.add(x, y) - check_visit(call_node) + basic_check( + call_node, + "\n".join(["Call", "\tOp", "\tVar", "\tVar"]), + "\n".join(["Op", "Var", "Var", "Call"]), + ) def test_if(): if_node = relax.If(x, x, x) - check_visit(if_node) + basic_check( + if_node, + "\n".join(["If", "\tVar", "\tVar", "\tVar"]), + "\n".join(["Var", "Var", "Var", "If"]), + ) def test_tuple_getitem(): - op = relax.TupleGetItem(relax.Tuple([x, y]), 0) - check_visit(op) + tuple_getitem_node = relax.TupleGetItem(relax.Tuple([x, y]), 0) + basic_check( + tuple_getitem_node, + "\n".join(["TupleGetItem", "\tTuple", "\t\tVar", "\t\tVar"]), + "\n".join(["Var", "Var", "Tuple", "TupleGetItem"]), + ) def test_binding_block(): @@ -113,7 +481,41 @@ def test_binding_block(): gv0 = bb.emit(relax.op.add(x, y)) gv1 = bb.match_shape(y, [m, n]) b0 = bb._end_block() - check_visit(b0) + basic_check( + b0, + "\n".join( + [ + "BindingBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tVarDef", + "\tMatchShape", + "\t\tVar", + "\t\tShapeExpr", + "\t\tVarDef", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "VarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "VarDef", + "MatchShape", + "BindingBlock", + ] + ), + ) def test_dataflow_block(): @@ -121,7 +523,41 @@ def test_dataflow_block(): lv0 = bb.emit(relax.op.add(x, y)) gv1 = bb.match_shape(y, [m, n]) b0 = bb._end_block() - check_visit(b0) + basic_check( + b0, + "\n".join( + [ + "DataflowBlock", + "\tVarBinding", + "\t\tCall", + "\t\t\tOp", + "\t\t\tVar", + "\t\t\tVar", + "\t\tDataflowVarDef", + "\tMatchShape", + "\t\tVar", + "\t\tShapeExpr", + "\t\tDataflowVarDef", + ] + ), + "\n".join( + [ + "Op", + "Var", + "Var", + "Call", + "ShapeExpr", + "DataflowVarDef", + "VarBinding", + "Var", + "ShapeExpr", + "ShapeExpr", + "DataflowVarDef", + "MatchShape", + "DataflowBlock", + ] + ), + ) def test_function(): @@ -130,12 +566,209 @@ def test_function(): seq_expr = relax.SeqExpr(blocks, x) ret_type = relax.DynTensorType(-1, "float32") func = relax.Function([x], seq_expr, ret_type) - check_visit(func) + basic_check( + func, + "\n".join( + [ + "Function", + "\tVarDef", + "\tSeqExpr", + "\t\tBindingBlock", + "\t\t\tVarBinding", + "\t\t\t\tConstant", + "\t\t\t\tVarDef", + "\t\tVar", + ] + ), + "\n".join( + [ + "ShapeExpr", + "VarDef", + "Constant", + "ShapeExpr", + "VarDef", + "VarBinding", + "BindingBlock", + "Var", + "SeqExpr", + "Function", + ] + ), + ) def test_extern_func(): func = relax.ExternFunc("f") - check_visit(func) + basic_check(func, "ExternFunc", "ExternFunc") + + +def test_inherit(): + # The internal class is not instantiated. + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_inherit_with_cls(): + # The decorator converts `InternalVisitor` to a wrapper class. + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + # `InternalVisitor._cls` refers to the original `InternalVisitor` users defined. + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + self.log.push_scope() + self.visit_expr(op.op) + + for arg in op.args: + self.visit_expr(arg) + self.log.pop_scope() + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "\tOp", "\tVar", "\tVar"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "\tOp", "\tVar", "\tVar"]) + + +def test_wrong_inherit(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def visit_call_(self, op: Call) -> None: + pass + + with pytest.raises( + TypeError, + match="Inheritance from a decorated object `LeafVisitor` is not allowed. Please inherit from `LeafVisitor._cls`.", + ): + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor): + def visit_call_(self, op: Call) -> None: + pass + + +def test_call_visitor_super(): + @relax.expr_functor.visitor + class InternalVisitor(PyExprVisitor): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + super().visit_call_(op) # call PyExprVisitor.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + + @relax.expr_functor.visitor + class LeafVisitor(InternalVisitor._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + super().visit_call_(op) # call InternalVisit.visit_call_ + + call_node = relax.op.add(x, y) + iv = InternalVisitor() + iv.visit_expr(call_node) + assert str(iv.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lv = LeafVisitor() + lv.visit_expr(call_node) + assert str(lv.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) + + +def test_call_mutator_super(): + @relax.expr_functor.mutator + class InternalMutator(PyExprMutator): + def __init__(self) -> None: + super().__init__() + self.log = ASTLog() + + def visit_call_(self, op: Call) -> None: + self.log.add("InternalCall") + return super().visit_call_(op) # call PyExprMutator.visit_call_ + + def visit_var_(self, op: Var) -> None: + self.log.add("Var") + return super().visit_var_(op) # call PyExprMutator.visit_var_ + + def visit_op_(self, op: Op) -> None: + self.log.add("Op") + return super().visit_op_(op) # call PyExprMutator.visit_op_ + + @relax.expr_functor.mutator + class LeafMutator(InternalMutator._cls): + def visit_call_(self, op: Call) -> None: + self.log.add("LeafCall") + return super().visit_call_(op) # call InternalMutator.visit_call_ + + call_node = relax.op.add(x, y) + im = InternalMutator() + im.visit_expr(call_node) + assert str(im.log) == "\n".join(["InternalCall", "Op", "Var", "Var"]) + + lm = LeafMutator() + lm.visit_expr(call_node) + assert str(lm.log) == "\n".join(["LeafCall", "InternalCall", "Op", "Var", "Var"]) if __name__ == "__main__": diff --git a/tests/python/relax/test_pass_manager.py b/tests/python/relax/test_pass_manager.py index 16e31c5025..288fd4cf08 100644 --- a/tests/python/relax/test_pass_manager.py +++ b/tests/python/relax/test_pass_manager.py @@ -21,7 +21,6 @@ import tvm from tvm import relax, ir from tvm.ir.base import assert_structural_equal -from tvm.relax import ExprMutator from tvm.relax.expr import Call import tvm.script @@ -66,13 +65,13 @@ def f2(x: Tensor((m, n), "float32")): # Swap Multiply and Add Ops -class SwapMAVar(ExprMutator): +@relax.expr_functor.mutator +class SwapMAVar(relax.PyExprMutator): def __init__(self) -> None: super().__init__() def visit_call_(self, call: Call) -> Call: - call = ExprMutator.visit_call_(self, call) - + call = self.visit_expr_post_order(call) if call.op == ir.Op.get("relax.add"): new_op = ir.Op.get("relax.multiply") elif call.op == ir.Op.get("relax.multiply"):