Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax Compilation MVP Design

Yuchen Jin edited this page Jun 22, 2022 · 10 revisions

Relax Compilation MVP Design

Authors(alphabetical): @altanh, @tqchen, @YuchenJin, @ZihengJiang

Key Goals

  • Compile Relax program to a format that VM can execute
  • A multi-stage compilation pipeline that enables composable compilation transformations
    • Every transformation is IRModuleIRModule
    • Users might run part of the program with third-party libraries such as cudnn. We need to be capable to optimize the left part.

Example Relax Program

Let's take compiling the following simple Relax program as a running example.

import tvm.script
from tvm.script import tir as T, relax as R

@tvm.script.ir_module
class MyIRModule:
    @T.prim_func
    def tirexp(a: ty.handle, b: ty.handle):
        n1, m1 = T.var("n1"), T.var("m1")
	X = T.match_buffer(x, (n1, m1))
	Y = T.match_buffer(y, (n1, m1))
	with T.block(n1, m1) as i, j:
	    Y[i, j] = T.exp(X[i, j])
    @R.function
    def myfunc(x: Tensor[(n, m)]):
        with R.dataflow():
	    lv0: Tensor[(n, m)] = R.call_tir((n, m), tirexp, [x])
	    gv0: Tensor[(m*n,)] = R.call_tir((m*n,), "flatten", [lv0])
	    R.outputs(gv0)

        return gv0

We introduced a new intrinsic relax.call_tir in relax, and use it to construct program. This CallTIR form program has the following properties:

  • No side effect
  • With shape annotation
  • Core expression: call_tir(output_shape, func, [arg0, arg1, ...], Optional<shape_expr>) -> Expr
    • This means that we take a function that is in TIR calling convention (destination passing style), e.g., mylog(in, out), and wrap it with a call_tir function, call_tir(out_shape, mylog, in) will return the output
    • func can be a TIR function or a packed function
    • shape_expr is an optional argument to pass integers as arguments. For more info
    • output_shape can be a ShapeExpr or a Tuple

For a more detailed introduction of CallTIR, please refer to Relax Architecture Overview.

Challenges of lowering this program to VM instructions

  • C0: Every call_tir needs to be lowered (relax VM only supports call instruction to call a packed function directly) → We need to insert explicit output memory allocation with memory planning

  • C1: The symbolic shape n and m are not something that the runtime can represent (the relax VM only supports NDArray and ShapeTuple runtime data structures) → We need to use the heap in the VM to do shape calculations

First, to address C0, we lower call_tir to explicit memory allocation form

An explicit memory form program has the following properties:

  • Explicitly allocate and kill storage and tensors
  • Has side effect
  • No shape annotation
  • Core expression: call(func_name, arg0, arg1, ...) -> optional<Expr>, this maps to the Call instruction that VM can directly execute

Four intrinsics/builtin-functions:

  • relax.vm.builtin.alloc_storage(size, device) -> storage: Allocate a storage that can be used to create tensors.

  • relax.vm.builtin.alloc_tensor(storage, shape, offset, dtype) -> tensor: Allocate a tensor in a storage

  • relax.vm.builtin.free_storage(storage): Free the allocated storage

  • relax.vm.builtin.free_tensor(tensor): Free the allocated tensor

Since alloc_storage and alloc_tensor contain integer or dtype as their arguments, and they cannot be represented as an Expr as the arguments to a CallNode, alloc_storage and alloc_tensor are designed as intrinsics in Relax, and contain attributes containing int and dtype.

from tvm.script import tir as T, relax as R

# Program after call_tir lowering

@R.function
def myfunc(x):
    # has side effect, so it's now in a BindingBlock instead of a DataflowBlock
    n, m = R.match_shape(x.shape)
		
    storage0 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
    tensor0 = relax.vm.builtin.alloc_tensor(storage0, shape=[n, m], offset=0, "f32")
    R.call_packed("tirexp"), x, tensor0)
		
    storage1 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
    tensor1 = relax.vm.builtin.alloc_tensor(storage1, shape=[m*n,], offset=0, "f32")
    R.call_packed("flatten"), tensor0, tensor1)
		
    R.call_packed("free_tensor"), tensor0)
    R.call_packed("free_storage"), storage0)
    return tensor1

Next, to address C1, we do the shape lowering via VM heap manipulation

  • Three intrinsics/builtin-functions:

    • relax.vm.builtin.alloc_heap(size) -> heap: Allocate the heap (an NDArray) with a specific size to execute shape computation

      (We can use alloc_tensor to achieve the same goal)

    • relax.vm.builtin.store_shape(shape, heap, idx0, ...): Store a shape into specific indices in the vm heap

    • relax.vm.builtin.load_shape(heap, idx0, ...) -> shape: Construct a shape from the vm heap according to the indices

    (Since store_shape and load_shape contains indices (an array of integers) as their arguments, and they cannot be represented as an Expr, so they are designed as intrinsics in Relax, and contain attributes describing int and dtype.)

from tvm.script import tir as T, relax as R

# Program after shape lowering

@R.function
def myfunc(x):
    shape_heap = relax.call_packed("vm.builtin.alloc_shape_heap", size=k) 
    relax.vm.builtin.store_shape(x.shape, shape_heap, 0, 1)
    sh = relax.vm.builtin.load_shape(shape_heap, 0, 1)
    # this product_shape function (to compute n*m) is generated as TIR primfunc when visiting ShapeExpr in the shape lowering pass
    shape_size = product_shape(sh) 
		
    storage0 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
    gv0 = relax.vm.builtin.alloc_tensor(storage0, sh, 0, "f32")
    R.call_packed("tirexp"), x, gv0)
		
    sh1 = R.call_packed("load_shape"), heap, 0, 1)
    storage1 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
    gv1 = relax.vm.builtin.alloc_tensor(storage1, sh1, 0, "f32")
    R.call_packed("flatten"), gv0, gv1)
		
    R.call_packed("free_tensor"), gv0)
    R.call_packed("free_storage"), storage0)
    return gv1

Relax Compilation and Build Workflow

The four lowering passes in vm.build(ir_mod, target) map to the multi-stage pipeline described above.

After the lowering passes, relax.vm.build(ir_mod, target) calls tvm::build to build all TIR primfuncs in the IRModule, and uses CodeGenVM to visit all the relax functions in the input IRModule and generate VM executable during the visit.