diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index 8e3e9d92f554..e9792d6b53a6 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -83,9 +83,11 @@ struct ToVDeviceAttrs : public tvm::AttrsNode { struct HintOnDeviceAttrs : public tvm::AttrsNode { int32_t dev_type; int32_t dev_id; + MemoryScope memory_scope; TVM_DECLARE_ATTRS(HintOnDeviceAttrs, "relax.attrs.HintOnDeviceAttrs") { TVM_ATTR_FIELD(dev_type).describe("The device type where the data is supposed to be executed."); TVM_ATTR_FIELD(dev_id).describe("The device id."); + TVM_ATTR_FIELD(memory_scope).set_default("global").describe("The device memory scope."); } }; // struct HintOnDeviceAttrs diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index eaad44a93ace..5f9016ee5062 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -243,9 +243,11 @@ TVM_DLL Pass FoldConstant(); * will override the default one. * \param enable_warning A boolean value indicating if to print warnings for TIR functions not * showing up in the database. + * \param add_attributes A boolean value indicating adding of call attributes to TIR functions * \return The Pass. */ -TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); +TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false, + bool add_attributes = false); /*! * \brief Propagate virtual device information. @@ -666,6 +668,24 @@ TVM_DLL Pass RewriteCUDAGraph(); */ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); +/*! + * \brief This pass is designed to annotate the memory scope information via VDevice attribute. + * This pass need operator attrbutes which in general vanish aftre legalization. + * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also + * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each + * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. + * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the + * var_buffer_map based on this new VDevice information. + */ +TVM_DLL Pass AnnotateCustomMemoryScope(Target target); + +/*! + * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + * Primarily used to update the VDevice information if any changes occured from the caller. + * This pass recreates the buffers and updates the map. + */ +TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 421d4017d1bd..b549663b8c72 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,6 +16,7 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu +from . import adreno from .base import ( ApplyDefaultSchedule, BlockInfo, diff --git a/python/tvm/dlight/adreno/__init__.py b/python/tvm/dlight/adreno/__init__.py new file mode 100644 index 000000000000..ea2781455989 --- /dev/null +++ b/python/tvm/dlight/adreno/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Adreno schedule rules. +""" +from .convolution import Conv2d diff --git a/python/tvm/dlight/adreno/base.py b/python/tvm/dlight/adreno/base.py new file mode 100644 index 000000000000..d043706c2fc5 --- /dev/null +++ b/python/tvm/dlight/adreno/base.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for Adreno operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class AdrenoScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to Adreno targets, + will return None if the target is not Adreno.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for Adreno rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "adreno" in target.keys diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py new file mode 100644 index 000000000000..f084885dad73 --- /dev/null +++ b/python/tvm/dlight/adreno/convolution.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Conv2d schedule rule for Adreno GPU operators.""" +from dataclasses import dataclass +from typing import List, Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV + +from ..base import analysis, BlockInfo, IterInfo +from .base import AdrenoScheduleRule + + +def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + +def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: + def _iter_kind(loop: tir.IterVar) -> str: + return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + + def _is_reduction_block(block: tir.schedule.BlockRV): + for iter_var in sch.get(block).iter_vars: + if _iter_kind(iter_var) == "R": + return True + return False + + return BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter_var), + var=iter_var.var, + dom=iter_var.dom.extent, + loop_rv=loop_rv, + ) + for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars) + ], + block_rv=block, + reduction_block=_is_reduction_block(block), + ) + + +def get_reduction_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]) -> bool: + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all( + [is_reduction_block(sch, block) or is_spatial_block(sch, block) for block in blocks] + ): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction_block(sch, block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks[0] + + +def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV): + # TODO: Use buffer access patterns to discover convolution type kernels instead of using name. + return ( + sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo") + and "".join([iter_type.kind for iter_type in get_block_info(sch, block).iters]) + == "SSSSSRRR" + ) + + +class Conv2d(AdrenoScheduleRule): + """The schedule rule for convolution computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Conv2d.Config( + block_size_x=8, + block_size_y=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + return Conv2d.Config( + block_size_x=32, + block_size_y=4, + vector_size=8, + unroll=16, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Conv2d.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if isinstance(func, tir.PrimFunc): + sch = tir.Schedule(func) + + # config = self.get_configs(target) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_block = get_reduction_blocks(sch, blocks) + + if reduction_block is None: + return None + if not is_convolution(sch, reduction_block): + return None + + def schedule_data_pad(blk): + axes = sch.get_loops(blk) + axes, vec = axes[:-1], axes[-1] + axis = sch.fuse(*axes) + bx, ty, tx = sch.split(axis, [None, 16, 16]) + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def schedule_conv2d(blk): + # TODO: Loop Pattern mayn't be reliable, need to perform better analysis. + n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk) + sch.reorder(n, oc, oh, ow, ic, kh, kw, ob) + main_lp = sch.fuse(n, oc, oh, ow) + bx, ty, tx = sch.split(main_lp, [None, 16, 16]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + + ico, icv = sch.split(ic, [None, 4]) + sch.reorder(ico, kh, kw, icv, ob) + rblk = sch.cache_read(blk, 0, "local") + sch.compute_at(rblk, kw) + sch.vectorize(sch.get_loops(rblk)[-1]) + wblk = sch.cache_write(blk, 0, "local") + sch.reverse_compute_at(wblk, tx) + sch.vectorize(sch.get_loops(wblk)[-1]) + sch.vectorize(ob) + init_blk = sch.decompose_reduction(blk, ico) + sch.vectorize(sch.get_loops(init_blk)[-1]) + + def is_data_pad(block: tir.stmt.Block): + return is_spatial_block(sch, block) and tir.analysis.has_if_then_else(sch.get(block)) + + def schedule_conv2d_blocks(): + + # Do analysis to find block type + blocks = sch.get_child_blocks(root_block) + passed_reduction = False + for blk in blocks: + if is_reduction_block(sch, blk): + schedule_conv2d(blk) + passed_reduction = True + elif is_data_pad(blk): + schedule_data_pad(blk) + elif is_spatial_block(sch, blk): + try: + if not passed_reduction: + sch.compute_inline(blk) + else: + sch.reverse_compute_inline(blk) + except: # pylint: disable=W0702 + pass + else: + raise TypeError("Can't Schedule this Block", sch.get(blk)) + + schedule_conv2d_blocks() + return sch diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 03e86a4633a6..300bb33325ec 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -823,7 +823,7 @@ def to_vdevice(data, dst_vdevice) -> Expr: return _ffi_api.to_vdevice(data, dst_vdevice) # type: ignore -def hint_on_device(data, dst_vdevice) -> Expr: +def hint_on_device(data, dst_vdevice, memory_scope="global") -> Expr: """It provides a hint specifying the device on which the input data should be executed. This hint is utilized by RealizeVDevice to propagate the virtual device." @@ -832,12 +832,15 @@ def hint_on_device(data, dst_vdevice) -> Expr: data : Expr The tensor to be copied. - dst_device : VDevice + dst_device : Device The destination device where the data is supposed to be executed. + memory_scope: String + Memory scope of buffer on target device. + Returns ------- result : Expr The result. """ - return _ffi_api.hint_on_device(data, dst_vdevice) # type: ignore + return _ffi_api.hint_on_device(data, dst_vdevice, memory_scope) # type: ignore diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 16e4800ca33d..803e9161683a 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -83,6 +83,8 @@ UpdateVDevice, VMBuiltinLower, VMShapeLower, + AnnotateCustomMemoryScope, + SpecializePrimFuncBasedOnCallSite, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index b4aba0291fc1..11b9a293407a 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -31,3 +31,6 @@ from . import search from . import statistical from . import unary + +# Device specific legalizations +from . import adreno diff --git a/python/tvm/relax/transform/legalize_ops/adreno/__init__.py b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py new file mode 100644 index 000000000000..f2b3f4a781d2 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from .convolution import conv2d_NCHWc_OIHWo diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py new file mode 100644 index 000000000000..eb0bf30cfbf2 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Convolution impl for Adreno GPU.""" + +from tvm import relax +from tvm import topi + + +def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + return bb.call_te( + topi.nn.conv2d_NCHWc_OIHWo, + data=call.args[0], + kernel=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + layout=call.attrs.data_layout, + out_layout=call.attrs.out_layout, + # out_dtype=call.attrs.out_dtype, + primfunc_name_hint="conv2d_NCHWc_OIHWo", + ) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 603211b59ebc..9de8d000495b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -30,6 +30,7 @@ from tvm.relax.dpl import DFPattern from tvm.runtime import NDArray, Object from tvm.tir import IndexMap, PrimFunc +from tvm.target import Target from . import _ffi_api from .legalize_ops.common import LegalizeFunc @@ -1061,7 +1062,9 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor def LegalizeOps( - customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, enable_warning: bool = False + customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, + enable_warning: bool = False, + add_attributes: bool = False, ): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. @@ -1092,6 +1095,10 @@ def LegalizeOps( legalization function is not registered. By default we don't print warnings. + add_attributes : bool + A boolean value indicating if we want legalize ops to add operator attributes to legalized + prim function attributes. By default it's false. + Returns ------- ret : tvm.transform.Pass @@ -1166,7 +1173,9 @@ def multiply( T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] """ - return _ffi_api.LegalizeOps(customize_legalize_map, enable_warning) # type: ignore + return _ffi_api.LegalizeOps( + customize_legalize_map, enable_warning, add_attributes # type: ignore + ) def RealizeVDevice() -> tvm.ir.transform.Pass: @@ -1604,6 +1613,32 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: return _ffi_api.AllocateWorkspace() # type: ignore +def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: + """Allocate the memory scope information. This is Adreno specific pass to annotate + The memory scope information and realize the same with RealizeVDevice pass followed by + updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore + + +def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: + """This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + Primarily used to update the VDevice information if any changes occured from the caller. + This pass recreates the buffers and updates the map. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 67eb7471d22d..e8bf0a0922e8 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Union import tvm -from tvm import Object +from tvm import Object, _ffi from tvm.ir import IRModule from tvm.tir.expr import Var from tvm.tir.stmt import Block, BufferRegion, PrimExpr @@ -407,6 +407,10 @@ def find_anchor_block(mod: IRModule) -> Block: return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member +def has_if_then_else(stmt: Stmt) -> bool: + return _ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + + def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: """Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 205730ff22d6..34207aa2968c 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -482,6 +482,135 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ) +def conv2d_NCHWc_OIHWo( + data: te.Tensor, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32" +): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.te.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.te.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + num_filter_block] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.te.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dilation_h, dilation_w = ( + dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + ) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + kernel_shape = get_const_tuple(kernel.shape) + if len(kernel_shape) == 6: # OIHW4i4o + oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = kernel_shape + groups = in_channel // (ic_chunk_group * kernel_ic_bn) + else: # OIHW4o + oc_chunk, ic, kernel_height, kernel_width, oc_bn = kernel_shape + groups = in_channel // ic + + num_filter = oc_chunk * oc_bn + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + # output shape + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) + + # DOPAD + DOPAD = HPAD != 0 or WPAD != 0 + if DOPAD: + data_pad = pad(data, pad_before, pad_after, name="conv2d_data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + def compute_conv2d(*args): + n, occ, oh, ow, ocb = args + ic = te.reduce_axis((0, in_channel // groups), name="ic") + if groups == 1: + data_pad_ = data_pad[ + n, + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + else: + data_pad_ = data_pad[ + n, + (occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + if len(kernel_shape) == 5: + kernel_ = kernel[occ, ic, kh, kw, ocb] + else: + kernel_ = kernel[occ, idxdiv(ic, oc_bn), kh, kw, idxmod(ic, oc_bn), ocb] + + if out_dtype is not None: + data_pad_ = data_pad_.astype(out_dtype) + kernel_ = kernel_.astype(out_dtype) + + return te.sum( + data_pad_ * kernel_, + axis=[ic, kh, kw], + ) + + return te.compute( + oshape, + lambda *indices: compute_conv2d(*indices), # pylint: disable=W0108 + name="conv2d_NCHWc_OIHWo", + tag="conv2d_NCHWc_OIHWo", + ) + + def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4 ): diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 3ebbc544f470..dc4d3d7f7ff1 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -301,6 +301,8 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); ObjectPtr new_attrs = make_object(*attrs); if (it != desired_layouts.end()) { @@ -346,14 +348,16 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, new_attrs->kernel_layout = (*it).second[1]; new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + data_layout = LayoutDecision(InitialLayout(4)); + weight_layout = LayoutDecision(InitialLayout(4)); } } } // We don't have a desired layout for conv2d or desired layouts not compatible. // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; new_attrs->data_layout = TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index a7d97a59a100..968c28fadbf2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1360,15 +1360,25 @@ RELAY_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, Device device) { +Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = make_object(); attrs->dev_type = static_cast(device.device_type); attrs->dev_id = device.device_id; + attrs->memory_scope = memory_scope; return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); +TVM_REGISTER_GLOBAL("relax.op.hint_on_device").set_body([](TVMArgs args, TVMRetValue* rv) { + Expr data = args[0]; + Device device = args[1]; + if (args.size() == 3) { + String scope = args[2]; + *rv = MakeHintOnDevice(data, device, scope); + } else { + *rv = MakeHintOnDevice(data, device); + } +}); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 4a63993d507c..7246a10df038 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -157,15 +157,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, Optional shape1 = GetRef(x1_sinfo->shape.as()); Optional shape2 = GetRef(x2_sinfo->shape.as()); + // Lets handle sub indexing as long as primal dims are matching - if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { - if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { - if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { - return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); - } - } else if (shape1.defined()) { - if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { - return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + if ((layout1->layout.ndim() != layout1->layout.ndim_primal()) || + (layout2->layout.ndim() != layout2->layout.ndim_primal())) { + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape2.value()->values.size()), layout1->layout, + shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape1.value()->values.size()), layout2->layout, + shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } } } } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 452b1f223a80..d1df2aaee97c 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -304,12 +304,54 @@ InferLayoutOutput InferLayoutConcat(const Call& call, const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); ICHECK(nlayout.IsNested()); ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + + // We may expect mix of sub indexed and regular layouts here + // Pick the first sub indexed layout and try to prove it for all tensors + // On any failre select first occuring regular layout for all + auto nlayout_array = nlayout.NestedArray(); + for (auto n_layout : nlayout_array) { + ICHECK(n_layout.IsLeaf()); + LayoutDecision in_layout = n_layout.LeafValue(); + if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { + const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tuple_sinfo != nullptr) + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_->GetTypeKey(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + StructInfo field_sinfo = tuple_sinfo->fields[i]; + const auto* field_tensor_sinfo = field_sinfo.as(); + ICHECK(field_tensor_sinfo != nullptr) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_; + auto t_sinfo = GetRef(field_tensor_sinfo); + Optional t_shape = GetRef(t_sinfo->shape.as()); + LayoutDecision curr_layout = nlayout_array[i].LeafValue(); + if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, + t_shape.value()->values)) { + // Some tensor unhappy with sub indexed layout, lets pick first regular layout + for (auto pick_layout : nlayout_array) { + if (pick_layout.LeafValue()->layout.ndim() == + pick_layout.LeafValue()->layout.ndim_primal()) { + in_layout = pick_layout.LeafValue(); + break; + } + } + break; + } + } + layout = in_layout; + break; + } + } + Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); diff --git a/src/relax/transform/annotate_custom_storage.cc b/src/relax/transform/annotate_custom_storage.cc new file mode 100644 index 000000000000..9c00edb0a9ab --- /dev/null +++ b/src/relax/transform/annotate_custom_storage.cc @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/annotate_texture_storage.cc + * \brief Texture Storage Annotation Pass. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class CollectProduserScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map Collect(const IRModule& mod, Function func, + const Map>>& scope_info, + const Target& target) { + mod_ = mod; + scope_info_ = scope_info; + target_ = target; + VisitExpr(func->body); + + return producer_sinfo; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + ExprVisitor::VisitBinding_(binding, call); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + out_sinfo = call->sinfo_args[0]; + } else { + return; + } + + std::unordered_map scope_count; + + auto arg_var = binding->var.as(); + if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[GetRef(arg_var)]) { + auto call_node = Downcast(val.first); + if (scope_count.find(val.second[0]) == scope_count.end()) { + scope_count.insert({val.second[0], 1}); + } else { + auto curr_count = scope_count[val.second[0]]; + scope_count.emplace(val.second[0], curr_count + 1); + } + } + } + String final_scope = "global"; + int count = 0; + for (const auto& sval : scope_count) { + if (sval.second > count) { + final_scope = sval.first; + count = sval.second; + } + } + StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); + producer_sinfo.Set(GetRef(call), updated_ret_sinfo); + } + + private: + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { + if (out_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(out_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); + return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, + VDevice(target_, 0, scope[0])); + } + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + sinfo_fields.push_back( + TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + } + return TupleStructInfo(sinfo_fields); + } + + Map>> scope_info_; + Map producer_sinfo; + IRModule mod_; + Target target_; +}; + +class CollectConsumerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + std::pair>, Map>>> Collect( + const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + for (const auto& val : arg_to_binding) { + if (scope_info.find(val.first) != scope_info.end()) { + if (scope_info.find(val.second) == scope_info.end()) { + scope_info.Set(val.second, scope_info[val.first]); + } else { + auto ent = scope_info[val.second]; + for (auto ent_val : scope_info[val.first]) { + ent.Set(ent_val.first, ent_val.second); + } + scope_info.Set(val.second, ent); + } + } + } + + return std::make_pair(call_scope_info, scope_info); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(GetRef(binding->var.get()), + GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + GlobalVar gv; + Array op_attrs; + Optional op_pattern = Integer(static_cast(relay::kOpaque)); + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + op_attrs = ExtractAttrs(pfunc); + op_pattern = ExtractPattern(pfunc); + out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + return; + } + + bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); + + Array arg_scope; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + auto scope = is_texture_supported + ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) + : "global"; + Map> ent_call; + const VarNode* arg_var = arg.as(); + if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { + ent_call = scope_info[GetRef(arg_var)]; + } + ent_call.Set(GetRef(call), {scope}); + scope_info.Set(GetRef(arg_var), ent_call); + arg_scope.push_back(scope); + } + } + call_scope_info.Set(GetRef(call), arg_scope); + } + + private: + template + Array ExtractAttrs(const T& func) { + Array op_attrs; + Optional attrs = func->template GetAttr("op_attrs"); + if (attrs) { + if (auto val = attrs.value().as()) { + op_attrs.push_back(val.value()); + } else if (auto val = attrs.value().as>()) { + op_attrs = val.value(); + } + } + return std::move(op_attrs); + } + + template + Optional ExtractPattern(const T& func) { + Optional op_pat = func->template GetAttr("op_pattern"); + return std::move(op_pat); + } + + bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { + if (op_pattern.IntValue() < relay::kCommReduce) return true; + + for (auto attr : op_attrs) { + if (auto conv_attr = attr.as()) { + if (conv_attr->data_layout == "NCHW4c" && conv_attr->kernel_layout == "OIHW4o") { + return true; + } + } else if (auto pool_attrs = attr.as()) { + if (pool_attrs->layout == "NCHW4c") { + return true; + } + } else if (auto avg_attrs = attr.as()) { + if (avg_attrs->layout == "NCHW4c") { + return true; + } + } else if (attr.as()) { + return true; + } + } + + return false; + } + + std::string Scope(Array shape) { + // currently we support only textures been made from 5d tensors + // 5d requirement is not limitation of textures in general, it is limitation how + // we are representing memory scopes/layout and flattening of textures in tir + if (shape.size() == 5 && shape[4].as()->value == 4) { + std::map diffs; + int spatial_limit = + target_->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; + int depth_limit = + target_->GetAttr("texture_depth_limit").value_or(Integer(2048))->value; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d1r = a0 * a1; + int d2r = a2 * a3; + int d3r = a1 * a2 * a3; + std::string scope = "global"; + if (a0 < spatial_limit && d3r < spatial_limit) + scope += ".texture-weight"; + else if (a0 < depth_limit && a1 < spatial_limit && d2r < spatial_limit) + scope += ".texture-nhwc"; + else if (d1r < depth_limit && a2 < spatial_limit && a3 < spatial_limit) + scope += ".texture"; + return scope; + } + return "global"; + } + + Map>> scope_info; + Map> call_scope_info; + Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +class DefineVDevice : ExprMutator { + public: + explicit DefineVDevice(const Target& target) : target_(target) {} + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); + call_scope_info_ = info.first; + scope_info_ = info.second; + producer_sinfo_ = CollectProduserScopeInfo().Collect(mod_, Downcast(func), + scope_info_, target_); + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + + Array global_vdevices_; + for (auto vdev : vdevices_) { + global_vdevices_.push_back(vdev.as().value()); + } + mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); + + mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); + mod_ = relax::transform::DeadCodeElimination()(mod_); + mod_ = relax::transform::RealizeVDevice()(mod_); + mod_ = relax::transform::SpecializePrimFuncBasedOnCallSite()(mod_); + + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + GlobalVar gv; + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + return call; + } + + Array new_args; + StructInfo updated_ret_sinfo = producer_sinfo_[GetRef(call_node)]; + + if (updated_ret_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(updated_ret_sinfo); + auto shape = tensor_sinfo->shape.value(); + auto dtype = tensor_sinfo->dtype; + if (tensor_sinfo->vdevice.defined()) { + auto vdev = tensor_sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); + } + } else { + ICHECK(updated_ret_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << updated_ret_sinfo; + + const auto& tuple_sinfo = Downcast(updated_ret_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + + auto shape = sinfo->shape.value(); + auto dtype = sinfo->dtype; + if (sinfo->vdevice.defined()) { + auto vdev = sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + sinfo_fields.push_back(TensorStructInfo(shape, dtype, vdev_global)); + } else { + sinfo_fields.push_back(sinfo); + } + } + updated_ret_sinfo = TupleStructInfo(sinfo_fields); + } + + int arg_idx = 0; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + String scope = "global"; + if (call_scope_info_.find(GetRef(call_node)) != call_scope_info_.end()) { + scope = call_scope_info_[GetRef(call_node)][arg_idx]; + } + new_args.push_back(HintArg(arg, scope)); + arg_idx++; + } else { + new_args.push_back(arg); + } + } + + auto updated_call = Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo}); + return builder_->Normalize(updated_call); + } + + private: + VDevice MakeGlobalVDevice(VDevice vdev) { + int device_type = vdev->target->GetTargetDeviceType(); + for (size_t i = 0; i < vdevices_.size(); ++i) { + int dev_type = vdevices_[i]->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevices_[i]->vdevice_id == vdev->vdevice_id && + vdevices_[i]->memory_scope == vdev->memory_scope) { + return vdevices_[i]; + } + } + vdevices_.push_back(vdev); + return (vdevices_.back()); + } + + Expr HintArg(const Expr& arg, String scope) { + if (arg->IsInstance()) { + if (auto tsinfo = arg->struct_info_.as()) { + if (!tsinfo->vdevice.defined()) { + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; + arg->struct_info_ = + TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); + return arg; + } + } + } + ObjectPtr attrs = make_object(); + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + attrs->dev_type = vdev->target->GetTargetDeviceType(); + attrs->dev_id = vdev->vdevice_id; + attrs->memory_scope = vdev->memory_scope; + + Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); + + return std::move(new_arg); + } + + Optional GetTarget(const StructInfo& sinfo) { + auto tinfo = sinfo.as(); + if (tinfo->vdevice.defined()) { + auto vdevice = tinfo->vdevice.value(); + if (vdevice->target.defined()) { + return vdevice->target; + } + } + return NullOpt; + } + + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); + IRModule mod_; + IRModule updates_; + Target target_; + Array vdevices_; + Map>> scope_info_; + Map producer_sinfo_; + Map> call_scope_info_; +}; + +namespace transform { + +Pass AnnotateCustomMemoryScope(Target target) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::DefineVDevice(target).Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"AnnotateCustomMemoryScope", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AnnotateCustomMemoryScope") + .set_body_typed(AnnotateCustomMemoryScope); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index fe247645dc24..ae10b4b5d62f 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -637,6 +637,14 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + if (prim_func_->GetAttr("op_attrs")) { + func_info_.op_attrs.push_back(prim_func_->GetAttr("op_attrs").value()); + } + + if (prim_func_->GetAttr("op_pattern")) { + auto op_pattern = prim_func_->GetAttr("op_pattern").value(); + func_info_.op_pattern.push_back(static_cast(op_pattern.IntValue())); + } // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); @@ -953,6 +961,14 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { Map attr_map; attr_map.Set("tir.noalias", tir::const_true()); + if (!func_info_.op_attrs.empty()) { + attr_map.Set("op_attrs", func_info_.op_attrs); + } + if (!func_info_.op_pattern.empty()) { + int op_pattern = relay::kOpaque; + op_pattern = *max_element(func_info_.op_pattern.begin(), func_info_.op_pattern.end()); + attr_map.Set("op_pattern", Integer(static_cast(op_pattern))); + } tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers @@ -1004,6 +1020,8 @@ class FusedTIRConstructor : public ExprVisitor { Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ Array bodies; + Array op_attrs; + std::vector op_pattern; /*! \brief The params of the fused function*/ Array params; /*! diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 4a6b44bf2839..b0a54dc6ddea 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -60,8 +60,11 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, - bool enable_warning) - : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { + bool enable_warning, bool add_attributes) + : ExprMutator(mod), + mod_(std::move(mod)), + enable_warning_(enable_warning), + add_attributes_(add_attributes) { if (cmap) { cmap_ = std::move(cmap.value()); } @@ -152,6 +155,32 @@ class LegalizeMutator : public ExprMutator { return NullOpt; } + Expr AttributeOpAttrs(Expr expr, Attrs attrs) { + if (!expr->IsInstance()) { + return expr; + } + + auto call = Downcast(expr); + if (call->args.empty()) { + return expr; + } + + auto gvar = call->args[0].as(); + if (!gvar.defined()) { + return expr; + } + + auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value()); + auto opt_prim_func = base_func.as(); + if (!opt_prim_func) { + return expr; + } + auto prim_func = opt_prim_func.value(); + auto new_prim_func = WithAttr(prim_func, "op_attrs", attrs); + builder_->UpdateFunction(gvar.value(), new_prim_func); + return call; + } + Expr BindTarget(Expr expr) { if (!expr->IsInstance()) { // FLegalize returned something other than a relax::Call. This @@ -342,6 +371,10 @@ class LegalizeMutator : public ExprMutator { } Expr legalized = legalization_func(builder_, visited_call); + if (call->attrs.as() && add_attributes_) { + legalized = AttributeOpAttrs(legalized, call->attrs); + } + // Append the target attribute to any PrimFunc generated in // legalization. legalized = BindTarget(legalized); @@ -385,17 +418,21 @@ class LegalizeMutator : public ExprMutator { * legalization function is not registered. */ bool enable_warning_; + /*! + * \brief Boolean indicating this pass to add operator attributes to prim function attr + */ + bool add_attributes_; }; namespace transform { -Pass LegalizeOps(Optional> cmap, bool enable_warning) { +Pass LegalizeOps(Optional> cmap, bool enable_warning, bool add_attributes) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; if (apply_legalize_ops) { - mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + mod = LegalizeMutator(mod, cmap, enable_warning, add_attributes).Transform(); } return mod; }; diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 0df86515dbcc..672d63e65f56 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -55,6 +55,7 @@ class VDeviceLookup { ICHECK(attrs); int32_t device_type = attrs->dev_type; int32_t device_id = attrs->dev_id; + String memory_scope = attrs->memory_scope; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; @@ -65,7 +66,8 @@ class VDeviceLookup { for (auto vdevice : vdevices) { int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { + if (dev_type == device_type && vdevice->vdevice_id == device_id && + memory_scope == vdevice->memory_scope) { return vdevice; } } diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc new file mode 100644 index 000000000000..fe2cc9329860 --- /dev/null +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/specialize_tir_params.cc + * \brief Update PrimFunc buffers based on updated scope (or structure) info. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class SpecializeTIRCallArgs : ExprMutator { + public: + IRModule Run(IRModule mod) { + mod_ = mod; + for (const auto& [gv, func] : mod->functions) { + if (func->IsInstance()) { + const auto& base_func = mod->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op == call_tir_op) { + return SpecializeTirPrimFunc(call); + } + return call; + } + + private: + Expr SpecializeTirPrimFunc(Call call) { + auto gv = Downcast(call->args[0]); + auto pfunc = Downcast(mod_->Lookup(gv)); + auto args = Downcast(call->args[1])->fields; + Map> param_map; + + for (size_t i = 0; i < args.size(); ++i) { + auto sinfo = GetStructInfo(args[i]); + CHECK(sinfo->IsInstance()) + << "Expected Tensor struct Info for call :" << call->op; + auto tensor_sinfo = Downcast(sinfo); + CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + String scope = "global"; + if (tensor_sinfo->vdevice.defined()) { + scope = tensor_sinfo->vdevice.value()->memory_scope; + } + String name; + if (args[i]->IsInstance()) { + name = Downcast(args[i])->name_hint(); + } else { + name = std::string({static_cast('A' + i)}); + } + + const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), + tensor_sinfo->dtype, name, scope); + param_map.Set(pfunc->params[i], buffer); + } + String scope = "global"; + auto out_sinfo = call->sinfo_args[0]; + if (out_sinfo->IsInstance()) { + auto sinfo = Downcast(out_sinfo); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); + } else { + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + int index = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[args.size() + index], buffer); + index++; + } + } + + auto new_pfunc = Specialize(pfunc, param_map); + for (const auto& [var, buffer] : new_pfunc->buffer_map) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + } + auto new_prim_func = WithAttr(new_pfunc, "scoped", Integer(1)); + updates_->Add(gv, new_prim_func); + return call; + } + IRModule mod_; + IRModule updates_; +}; + +namespace transform { + +Pass SpecializePrimFuncBasedOnCallSite() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) { return relax::SpecializeTIRCallArgs().Run(mod); }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"SpecializePrimFuncBasedOnCallSite", + /*required=*/{}); +} + +TVM_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") + .set_body_typed(SpecializePrimFuncBasedOnCallSite); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 55e355b4bac2..87fa2f23b4b9 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -387,7 +387,7 @@ inline String GetCodegenName(const std::string& composite_name) { inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { - if (vdevices[i] == vdevice) { + if (vdevices[i].same_as(vdevice)) { return i; } } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index ef9438350ce0..9aed336ca0fa 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -219,7 +219,11 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, if (n->attrs.as()) { AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values); const_cast(n->attrs.get())->VisitAttrs(&printer); + ExprDoc scope_val = kwargs_values.back(); + kwargs_keys.pop_back(); + kwargs_values.pop_back(); args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values)); + args.push_back(scope_val); } return Relax(d, "hint_on_device")->Call(args); } @@ -242,7 +246,8 @@ Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, int dev_index = FindVDeviceIndexByTargetKind(vdev, d); kwargs_keys.push_back("dst_vdevice"); kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("dst_vdevice"))); + LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index) + ":" + vdev->memory_scope, + n_p->Attr("dst_vdevice"))); } return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 7043952c7c15..66643fc2e9fd 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -126,8 +126,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("vdevice"); std::string dev_kind = n->vdevice.value()->target->kind->name; int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d); - kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("vdevice"))); + kwargs_values.push_back(LiteralDoc::Str( + dev_kind + ":" + std::to_string(dev_index) + ":" + n->vdevice.value()->memory_scope, + n_p->Attr("vdevice"))); } if (args.empty() && kwargs_keys.empty()) { return Relax(d, "Tensor"); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 6195313fddae..549c63f374c5 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2144,6 +2144,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") return "O"; } }); +TVM_REGISTER_GLOBAL("tir.schedule.HasIfThenElse").set_body_typed([](const Stmt& stmt) -> bool { + return HasIfThenElse(stmt); +}); } // namespace tir } // namespace tvm diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e3274aea886a..b0bec5e858af 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -17,6 +17,7 @@ import pytest import tvm +import tvm.testing from tvm import relax import tvm.script diff --git a/tests/python/relax/test_transform_annotate_custom_scope.py b/tests/python/relax/test_transform_annotate_custom_scope.py new file mode 100644 index 000000000000..303966c88f7b --- /dev/null +++ b/tests/python/relax/test_transform_annotate_custom_scope.py @@ -0,0 +1,1141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, scope_info: dict) -> None: + self.scope_info = scope_info + self.matched = True + + def visit(self, mod: IRModule) -> None: + """Entry point""" + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + return self.matched + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + # if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + assert ( + arg_sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mispatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + assert ( + call.sinfo_args[0].vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mispatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + assert ( + sinfo.vdevice.memory_scope + == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mispatched for return scope for {idx} in {call.args[0].name_hint}" + + +def verify(mod, expected): + tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + with tgt: + mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} + mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) + mod = tvm.relax.transform.Normalize()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.LegalizeOps(add_attributes=True)(mod) + mod = tvm.relax.transform.LegalizeOps( + {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, + add_attributes=True, + )(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.FuseOps()(mod) + mod = tvm.relax.transform.FuseTIR()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.AnnotateCustomMemoryScope(tgt)(mod) + mod = tvm.relax.transform.Normalize()(mod) + + ValidateScope(expected).visit(mod) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 64, 56, 56), "float32"), w: R.Tensor((32, 64, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 32, 54, 54), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-nhwc"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-nhwc", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def _test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "relu": (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_relu_tir_tanh": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "relu": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "transpose": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "expand_dims": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "squeeze": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "strided_slice": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_transpose_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + gv5: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[0], axes=[3, 2, 1, 0]) + gv6: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[1], axes=[3, 2, 1, 0]) + gv7: R.Tensor((26, 26, 8, 2), "float32") = R.concat((gv5, gv6), axis=2) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + "fused_transpose_transpose1_concatenate1": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "adaptive_avg_pool2d": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "softmax": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "layer_norm": (["global.texture-weight", "global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + "add": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_residual_block(): + """ + - some kind of residual block followed by convolution to have texture after residual block + - scalar data type verification which should be mapped to global memory scope + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) | + \ / + add <- add should be fused into conv2d (2) + multiply to scalar <- buffer to the input of multiply scalar value + relu + | <- texture in intermediate tensor + conv2d (3) + relu + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 2, 2), "float32"), + w2: R.Tensor((32, 32, 1, 1), "float32"), + w3: R.Tensor((32, 32, 2, 2), "float32"), + bias: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[1, 1], out_dtype="float32") + bias_1 = R.multiply(bias, R.const(0.15, "float32")) + gv4 = R.add(gv3, bias_1) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv5, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.nn.relu(gv6) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "multiply": (["global"], ["global"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform5": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo2_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform6": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_fallback_to_buffer_conv2d(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform4": (["global"], ["global"]), + "conv2d": (["global", "global"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + "concatenate": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "concatenate": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d": (["global.texture-weight"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo1_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo3_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_inputs1(): + """ + Input + / \ + / | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / + | | \ / + | | (3) add + | | | + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv1) + gv = R.add(ad3, ad2) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_nwo_inputs2(): + """ + Input + / \ + | \ + conv2d \ + | / + conv2d mean / + / \ / + add | \ | + | | \ | + | | \ / + | | (3) add + | | | + | \ / + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv2) + gv = R.add(ad2, ad3) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo1_add_multiply_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index db4130f947d1..3b601d2f4d29 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -206,10 +206,9 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( - lv2, axes=[0, 3, 1, 2] - ) - gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) + lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) R.output(gv) return gv @@ -4585,5 +4584,413 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((8, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 40, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 2, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv5, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv6: R.Tensor((2, 10, 10, 10, 4), dtype="float32") = R.concat((gv3, gv6), axis=1) + gv7: R.Tensor((2, 40, 10, 10), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_conv2d_callback_to_buffer_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((5, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 37, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 20, 20), dtype="float32") = R.layout_transform( + gv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 5, 10, 10), dtype="float32") = R.nn.conv2d( + lv5, + w3, + strides=[2, 2], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv6: R.Tensor((2, 32, 10, 10), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv7: R.Tensor((2, 37, 10, 10), dtype="float32") = R.concat((lv6, gv6), axis=1) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((32, 32, 1, 1), dtype="float32"), + w2: R.Tensor((32, 32, 2, 2), dtype="float32"), + w3: R.Tensor((32, 32, 1, 1), dtype="float32"), + w4: R.Tensor((32, 32, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 32, 20, 20), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv1: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.max_pool2d( + gv, pool_size=[2, 2], strides=[2, 2], layout="NCHW4c", out_layout="NCHW4c" + ) + lv2: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv2, + padding=[0, 0, 1, 1], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv3: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, lv3) + gv4: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv3) + lv4: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv5: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv5: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w4, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv5, + strides=[1, 1], + padding=[0, 1, 1, 0], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv7: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv6) + gv8: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, gv5) + lv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv8, gv6) + gv9: R.Tensor((2, 32, 20, 20), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv9) + return gv9 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 99e7a5d2b737..0e8dbaeb11b7 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1117,7 +1117,7 @@ def fused_concatenate_transpose2( (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32" ), ): - T.func_attr({"tir.noalias": T.bool(True)}) + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) T_concat_handle_intermediate = T.alloc_buffer( (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) ) @@ -1307,7 +1307,7 @@ def fused_reshape( (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32" ), ): - T.func_attr({"tir.noalias": T.bool(True)}) + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) # with T.block("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): with T.block("T_reshape"): diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 694e7a688cf7..c0ff78ca4c6b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -439,5 +439,20 @@ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): _check(foo, bb.get()["foo"]) +def test_hint_on_device_scoped(): + @R.function + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + r = R.hint_on_device(x, R.device(4, 2), "global.texture") + return r + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + tensor = bb.emit(relax.op.hint_on_device(x, R.opencl(2), "global.texture")) + bb.emit_func_output(tensor) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main()