Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax AST Design

Jiawei Liu edited this page Jun 30, 2022 · 8 revisions

Relax AST Design

Authors(alphabetical): @altanh, @electriclilies, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang

This doc is meant to serve as a design overview of the Relax AST. For a broad background of Relax, please refer to Relax Architecture Overview.

To support the key goals (G0: support dynamic shape workloads, and G1: dataflow block as a first class citizen) in the architecture overview, Relax adds the following constructs to the AST.

class ShapeExpr(Expr):
    """corresponds to a shape containing symbolic PrimExpr"""
    values: List[PrimExpr]

class Var(Expr):
    """global scope visible vars"""
    vid: Id
    type_annotation: Optional[Type]

class DataflowVar(Var):
    """dataflow scope visible vars"""
    pass

class Binding(ObjectRef):
    """the base class of bindings"""
    pass

class VarBinding(Binding):
    """variable bindings, bind the value to the var"""
    var: Var
    value: Expr

class MatchShape(Binding):
    """binding represents to match a shape"""
    value: Expr
    pattern: List[PrimExpr]
    var: Var

class BindingBlock(Node):
    """base class of binding block, bindings inside can be impure (with side effect or control flow)"""
    bindings: List[Binding]

class DataflowBlock(BindingBlock):
    """dataflow block, bindings inside are pure (no side effect and no control flow)"""
    pass

class SeqExpr(Expr):
    """sequence of BindingBlocks, can serve as the body of a Function"""
    blocks: List[BindingBlock]
    body: Expr

class Function(BaseFunc):
    """represents a Relax function"""
    params: List[Var]
    body: Expr   
    ret_type: Type

class ExternFunc(BaseFunc):
    """extern function, which can represent a TIR PrimFunc or a PackedFunc."""
    global_symbol: String

Overall structure of a Relax Function

  • A Function's body can be a SeqExpr.
  • A SeqExpr consists of a list of BindingBlock.
  • DataflowBlock is a special kind of BindingBlock that is identical to a pure computational graph. The bindings inside DataflowBlock have no side effects and no control.
  • A BindingBlock consists of a list of Binding.
  • Binding can be either VarBinding or MatchShape.
  • The scope of a DataflowVar is its DataflowBlock, a normal Var in a DataflowBlock escapes to the scope containing the block (which could be the function scope or some other scope like an if branch). Note that TIR vars (bound by MatchShape) have the same scoping rules as normal Vars.

Let's take the following relax program as an example, relax_func contains a SeqExpr, the SeqExpr contains a DataflowBlock (with 2 VarBinding) and a BindingBlock with one VarBinding.

from tvm.script import relax as R

@R.func
def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[(k, m), "f32"]):
    # start a DataflowBlock
    with R.dataflow(): ## <= DataflowBlock
        lv0: R.Tensor[(n, m), "f32"] = R.dot(x, w) ## <= VarBinding, lv0 is a DataflowVar
        gv0: R.Tensor[(n * m,), "f32"] = R.flatten(lv0) ## <= VarBinding, gv0 is a Var that escapes to the outer scope
        R.outputs(gv0)

    # start a BindingBlock
    gv1 = R.call_packed("custom_inplace_update", gv0) ## <= side-effect binding
    return gv1

Why separate DataflowBlock and BindingBlock?

Most pass writers are ML researcher and ML engineers who have no compiler or PL background, so they write passes based on the simple assumption that the passes are mutating a pure computational graph. Relay is not explicit about which expressions have side-effects vs. which are pure, as a result, many optimizations are unsound in the presence of side-effects. In Relax, DataflowBlock represents a computational graph region where all the bindings inside are pure (no side effects, no control flow). Clearly separating the graph region and having it as first-class citizen makes it easy for end-users to write graph passes. Due to this clear separation between the "pure" and "impure" regions, a Function's body can be composed of one or more pure or impure blocks, so SeqExpr's body comes with Array<BindingBlock>.

MatchShape is a kind of Binding:

MatchShape(value: Expr, pattern: List[PrimExpr], var: Var)

  • MatchShape has two overloaded semantics:
    • Suppose x is a 2-D tensor:
      • (1) MatchShape(x.shape, [m, n], var) → matches x.shape to the symbolic variables (m, n), and returns a Shape to the return var;
      • (2) MatchShape(x, [m, n], var) → matches x.shape to the symbolic variables (m, n), and returns a 2-D tensor with the same shape as tensor x (but with explicit shape field [m, n]) to the output var.

Implication on IR transformations:

  • DataflowBlock is a self-contained data structure which contains a list of bindings. Pass writers can visit and transform a pure dataflow block using the DataflowMutator interface. It could be more user-friendly to those pass writers with only ML background who are familiar with computational dataflow because they only need to face this simple concept of DataflowBlock and override visitors in DataflowMutator.
  • A BindingBlock is a list of Binding. The ExprMutator works on ANF program, so the visitor can traverse the bindings without the need of memoization and there is no stack overflow since there is no recursion.
  • The ExprMutator has an internal BlockBuilder that can emit bindings to the newly created block. Why having a internal BlockBuilder in the ExprMutator?
    • BlockBuilder provides APIs for emitting bindings to a block. We can often see the cases where we want to fold several bindings into one (n → 1) or we want to rewrite a binding to multiple bindings (1 → n). Using BlockBuilder to emit bindings in the visitor can easily do both.
    • The BlockBuilder can do eager shape and type inference, so the shape_ and checked_type_ fields of both lhs var and rhs expr can be filled when emitting new bindings.

Hierarchy of Visitor Pattern in Relax

TVM performs IR analysis and rewriting through a recursive-descent visitor pattern over the expression's abstract syntax tree. For example, to write an analysis pass for counting the number of relax.Var definitions within a Relax.Function, we can overload relax::ExprVisitor via class CountVarDef : public relax::ExprVisitor.

The hierachy of the default visitor pattern is shown below. To count the number of variable defines, we only need to overload the VisitVarDef function:

class CountVarDef : public relax::ExprVisitor {
public:
    size_t n_vardef {0};
    void VisitVarDef(const Var& var) override { ++n_vardef; };
    void VisitExpr(const Expr& expr) override { n_vardef = 0; ExprVisitor::VisitExpr(expr); };
};

// Export to Python front end.
TVM_REGISTER_GLOBAL(("relax.analysis.count_vardef")).set_body_typed([](Function f) {
    CountVarDef counter{};
    counter.VisitExpr(f);
    return counter.n_vardef;
});

irvisit_hierarchy

Similarly, if we want to only count the number of DataflowVar definitions, according to the figure, we only need to overload VisitVarDef_(const DataflowVarNode*) as by default (ExprVisitor's implementation) VisitVarDef will automatically and dynamically dispatch the corresponding VisitVarDef_ according to the variable's type (Var or DataflowVar. See also the top-left blue blocks).