-
Notifications
You must be signed in to change notification settings - Fork 58
Relax Compilation MVP Design
Authors(alphabetical): @altanh, @tqchen, @YuchenJin, @ZihengJiang
- Compile Relax program to a format that VM can execute
- A multi-stage compilation pipeline that enables composable compilation transformations
- Every transformation is
IRModule
→IRModule
- Users might run part of the program with third-party libraries such as cudnn. We need to be capable to optimize the left part.
- Every transformation is
Let's take compiling the following simple Relax program as a running example.
import tvm.script
from tvm.script import tir as T, relax as R
@tvm.script.ir_module
class MyIRModule:
@T.prim_func
def tirexp(a: ty.handle, b: ty.handle):
n1, m1 = T.var("n1"), T.var("m1")
X = T.match_buffer(x, (n1, m1))
Y = T.match_buffer(y, (n1, m1))
with T.block(n1, m1) as i, j:
Y[i, j] = T.exp(X[i, j])
@R.function
def myfunc(x: Tensor[(n, m)]):
with R.dataflow():
lv0: Tensor[(n, m)] = R.call_tir((n, m), tirexp, [x])
gv0: Tensor[(m*n,)] = R.call_tir((m*n,), "flatten", [lv0])
R.outputs(gv0)
return gv0
We introduced a new intrinsic relax.call_tir
in relax, and use it to construct program. This CallTIR
form program has the following properties:
- No side effect
- With shape annotation
- Core expression:
call_tir(output_shape, func, [arg0, arg1, ...], Optional<shape_expr>) -> Expr
- This means that we take a function that is in TIR calling convention (destination passing style), e.g.,
mylog(in, out)
, and wrap it with acall_tir
function,call_tir(out_shape, mylog, in)
will return the output -
func
can be a TIR function or a packed function -
shape_expr
is an optional argument to pass integers as arguments. For more info -
output_shape
can be a ShapeExpr or a Tuple
- This means that we take a function that is in TIR calling convention (destination passing style), e.g.,
For a more detailed introduction of CallTIR
, please refer to Relax Architecture Overview.
-
C0: Every
call_tir
needs to be lowered (relax VM only supportscall
instruction to call a packed function directly) → We need to insert explicit output memory allocation with memory planning -
C1: The symbolic shape
n
andm
are not something that the runtime can represent (the relax VM only supportsNDArray
andShapeTuple
runtime data structures) → We need to use the heap in the VM to do shape calculations
An explicit memory form program has the following properties:
- Explicitly allocate and kill storage and tensors
- Has side effect
- No shape annotation
- Core expression:
call(func_name, arg0, arg1, ...) -> optional<Expr>
, this maps to theCall
instruction that VM can directly execute
Four intrinsics/builtin-functions:
-
relax.vm.builtin.alloc_storage(size, device) -> storage
: Allocate a storage that can be used to create tensors. -
relax.vm.builtin.alloc_tensor(storage, shape, offset, dtype) -> tensor
: Allocate a tensor in a storage -
relax.vm.builtin.free_storage(storage)
: Free the allocated storage -
relax.vm.builtin.free_tensor(tensor)
: Free the allocated tensor
Since alloc_storage
and alloc_tensor
contain integer or dtype as their arguments, and they cannot be represented as an Expr
as the arguments to a CallNode
, alloc_storage
and alloc_tensor
are designed as intrinsics in Relax, and contain attributes containing int and dtype.
from tvm.script import tir as T, relax as R
# Program after call_tir lowering
@R.function
def myfunc(x):
# has side effect, so it's now in a BindingBlock instead of a DataflowBlock
n, m = R.match_shape(x.shape)
storage0 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
tensor0 = relax.vm.builtin.alloc_tensor(storage0, shape=[n, m], offset=0, "f32")
R.call_packed("tirexp"), x, tensor0)
storage1 = relax.vm.builtin.alloc_storage(size=[n*m], device=cpu)
tensor1 = relax.vm.builtin.alloc_tensor(storage1, shape=[m*n,], offset=0, "f32")
R.call_packed("flatten"), tensor0, tensor1)
R.call_packed("free_tensor"), tensor0)
R.call_packed("free_storage"), storage0)
return tensor1
-
Three intrinsics/builtin-functions:
-
relax.vm.builtin.alloc_heap(size) -> heap
: Allocate the heap (an NDArray) with a specific size to execute shape computation(We can use
alloc_tensor
to achieve the same goal) -
relax.vm.builtin.store_shape(shape, heap, idx0, ...)
: Store a shape into specific indices in the vm heap -
relax.vm.builtin.load_shape(heap, idx0, ...) -> shape
: Construct a shape from the vm heap according to the indices
(Since
store_shape
andload_shape
contains indices (an array of integers) as their arguments, and they cannot be represented as anExpr
, so they are designed as intrinsics in Relax, and contain attributes describing int and dtype.) -
from tvm.script import tir as T, relax as R
# Program after shape lowering
@R.function
def myfunc(x):
shape_heap = relax.call_packed("vm.builtin.alloc_shape_heap", size=k)
relax.vm.builtin.store_shape(x.shape, shape_heap, 0, 1)
sh = relax.vm.builtin.load_shape(shape_heap, 0, 1)
# this product_shape function (to compute n*m) is generated as TIR primfunc when visiting ShapeExpr in the shape lowering pass
shape_size = product_shape(sh)
storage0 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
gv0 = relax.vm.builtin.alloc_tensor(storage0, sh, 0, "f32")
R.call_packed("tirexp"), x, gv0)
sh1 = R.call_packed("load_shape"), heap, 0, 1)
storage1 = relax.vm.builtin.alloc_storage(size=shape_size, device=cpu)
gv1 = relax.vm.builtin.alloc_tensor(storage1, sh1, 0, "f32")
R.call_packed("flatten"), gv0, gv1)
R.call_packed("free_tensor"), gv0)
R.call_packed("free_storage"), storage0)
return gv1
The four lowering passes in vm.build(ir_mod, target)
map to the multi-stage pipeline described above.
After the lowering passes, relax.vm.build(ir_mod, target)
calls tvm::build
to build all TIR primfuncs in the IRModule, and uses CodeGenVM
to visit all the relax functions in the input IRModule and generate VM executable during the visit.