-
Notifications
You must be signed in to change notification settings - Fork 58
Relax AST Design
Authors(alphabetical): @altanh, @electriclilies, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang
This doc is meant to serve as a design overview of the Relax AST. For a broad background of Relax, please refer to Relax Architecture Overview.
To support the key goals (G0: support dynamic shape workloads, and G1: dataflow block as a first class citizen) in the architecture overview, Relax adds the following constructs to the AST.
class ShapeExpr(Expr):
"""corresponds to a shape containing symbolic PrimExpr"""
values: List[PrimExpr]
class Var(Expr):
"""global scope visible vars"""
vid: Id
type_annotation: Optional[Type]
class DataflowVar(Var):
"""dataflow scope visible vars"""
pass
class Binding(ObjectRef):
"""the base class of bindings"""
pass
class VarBinding(Binding):
"""variable bindings, bind the value to the var"""
var: Var
value: Expr
class MatchShape(Binding):
"""binding represents to match a shape"""
value: Expr
pattern: List[PrimExpr]
var: Var
class BindingBlock(Node):
"""base class of binding block, bindings inside can be impure (with side effect or control flow)"""
bindings: List[Binding]
class DataflowBlock(BindingBlock):
"""dataflow block, bindings inside are pure (no side effect and no control flow)"""
pass
class SeqExpr(Expr):
"""sequence of BindingBlocks, can serve as the body of a Function"""
blocks: List[BindingBlock]
body: Expr
class Function(BaseFunc):
"""represents a Relax function"""
params: List[Var]
body: Expr
ret_type: Type
class ExternFunc(BaseFunc):
"""extern function, which can represent a TIR PrimFunc or a PackedFunc."""
global_symbol: String
- A
Function
's body can be aSeqExpr
. - A
SeqExpr
consists of a list ofBindingBlock
. -
DataflowBlock
is a special kind ofBindingBlock
that is identical to a pure computational graph. The bindings insideDataflowBlock
have no side effects and no control. - A
BindingBlock
consists of a list ofBinding
. -
Binding
can be eitherVarBinding
orMatchShape
. - The scope of a
DataflowVar
is itsDataflowBlock
, a normalVar
in aDataflowBlock
escapes to the scope containing the block (which could be the function scope or some other scope like an if branch). Note that TIR vars (bound byMatchShape
) have the same scoping rules as normalVar
s.
Let's take the following relax program as an example, relax_func
contains a SeqExpr
, the SeqExpr
contains a DataflowBlock
(with 2 VarBinding
) and a BindingBlock
with one VarBinding
.
from tvm.script import relax as R
@R.func
def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[(k, m), "f32"]):
# start a DataflowBlock
with R.dataflow(): ## <= DataflowBlock
lv0: R.Tensor[(n, m), "f32"] = R.dot(x, w) ## <= VarBinding, lv0 is a DataflowVar
gv0: R.Tensor[(n * m,), "f32"] = R.flatten(lv0) ## <= VarBinding, gv0 is a Var that escapes to the outer scope
R.outputs(gv0)
# start a BindingBlock
gv1 = R.call_packed("custom_inplace_update", gv0) ## <= side-effect binding
return gv1
Most pass writers are ML researcher and ML engineers who have no compiler or PL background, so they write passes based on the simple assumption that the passes are mutating a pure computational graph. Relay is not explicit about which expressions have side-effects vs. which are pure, as a result, many optimizations are unsound in the presence of side-effects. In Relax, DataflowBlock
represents a computational graph region where all the bindings inside are pure (no side effects, no control flow). Clearly separating the graph region and having it as first-class citizen makes it easy for end-users to write graph passes. Due to this clear separation between the "pure" and "impure" regions, a Function's body can be composed of one or more pure or impure blocks, so SeqExpr
's body comes with Array<BindingBlock>
.
MatchShape(value: Expr, pattern: List[PrimExpr], var: Var)
-
MatchShape
has two overloaded semantics:- Suppose
x
is a 2-D tensor:- (1)
MatchShape(x.shape, [m, n], var)
→ matchesx.shape
to the symbolic variables (m, n), and returns a Shape to the return var; - (2)
MatchShape(x, [m, n], var)
→ matchesx.shape
to the symbolic variables (m, n), and returns a 2-D tensor with the same shape as tensor x (but with explicit shape field[m, n]
) to the output var.
- (1)
- Suppose
-
DataflowBlock
is a self-contained data structure which contains a list of bindings. Pass writers can visit and transform a pure dataflow block using theDataflowMutator
interface. It could be more user-friendly to those pass writers with only ML background who are familiar with computational dataflow because they only need to face this simple concept ofDataflowBlock
and override visitors inDataflowMutator
. - A
BindingBlock
is a list ofBinding
. TheExprMutator
works on ANF program, so the visitor can traverse the bindings without the need of memoization and there is no stack overflow since there is no recursion. - The
ExprMutator
has an internalBlockBuilder
that can emit bindings to the newly created block. Why having a internalBlockBuilder
in theExprMutator
?- BlockBuilder provides APIs for emitting bindings to a block. We can often see the cases where we want to fold several bindings into one (n → 1) or we want to rewrite a binding to multiple bindings (1 → n). Using BlockBuilder to emit bindings in the visitor can easily do both.
- The BlockBuilder can do eager shape and type inference, so the
shape_
andchecked_type_
fields of both lhs var and rhs expr can be filled when emitting new bindings.
TVM performs IR analysis and rewriting through a recursive-descent visitor pattern over the expression's abstract syntax tree. For example, to write an analysis pass for counting the number of relax.Var
definitions within a Relax.Function
, we can overload relax::ExprVisitor
via class CountVarDef : public relax::ExprVisitor
.
The hierachy of the default visitor pattern is shown below. To count the number of variable defines, we only need to overload the VisitVarDef
function:
class CountVarDef : public relax::ExprVisitor {
public:
size_t n_vardef {0};
void VisitVarDef(const Var& var) override { ++n_vardef; };
void VisitExpr(const Expr& expr) override { n_vardef = 0; ExprVisitor::VisitExpr(expr); };
};
// Export to Python front end.
TVM_REGISTER_GLOBAL(("relax.analysis.count_vardef")).set_body_typed([](Function f) {
CountVarDef counter{};
counter.VisitExpr(f);
return counter.n_vardef;
});
Similarly, if we want to only count the number of DataflowVar
definitions, according to the figure, we only need to overload VisitVarDef_(const DataflowVarNode*)
as by default (ExprVisitor
's implementation) VisitVarDef
will automatically and dynamically dispatch the corresponding VisitVarDef_
according to the variable's type (Var
or DataflowVar
. See also the top-left blue blocks).