diff --git a/CMakeLists.txt b/CMakeLists.txt index 3667ed6ba974..c4b6ae4e3ceb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -307,6 +307,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/analysis/*.cc src/relax/transform/*.cc src/relax/backend/vm/*.cc + src/relax/backend/adreno/*.cc src/relax/backend/task_extraction.cc src/relax/backend/pattern_registry.cc src/relax/utils.cc diff --git a/ffi/cmake/Utils/AddLibbacktrace.cmake b/ffi/cmake/Utils/AddLibbacktrace.cmake index 844a8816a6d8..fa5c26b02280 100644 --- a/ffi/cmake/Utils/AddLibbacktrace.cmake +++ b/ffi/cmake/Utils/AddLibbacktrace.cmake @@ -26,6 +26,10 @@ function(_libbacktrace_compile) set(_cmake_c_compiler "${CMAKE_C_COMPILER}") endif() + if(DEFINED CMAKE_C_COMPILER_TARGET) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --target=${CMAKE_C_COMPILER_TARGET}") + endif() + message(STATUS CMAKC_C_COMPILER="${CMAKE_C_COMPILER}") file(MAKE_DIRECTORY ${_libbacktrace_prefix}/include) diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h index cce78e9fd615..af890bc29599 100644 --- a/include/tvm/relax/attrs/op.h +++ b/include/tvm/relax/attrs/op.h @@ -103,13 +103,15 @@ struct ToVDeviceAttrs : public AttrsNodeReflAdapter { struct HintOnDeviceAttrs : public AttrsNodeReflAdapter { int32_t dev_type; int32_t dev_id; + MemoryScope memory_scope; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("dev_type", &HintOnDeviceAttrs::dev_type, "The device type where the data is supposed to be executed.") - .def_ro("dev_id", &HintOnDeviceAttrs::dev_id, "The device id."); + .def_ro("dev_id", &HintOnDeviceAttrs::dev_id, "The device id.") + .def_ro("memory_scope", &HintOnDeviceAttrs::memory_scope, "The device memory scope."); } static constexpr const char* _type_key = "relax.attrs.HintOnDeviceAttrs"; diff --git a/include/tvm/relax/backend/adreno/transform.h b/include/tvm/relax/backend/adreno/transform.h new file mode 100644 index 000000000000..891a19187739 --- /dev/null +++ b/include/tvm/relax/backend/adreno/transform.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/backend/adreno/transform.h + * \brief Adreno GPU specific transformation passes. + */ +#ifndef TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ +#define TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ + +#include +#include +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { +namespace transform { + +using Pass = tvm::transform::Pass; +using PassInfo = tvm::transform::PassInfo; +using PassContext = tvm::transform::PassContext; +using Function = tvm::relax::Function; +using DataflowBlock = tvm::relax::DataflowBlock; +using tvm::relax::transform::CreateFunctionPass; +using tvm::transform::CreateModulePass; + +/*! + * \brief This pass is designed to annotate the memory scope information via VDevice attribute. + * This pass need operator attrbutes which in general vanish aftre legalization. + * FuseOps and FuseTIR are modified to pass on the operator specific attributes and also + * op_pattern details as part of the PrimFunc. This pass is Adreno specific and annotates each + * BindingVar with appropriate HintInDevice. RealizeVDevice pass followed by handles these hints. + * Followed by this pass we also invoke SpecializePrimFuncBasedOnCallSite which updates the + * var_buffer_map based on this new VDevice information. + */ +TVM_DLL Pass AnnotateCustomMemoryScope(Target target); + +/* + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ +TVM_DLL Pass FoldVDeviceScopeChange(); + +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_BACKEND_ADRENO_TRANSFORM_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e4049f23873c..3668773dd12d 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -160,9 +160,13 @@ class CallNode : public ExprNode { /*! * \brief The structure info arguments of a CallNode. - * sinfo_args is designed to be non-empty only for intrinsic op (e.g., + * sinfo_args is by default designed to be non-empty only for intrinsic op (e.g., * call_tir, call_builtin_with_ctx, etc.) and calls to ExternFuncs, with the main * usage of structure info inference. + * + * Regular ops also at times may have sinfo_args defined to specialize partial + * or complete structure info. Like VDevice customization with mixed input memory_scopes. + * The customized pass can set this info and operator specific inference will respect it. */ Array sinfo_args; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 27f226042864..3a572e89fa97 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -244,11 +244,13 @@ TVM_DLL Pass FoldConstant(); * * \param cmap The customized operator legalization function map. The customized function * will override the default one. + * \param skip_ops The list operator names which need to be skipped from legalization * \param enable_warning A boolean value indicating if to print warnings for TIR functions not * showing up in the database. * \return The Pass. */ -TVM_DLL Pass LegalizeOps(Optional> cmap, bool enable_warning = false); +TVM_DLL Pass LegalizeOps(Optional> cmap, + Optional> skip_ops, bool enable_warning = false); /*! * \brief Propagate virtual device information. @@ -677,6 +679,13 @@ TVM_DLL Pass RewriteCUDAGraph(); */ TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark); +/*! + * \brief This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + * Primarily used to update the VDevice information if any changes occured from the caller. + * This pass recreates the buffers and updates the map. + */ +TVM_DLL Pass SpecializePrimFuncBasedOnCallSite(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 6eebe49ff135..4579e26b1ee2 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -175,6 +175,16 @@ class NDArray : public tvm::ffi::NDArray { */ TVM_DLL static void CopyToBytes(const DLTensor* from, void* to, size_t nbytes, TVMStreamHandle stream = nullptr); + + /*! + * \brief Function to copy data from one array to a byte buffer. + * \param from The source array. + * \param to The target byte buffer. + * \param nbytes The size of the data buffer. + * \param stream The stream used in copy. + */ + TVM_DLL static void CopyFromBytes(const DLTensor* to, void* from, size_t nbytes, + TVMStreamHandle stream = nullptr); }; /*! diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index bd70acf00f90..3d42d1972dcc 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,6 +16,7 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu +from . import adreno from . import cpu from .analysis import ( BlockInfo, diff --git a/python/tvm/dlight/adreno/__init__.py b/python/tvm/dlight/adreno/__init__.py new file mode 100644 index 000000000000..ea2781455989 --- /dev/null +++ b/python/tvm/dlight/adreno/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Adreno schedule rules. +""" +from .convolution import Conv2d diff --git a/python/tvm/dlight/adreno/base.py b/python/tvm/dlight/adreno/base.py new file mode 100644 index 000000000000..d043706c2fc5 --- /dev/null +++ b/python/tvm/dlight/adreno/base.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for Adreno operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class AdrenoScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to Adreno targets, + will return None if the target is not Adreno.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for Adreno rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "adreno" in target.keys diff --git a/python/tvm/dlight/adreno/convolution.py b/python/tvm/dlight/adreno/convolution.py new file mode 100644 index 000000000000..fc2cc449a1c6 --- /dev/null +++ b/python/tvm/dlight/adreno/convolution.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Conv2d schedule rule for Adreno GPU operators.""" +from dataclasses import dataclass +from typing import List, Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV + +from ..analysis import BlockInfo, IterInfo +from .base import AdrenoScheduleRule + + +def is_spatial_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + +def is_reduction_block(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo: + def _iter_kind(loop: tir.IterVar) -> str: + return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + + def _is_reduction_block(block: tir.schedule.BlockRV): + for iter_var in sch.get(block).iter_vars: + if _iter_kind(iter_var) == "R": + return True + return False + + return BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter_var), + var=iter_var.var, + dom=iter_var.dom.extent, + loop_rv=loop_rv, + ) + for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars) + ], + block_rv=block, + reduction_block=_is_reduction_block(block), + ) + + +def get_reduction_blocks(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]) -> bool: + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all( + [is_reduction_block(sch, block) or is_spatial_block(sch, block) for block in blocks] + ): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction_block(sch, block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks[0] + + +def is_convolution(sch: tir.Schedule, block: tir.schedule.BlockRV): + # TODO: Use buffer access patterns to discover convolution type kernels instead of using name. + return ( + sch.get(block).name_hint.count("conv2d_NCHWc_OIHWo") + and "".join([iter_type.kind for iter_type in get_block_info(sch, block).iters]) + == "SSSSSRRR" + ) + + +class Conv2d(AdrenoScheduleRule): + """The schedule rule for convolution computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Conv2d.Config( + block_size_x=8, + block_size_y=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and ( + ("android" in str(target.host)) or ("adreno" in str(target.attrs)) + ): + return Conv2d.Config( + block_size_x=32, + block_size_y=4, + vector_size=8, + unroll=16, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Conv2d.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if isinstance(func, tir.PrimFunc): + sch = tir.Schedule(func) + + # config = self.get_configs(target) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_block = get_reduction_blocks(sch, blocks) + + if reduction_block is None: + return None + if not is_convolution(sch, reduction_block): + return None + + def schedule_data_pad(blk): + axes = sch.get_loops(blk) + axes, vec = axes[:-1], axes[-1] + axis = sch.fuse(*axes) + bx, ty, tx = sch.split(axis, [None, 16, 16]) + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + def schedule_conv2d(blk): + # TODO: Loop Pattern mayn't be reliable, need to perform better analysis. + n, oc, oh, ow, ob, ic, kh, kw = sch.get_loops(blk) + sch.reorder(n, oc, oh, ow, ic, kh, kw, ob) + main_lp = sch.fuse(n, oc, oh, ow) + bx, ty, tx = sch.split(main_lp, [None, 16, 16]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + + ico, icv = sch.split(ic, [None, 4]) + sch.reorder(ico, kh, kw, icv, ob) + rblk = sch.cache_read(blk, 0, "local") + sch.compute_at(rblk, kw) + sch.vectorize(sch.get_loops(rblk)[-1]) + wblk = sch.cache_write(blk, 0, "local") + sch.reverse_compute_at(wblk, tx) + sch.vectorize(sch.get_loops(wblk)[-1]) + sch.vectorize(ob) + init_blk = sch.decompose_reduction(blk, ico) + sch.vectorize(sch.get_loops(init_blk)[-1]) + + def is_data_pad(block: tir.stmt.Block): + return is_spatial_block(sch, block) and tir.analysis.has_if_then_else(sch.get(block)) + + def schedule_conv2d_blocks(): + + # Do analysis to find block type + blocks = sch.get_child_blocks(root_block) + passed_reduction = False + for blk in blocks: + if is_reduction_block(sch, blk): + schedule_conv2d(blk) + passed_reduction = True + elif is_data_pad(blk): + schedule_data_pad(blk) + elif is_spatial_block(sch, blk): + try: + if not passed_reduction: + sch.compute_inline(blk) + else: + sch.reverse_compute_inline(blk) + except: # pylint: disable=W0702 + pass + else: + raise TypeError("Can't Schedule this Block", sch.get(blk)) + + schedule_conv2d_blocks() + return sch diff --git a/python/tvm/relax/backend/adreno/__init__.py b/python/tvm/relax/backend/adreno/__init__.py index b3364f2f4b4a..b97ea399ab19 100644 --- a/python/tvm/relax/backend/adreno/__init__.py +++ b/python/tvm/relax/backend/adreno/__init__.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. """The Relax Adreno backend compilation pipeline and other passes.""" + +from . import transform + from .pipeline import ( finalize_passes, get_default_pipeline, diff --git a/python/tvm/relax/backend/adreno/transform/__init__.py b/python/tvm/relax/backend/adreno/transform/__init__.py new file mode 100644 index 000000000000..abeb56ac488c --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Adreno Relax transformations. """ + +from .transform import ( + AnnotateCustomMemoryScope, + FoldVDeviceScopeChange, +) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py new file mode 100644 index 000000000000..7a19e3380feb --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for Adreno transform""" +import tvm.ffi + +tvm.ffi._init_api("relax.backend.adreno.transform", __name__) diff --git a/python/tvm/relax/backend/adreno/transform/transform.py b/python/tvm/relax/backend/adreno/transform/transform.py new file mode 100644 index 000000000000..9a01d7be97dd --- /dev/null +++ b/python/tvm/relax/backend/adreno/transform/transform.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Adreno Relax transformation passes.""" +from typing import Optional + +import tvm.ir +from tvm.target import Target + +from . import _ffi_api + + +def AnnotateCustomMemoryScope(target: Optional[Target] = None) -> tvm.ir.transform.Pass: + """Allocate the memory scope information. This is Adreno specific pass to annotate + The memory scope information and realize the same with RealizeVDevice pass followed by + updating the Prim Function var_buffer mapping using SpecializePrimFuncBasedOnCallSite. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AnnotateCustomMemoryScope(target) # type: ignore + + +def FoldVDeviceScopeChange() -> tvm.ir.transform.Pass: + """This pass is a texture specific pass that can optimize unnecessary to_device copies. + Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + store into global scope avoiding unnecessary device copy. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.FoldVDeviceScopeChange() # type: ignore diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index b0570344e5a0..185da3fd7068 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -813,7 +813,7 @@ def to_vdevice(data, dst_vdevice) -> Expr: return _ffi_api.to_vdevice(data, dst_vdevice) # type: ignore -def hint_on_device(data, dst_vdevice) -> Expr: +def hint_on_device(data, dst_vdevice, memory_scope="global") -> Expr: """It provides a hint specifying the device on which the input data should be executed. This hint is utilized by RealizeVDevice to propagate the virtual device." @@ -822,12 +822,15 @@ def hint_on_device(data, dst_vdevice) -> Expr: data : Expr The tensor to be copied. - dst_device : VDevice + dst_device : Device The destination device where the data is supposed to be executed. + memory_scope: String + Memory scope of buffer on target device. + Returns ------- result : Expr The result. """ - return _ffi_api.hint_on_device(data, dst_vdevice) # type: ignore + return _ffi_api.hint_on_device(data, dst_vdevice, memory_scope) # type: ignore diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 724921e5fee7..dacbc667be2b 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -83,6 +83,7 @@ UpdateVDevice, VMBuiltinLower, VMShapeLower, + SpecializePrimFuncBasedOnCallSite, dataflowblock_pass, function_pass, ) diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py index b4aba0291fc1..11b9a293407a 100644 --- a/python/tvm/relax/transform/legalize_ops/__init__.py +++ b/python/tvm/relax/transform/legalize_ops/__init__.py @@ -31,3 +31,6 @@ from . import search from . import statistical from . import unary + +# Device specific legalizations +from . import adreno diff --git a/python/tvm/relax/transform/legalize_ops/adreno/__init__.py b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py new file mode 100644 index 000000000000..f2b3f4a781d2 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Legalize high-level operator calls in Relax functions to call_tir.""" +from .convolution import conv2d_NCHWc_OIHWo diff --git a/python/tvm/relax/transform/legalize_ops/adreno/convolution.py b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py new file mode 100644 index 000000000000..959e43778024 --- /dev/null +++ b/python/tvm/relax/transform/legalize_ops/adreno/convolution.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A Convolution impl for Adreno GPU.""" + +from tvm import relax +from tvm import topi + + +def conv2d_NCHWc_OIHWo(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr: + return bb.call_te( + topi.nn.conv2d_NCHWc_OIHWo, + data=call.args[0], + kernel=call.args[1], + stride=call.attrs.strides, + padding=call.attrs.padding, + dilation=call.attrs.dilation, + layout=call.attrs.data_layout, + out_layout=call.attrs.out_layout, + # out_dtype=call.attrs.out_dtype, + sinfo_args=call.sinfo_args, + primfunc_name_hint="conv2d_NCHWc_OIHWo", + ) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 57627ceebe66..ed81889c9337 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1062,7 +1062,9 @@ def BundleModelParams(param_tuple_name: Optional[str] = None) -> tvm.ir.transfor def LegalizeOps( - customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, enable_warning: bool = False + customize_legalize_map: Optional[Dict[str, LegalizeFunc]] = None, + skip_ops: Optional[List[str]] = None, + enable_warning: bool = False, ): """Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs. @@ -1088,6 +1090,9 @@ def LegalizeOps( The customized operator legalization function map. The customized function will override the default one. + skip_ops : Optional,List[str]] + List of ops that need to be skipped from legalization + enable_warning : bool A boolean value indicating if to print warnings for CallNode whose op's legalization function is not registered. By default we don't print @@ -1167,7 +1172,7 @@ def multiply( T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1] """ - return _ffi_api.LegalizeOps(customize_legalize_map, enable_warning) # type: ignore + return _ffi_api.LegalizeOps(customize_legalize_map, skip_ops, enable_warning) # type: ignore def RealizeVDevice() -> tvm.ir.transform.Pass: @@ -1605,6 +1610,19 @@ def AllocateWorkspace() -> tvm.ir.transform.Pass: return _ffi_api.AllocateWorkspace() # type: ignore +def SpecializePrimFuncBasedOnCallSite() -> tvm.ir.transform.Pass: + """This pass updates the var_buffer mapping of PrimFunctions from the call_tir info. + Primarily used to update the VDevice information if any changes occured from the caller. + This pass recreates the buffers and updates the map. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.SpecializePrimFuncBasedOnCallSite() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 9795631fbe10..5424e599ed5a 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -347,6 +347,7 @@ def _shape_with_old_tir_var( ) primfunc_attrs = kwargs.pop("primfunc_attrs", None) + custom_out_sinfo = kwargs.pop("sinfo_args", []) te_args = _convert_te_arg(args) te_kwargs = _convert_te_arg(kwargs) @@ -371,14 +372,17 @@ def _shape_with_old_tir_var( # with old set of variables. tir_var_inverse_map = {v: k for k, v in tir_var_map.items()} - output_sinfo = [ - TensorStructInfo( - _shape_with_old_tir_var(out.shape, tir_var_inverse_map), - out.dtype, - _get_vdevice(args), - ) - for out in outs - ] + if len(custom_out_sinfo) == 1: + output_sinfo = custom_out_sinfo[0] + else: + output_sinfo = [ + TensorStructInfo( + _shape_with_old_tir_var(out.shape, tir_var_inverse_map), + out.dtype, + _get_vdevice(args), + ) + for out in outs + ] tir_vars = None if len(unbound_tir_vars) > 0: diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 915b7f765c10..8a84d3ee51fa 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -301,6 +301,10 @@ def find_anchor_block(mod: IRModule) -> Block: return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member +def has_if_then_else(stmt: Stmt) -> bool: + return tvm.ffi.get_global_func("tir.schedule.HasIfThenElse")(stmt) + + def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: """Utility function to get the list of lowering passes to be applied to calculate the compacted VTCM allocation size diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 531c0a6c6663..ce14df8beddf 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -394,6 +394,135 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ) +def conv2d_NCHWc_OIHWo( + data: te.Tensor, kernel, stride, padding, dilation, layout, out_layout, out_dtype="float32" +): + """Conv2D operator for nChw[x]c layout. + + Parameters + ---------- + data : tvm.te.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + + kernel : tvm.te.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, filter_height, filter_width, + num_filter_block] + + stride : int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding : int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + layout : str + Input data layout + + out_layout : str + Output data layout + + out_dtype : str + output data type + + Returns + ------- + output : tvm.te.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dilation_h, dilation_w = ( + dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + ) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + kernel_shape = get_const_tuple(kernel.shape) + if len(kernel_shape) == 6: # OIHW4i4o + oc_chunk, ic_chunk_group, kernel_height, kernel_width, kernel_ic_bn, oc_bn = kernel_shape + groups = in_channel // (ic_chunk_group * kernel_ic_bn) + else: # OIHW4o + oc_chunk, ic, kernel_height, kernel_width, oc_bn = kernel_shape + groups = in_channel // ic + + num_filter = oc_chunk * oc_bn + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + # output shape + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) + + # DOPAD + DOPAD = HPAD != 0 or WPAD != 0 + if DOPAD: + data_pad = pad(data, pad_before, pad_after, name="conv2d_data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + def compute_conv2d(*args): + n, occ, oh, ow, ocb = args + ic = te.reduce_axis((0, in_channel // groups), name="ic") + if groups == 1: + data_pad_ = data_pad[ + n, + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + else: + data_pad_ = data_pad[ + n, + (occ // (oc_chunk // groups)) * (ic_chunk // groups) + idxdiv(ic, ic_bn), + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + idxmod(ic, ic_bn), + ] + if len(kernel_shape) == 5: + kernel_ = kernel[occ, ic, kh, kw, ocb] + else: + kernel_ = kernel[occ, idxdiv(ic, oc_bn), kh, kw, idxmod(ic, oc_bn), ocb] + + if out_dtype is not None: + data_pad_ = data_pad_.astype(out_dtype) + kernel_ = kernel_.astype(out_dtype) + + return te.sum( + data_pad_ * kernel_, + axis=[ic, kh, kw], + ) + + return te.compute( + oshape, + lambda *indices: compute_conv2d(*indices), # pylint: disable=W0108 + name="conv2d_NCHWc_OIHWo", + tag="conv2d_NCHWc_OIHWo", + ) + + def conv2d_NCHWc_int8( data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32", n_elems=4 ): diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc new file mode 100644 index 000000000000..3896acca908d --- /dev/null +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/adreno/annotate_texture_storage.cc + * \brief Texture Storage Annotation Pass for Adreno GPU targets. + * + * Texture realization for Adreno GPU targets requires fundamentally follows + * Stage 1: Transforming the shapes with inner most dimension being 4 + * Stage 2: Annotate appropriate memory_scope hint in VDevice of StructInfo + * Stage 3: TIR lowering does injects texture load/store builtins looking at this scope + * Stage 4: Finally codegen handles appropriate code looking at buffer types and load/store + * builtins. + * + * Stage 1 is generic and straight forward by using convert_layout pass that transforms the + * shapes as well as injecting layout_transform ops as needed. + * + * Stage 2 This pass is responsible for injeting appropriate VDevice into StructInfo and + * adding any copies if there is a conflict between producer and consuner scopes. + * + * After convert_layout the mod looks like below + * @I.ir_module + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform( + * x, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4))) + * lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform( + * w, + * index_map=T.index_map( + * lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4))) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform( + * lv2, + * index_map=T.index_map( + * lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3))) + * R.output(gv) + * return gv + * + * Here, the param layout transforms are injected properly and the conv2d op is operating + * in 5D shapes. + * + * Now, the scope annotation decisions are done by + * - For op_pattern < kCommReduce we just look for shape being 5D and inner dimsion = 4 + * - For op_pattern > kCommReduce we make decisions selectively. Currently we do enable texture + * scope for Conv2D, PoolOps. + * The trick here is whiel this pass is in action we need op_pattern information for ops that are + * below kCommReduce as well op attrbuted for seletive ops like Conv2D and PoolOps. + * op_pattern is available after legalization and TIROpPattern pass does an analysis. However, + * op specific attributes doesn't exist after legalization. + * + * To solve this issue, we go legalization in parts. + * At first, we call legalization by skipping the list of ops we wanted not to legalize. + * LigalizeOps is enhanced to accept skip_ops for this purpose. + * After legalization and AnnotateTIROpPattern this way the mod liiks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32") + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32") + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d( + * lv, + * lv1, + * data_layout="NCHW4c", + * kernel_layout="OIHW4o", + * out_layout="NCHW4c", + * out_dtype="float32" + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv2,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32") + * ) + * R.output(gv) + * return gv + * + * Here, the legalized prim functions does have op_pattern attribute. + * We now have what we wanted to run this pass. + * + * This pass in principle does scope annotation based on sonsumer priotiry. i.e. + * For any tensor object we tries to assign scope based on the sonsuner requirement. + * The conflicts and multiple consumers for same tensor are handled by injecting + * appropriate copies. + * 1: CollectConsumerScopeInfo: Visitor collects all consumer demand for each input + * 2: CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based + * on consumer scope information. It does evaluating mutiple consumer cases and conflicts. + * 3: DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update + * out StructInfo containing VDevice information. This update for tir calls is straight forward + * as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by + * design is invalid as we do this by "FInferStructInfo". + * Another issue we have with "FInferStructInfo" per op is they can't decide this + * memory scope information which is done by this pass based on consumer demand. + * Hence, we are going to use the sinfo_args to indicate this information. + * So, this pass attributes sinfo_args for regumar calls too and FInferStructInfo implmentation + * do take VDevice information fro this hint. This also solves the issue of mixed VDevice + * for arguments of an op. + * After these steps the mod looks like + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device( + * x, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv_1 = R.call_tir(cls.te_layout_transform, (lv,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device( + * w, R.device(dev_type=4, dev_id=0), "global" + * ) + * lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc") + * lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight") + * lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + & ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global") + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * What we have above is hint_on_device injections and out_sinfo for all calls. + * Now, we apply RealizeVDevice to formalize the hints. Follwed by we also call + * CanonicalizeBindings that removes redundant assignments like + * + * lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x + * lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w + * + * These assignments are result of hint_on_device not realizing any copy while consumer and + * producer has same memory scope or vdevice. These assignments do impact operator fusion. + * + * Now the mod looks like, + * + * class Module: + * @R.function + * def main( + * x: R.Tensor((2, 64, 56, 56), dtype="float32"), + * w: R.Tensor((32, 64, 3, 3), dtype="float32") + * ) -> R.Tensor((2, 32, 54, 54), dtype="float32"): + * with R.dataflow(): + * lv = R.call_tir(cls.te_layout_transform, (x,), + * out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32", + * vdevice="opencl:0:global.texture-nhwc" + * ) + * ) + * lv1 = R.call_tir(cls.te_layout_transform1, (w,), + * out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32", + * vdevice="opencl:2:global.texture-weight" + * ) + * ) + * lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global" + * ) = R.nn.conv2d( + * lv2, lv3, + * data_layout="NCHW4c", kernel_layout="OIHW4o", + * out_layout="NCHW4c", out_dtype="float32", + * sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32", + * vdevice="opencl:1:global"), + * ) + * ) + * gv = R.call_tir(cls.te_layout_transform2, (lv4,), + * out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global") + * ) + * R.output(gv) + * return gv + * + * Followed by, the compilation pipeline calls + * - legalization of the remainng ops: This legalization do forwards the annotated out_sinfo + * VDevice information to tir_calls + * - AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops + * - Fusion + * - FoldVDeviceScopeChange: There existed some ToVDevice copies from texture to buffer + * This pass removes the copes and updates producer scope to global. + * - SpecializePrimFuncBasedOnCallSite: Finally we updates the Buffer Var maps according to + * VDevice scopes. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +/* + * \brief generates consumer information for each var + * \return scope_info is a map which contain for each var the corresponding call nodes that + * consume it and corresponding scope it expects this input to be. + * \return call_scope_info is a map of each call_node and array holding scope infor for each input. + */ +class CollectConsumerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + std::pair>, Map>>> Collect( + const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + // Extend the scope for tuple items + for (const auto& val : arg_to_binding) { + if (scope_info.find(val.first) != scope_info.end()) { + if (scope_info.find(val.second) == scope_info.end()) { + scope_info.Set(val.second, scope_info[val.first]); + } else { + auto ent = scope_info[val.second]; + for (auto ent_val : scope_info[val.first]) { + ent.Set(ent_val.first, ent_val.second); + } + scope_info.Set(val.second, ent); + } + } + } + + return std::make_pair(call_scope_info, scope_info); + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(GetRef(binding->var.get()), + GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + GlobalVar gv; + Array op_attrs; + Optional op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); + Tuple func_args; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + op_attrs = ExtractAttrs(pfunc); + op_pattern = ExtractPattern(pfunc); + func_args = Downcast(call->args[1]); + } else { + op_attrs = {call->attrs}; + op_pattern = Integer(static_cast(OpPatternKind::kOpaque)); + func_args = Tuple(call->args); + } + + bool is_texture_supported = SupportsTexture(op_attrs, op_pattern.value()); + + Array arg_scope; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + auto scope = is_texture_supported + ? Scope(GetShapeFromTensorStructInfo(tensor_sinfo.value())) + : "global"; + Map> ent_call; + const VarNode* arg_var = arg.as(); + if (scope_info.find(GetRef(arg_var)) != scope_info.end()) { + ent_call = scope_info[GetRef(arg_var)]; + } + ent_call.Set(GetRef(call), {scope}); + scope_info.Set(GetRef(arg_var), ent_call); + arg_scope.push_back(scope); + } + } + call_scope_info.Set(GetRef(call), arg_scope); + } + + private: + template + Array ExtractAttrs(const T& func) { + Array op_attrs; + Optional attrs = func->template GetAttr("op_attrs"); + if (attrs) { + if (auto val = attrs.value().as()) { + op_attrs.push_back(val.value()); + } else if (auto val = attrs.value().as>()) { + op_attrs = val.value(); + } + } + return std::move(op_attrs); + } + + template + Optional ExtractPattern(const T& func) { + Optional op_pat = func->template GetAttr("op_pattern"); + return std::move(op_pat); + } + + bool SupportsTexture(const Array& op_attrs, Integer op_pattern) { + if (op_pattern.IntValue() < OpPatternKind::kCommReduce) return true; + + for (auto attr : op_attrs) { + if (auto conv_attr = attr.as()) { + if (conv_attr->data_layout == "NCHW4c" && conv_attr->kernel_layout == "OIHW4o") { + return true; + } + } else if (auto pool_attrs = attr.as()) { + if (pool_attrs->layout == "NCHW4c") { + return true; + } + } else if (auto avg_attrs = attr.as()) { + if (avg_attrs->layout == "NCHW4c") { + return true; + } + } else if (attr.as()) { + return true; + } + } + + return false; + } + + std::string Scope(Array shape) { + // currently we support only textures been made from 5d tensors + // 5d requirement is not limitation of textures in general, it is limitation how + // we are representing memory scopes/layout and flattening of textures in tir + if (shape.size() == 5 && shape[4].as()->value == 4) { + for (auto ind : shape) { + if (!ind.as()) { + // Dynamic tensors + return "global.texture-nchw"; + } + } + std::map diffs; + int spatial_limit = + target_->GetAttr("texture_spatial_limit").value_or(Integer(16384))->value; + int depth_limit = + target_->GetAttr("texture_depth_limit").value_or(Integer(2048))->value; + int a0 = shape[0].as()->value; + int a1 = shape[1].as()->value; + int a2 = shape[2].as()->value; + int a3 = shape[3].as()->value; + + int d1r = a0 * a1; + int d2r = a2 * a3; + int d3r = a1 * a2 * a3; + std::string scope = "global"; + if (a0 < spatial_limit && d3r < spatial_limit) + scope += ".texture-weight"; + else if (a0 < depth_limit && a1 < spatial_limit && d2r < spatial_limit) + scope += ".texture-nhwc"; + else if (d1r < depth_limit && a2 < spatial_limit && a3 < spatial_limit) + scope += ".texture"; + return scope; + } + return "global"; + } + + /* Map of each Var consumption by a call node and its scope */ + Map>> scope_info; + /* A map of call node and scope info for each argument it consunes */ + Map> call_scope_info; + Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +/* + * \brief producer scope information consolidated based on consumer demands. + * \return producer_info which is a map of each call node and corresponding out StructInfo + * This pass considers all consumers and their scope demand. + * Any mismatches here introduces copies as needed. + */ +class CollectProducerScopeInfo : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map Collect(const IRModule& mod, Function func, + const Map>>& scope_info, + const Target& target, const BlockBuilder& builder) { + mod_ = mod; + scope_info_ = scope_info; + target_ = target; + builder_ = builder; + VisitExpr(func->body); + + return producer_sinfo; + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + ExprVisitor::VisitBinding_(binding, call); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + out_sinfo = call->sinfo_args[0]; + } else { + tvm::OpAttrMap op_map_infer_struct_info_ = + Op::GetAttrMap("FInferStructInfo"); + + auto* op_ptr = call->op.as(); + Op op = GetRef(op_ptr); + ICHECK(op_map_infer_struct_info_.count(op)) + << " Cannot find the FInferStructInfo attribute registered to op: " << op->name; + out_sinfo = op_map_infer_struct_info_[op](GetRef(call), builder_); + } + + std::unordered_map scope_count; + + // Decide the final scope based on the max consumer demand. Rest will use to_device. + auto arg_var = binding->var.as(); + if (scope_info_.find(GetRef(arg_var)) != scope_info_.end()) { + for (const auto& val : scope_info_[GetRef(arg_var)]) { + auto call_node = Downcast(val.first); + if (scope_count.find(val.second[0]) == scope_count.end()) { + scope_count.insert({val.second[0], 1}); + } else { + auto curr_count = scope_count[val.second[0]]; + scope_count.emplace(val.second[0], curr_count + 1); + } + } + } + String final_scope = "global"; + int count = 0; + for (const auto& sval : scope_count) { + if (sval.second > count) { + final_scope = sval.first; + count = sval.second; + } + } + // Applying same scope for outputs + StructInfo updated_ret_sinfo = UpdateStructInfo(out_sinfo, {final_scope}); + producer_sinfo.Set(GetRef(call), updated_ret_sinfo); + } + + private: + StructInfo UpdateStructInfo(const StructInfo& out_sinfo, Array scope) { + if (out_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(out_sinfo); + auto shape_arr = GetShapeFromTensorStructInfo(tensor_sinfo); + return TensorStructInfo(ShapeExpr(shape_arr), tensor_sinfo->dtype, + VDevice(target_, 0, scope[0])); + } + + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + sinfo_fields.push_back( + TensorStructInfo(ShapeExpr(shape_arr), sinfo->dtype, VDevice(target_, 0, scope[0]))); + } + return TupleStructInfo(sinfo_fields); + } + + Map>> scope_info_; + Map producer_sinfo; + IRModule mod_; + Target target_; + BlockBuilder builder_; +}; + +/* + * \brief main pass that injects hint_on_device for each argument based on producer, + * consumer indormations. This also attributes ret StructInfo for each call node. + * This pass also calls the ReliaseVdevice that formalizes the hints by appropriately injecting + * Vdevice copies as needed. + */ + +class DefineVDevice : ExprMutator { + public: + explicit DefineVDevice(const Target& target) : target_(target) {} + + IRModule Run(IRModule& mod) { + mod_ = mod; + for (const auto& [gv, func] : mod_->functions) { + if (func->IsInstance()) { + const auto& base_func = mod_->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + auto info = CollectConsumerScopeInfo().Collect(mod_, Downcast(func), target_); + call_scope_info_ = info.first; + scope_info_ = info.second; + producer_sinfo_ = CollectProducerScopeInfo().Collect(mod_, Downcast(func), + scope_info_, target_, builder_); + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + + Array global_vdevices_; + for (auto vdev : vdevices_) { + global_vdevices_.push_back(vdev.as().value()); + } + mod_.CopyOnWrite()->global_infos.Set("vdevice", global_vdevices_); + + mod_ = relax::transform::DeadCodeElimination()(mod_); + mod_ = relax::transform::RealizeVDevice()(mod_); + mod_ = relax::transform::CanonicalizeBindings()(mod_); + + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + + GlobalVar gv; + Tuple func_args; + + StructInfo out_sinfo; + + if (call->op == call_tir_op) { + gv = Downcast(call->args[0]); + // tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + // out_sinfo = call->sinfo_args[0]; + func_args = Downcast(call->args[1]); + } else { + func_args = Tuple(call->args); + // return call; + } + + Array new_args; + StructInfo updated_ret_sinfo = producer_sinfo_[GetRef(call_node)]; + + if (updated_ret_sinfo->IsInstance()) { + auto tensor_sinfo = Downcast(updated_ret_sinfo); + auto shape = tensor_sinfo->shape.value(); + auto dtype = tensor_sinfo->dtype; + if (tensor_sinfo->vdevice.defined()) { + auto vdev = tensor_sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + updated_ret_sinfo = TensorStructInfo(shape, dtype, vdev_global); + } + } else { + ICHECK(updated_ret_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << updated_ret_sinfo; + + const auto& tuple_sinfo = Downcast(updated_ret_sinfo); + Array sinfo_fields; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + + auto shape_arr = GetShapeFromTensorStructInfo(sinfo); + + auto shape = sinfo->shape.value(); + auto dtype = sinfo->dtype; + if (sinfo->vdevice.defined()) { + auto vdev = sinfo->vdevice.value(); + const VDevice& vdev_global = MakeGlobalVDevice(vdev); + sinfo_fields.push_back(TensorStructInfo(shape, dtype, vdev_global)); + } else { + sinfo_fields.push_back(sinfo); + } + } + updated_ret_sinfo = TupleStructInfo(sinfo_fields); + } + + int arg_idx = 0; + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + String scope = "global"; + if (call_scope_info_.find(GetRef(call_node)) != call_scope_info_.end()) { + scope = call_scope_info_[GetRef(call_node)][arg_idx]; + } + new_args.push_back(HintArg(arg, scope)); + arg_idx++; + } else { + new_args.push_back(arg); + } + } + + if (call->op == call_tir_op) { + return builder_->Normalize( + Call(call_tir_op, {gv, Tuple(new_args)}, call->attrs, {updated_ret_sinfo})); + } else { + return builder_->Normalize(Call(call->op, new_args, call->attrs, {updated_ret_sinfo})); + } + } + + private: + VDevice MakeGlobalVDevice(VDevice vdev) { + int device_type = vdev->target->GetTargetDeviceType(); + for (size_t i = 0; i < vdevices_.size(); ++i) { + int dev_type = vdevices_[i]->target->GetTargetDeviceType(); + if (dev_type == device_type && vdevices_[i]->vdevice_id == vdev->vdevice_id && + vdevices_[i]->memory_scope == vdev->memory_scope) { + return vdevices_[i]; + } + } + vdevices_.push_back(vdev); + return (vdevices_.back()); + } + + Expr HintArg(const Expr& arg, String scope) { + if (arg->IsInstance()) { + if (auto tsinfo = arg->struct_info_.as()) { + if (!tsinfo->vdevice.defined()) { + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + CHECK(tsinfo->shape.defined()) << "Shape not defined for a constant tensor ..!"; + arg->struct_info_ = + TensorStructInfo(tsinfo->shape.value(), tsinfo->dtype, vdev, tsinfo->span); + return arg; + } + } + } + ObjectPtr attrs = make_object(); + const VDevice& vdev = MakeGlobalVDevice(VDevice(target_, 0, scope)); + attrs->dev_type = vdev->target->GetTargetDeviceType(); + attrs->dev_id = vdev->vdevice_id; + attrs->memory_scope = vdev->memory_scope; + + Expr new_arg = Call(hint_on_device_op_, {arg}, Attrs{std::move(attrs)}, {}); + + return std::move(new_arg); + } + + Optional GetTarget(const StructInfo& sinfo) { + auto tinfo = sinfo.as(); + if (tinfo->vdevice.defined()) { + auto vdevice = tinfo->vdevice.value(); + if (vdevice->target.defined()) { + return vdevice->target; + } + } + return std::nullopt; + } + + const Op& hint_on_device_op_ = Op::Get("relax.hint_on_device"); + IRModule mod_; + IRModule updates_; + Target target_; + Array vdevices_; + Map>> scope_info_; + Map producer_sinfo_; + Map> call_scope_info_; +}; + +namespace transform { + +Pass AnnotateCustomMemoryScope(Target target) { + auto pass_func = [=](IRModule mod, PassContext pc) { + return tvm::relax::backend::adreno::DefineVDevice(target).Run(mod); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"AnnotateCustomMemoryScope", + /*required=*/{}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.AnnotateCustomMemoryScope") + .set_body_typed(AnnotateCustomMemoryScope); + +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc new file mode 100644 index 000000000000..73c1e51acb90 --- /dev/null +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/adreno/fold_vdevice_scope_change.cc + * \brief This is a texture specific pass that can optimize unnecessary to_device copies. + * Like texture_scope -> ToVDevice -> global scope. In this case the producer can directly + * store into global scope avoiding unnecessary device copy. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../../op/tensor/manipulate.h" +#include "../../transform/infer_layout_utils.h" +#include "../../transform/utils.h" + +namespace tvm { +namespace relax { +namespace backend { +namespace adreno { + +namespace { +std::tuple)>> CreatePatterns( + Map> consumers) { + auto pat_gv = WildcardPattern(); + + auto pat_inp = WildcardPattern(); + auto pat_call_tir = IsOp("relax.call_tir")(pat_gv, pat_inp); + auto pattern_out = IsOp("relax.to_vdevice")(pat_call_tir); + + auto rewriter = [=](Expr expr, Map matches) -> Expr { + const auto* call_tir = matches[pat_call_tir].as(); + ICHECK(call_tir) << "InternalError: " + << "Match of relax.call_tir operator should produce Call, " + << "but instead produces " << matches[pat_call_tir] << " with type " + << matches[pat_call_tir]->GetTypeKey(); + + const auto* out = matches[pattern_out].as(); + ICHECK(out) << "InternalError: " + << "Match of relax.to_vdevice operator should produce Call, " + << "but instead produces " << matches[pattern_out] << " with type " + << matches[pattern_out]->GetTypeKey(); + + const auto* vdev_attrs = out->attrs.as(); + ICHECK(vdev_attrs) << "InternalError: " + << "Attributes for relax.to_vdevice operator should be ToVDeviceAttrs, " + << "but were instead " << out->attrs << " with type " << out->GetTypeKey(); + + const auto* tir_out_sinfo = call_tir->sinfo_args[0].as(); + if (!tir_out_sinfo) return expr; + + if (!tir_out_sinfo->vdevice.defined()) return expr; + + const VarNode* arg_var = out->args[0].as(); + if (consumers.find(GetRef(arg_var)) != consumers.end()) { + if (consumers[GetRef(arg_var)].size() > 1) { + /* Don't do to_device optimization as we are not the only consumer */ + return expr; + } + } + + if ((std::string(tir_out_sinfo->vdevice.value()->memory_scope).find("texture") != + std::string::npos) && + (vdev_attrs->dst_vdevice->memory_scope == "global")) { + auto shape_arr = tir_out_sinfo->GetShape().value(); + auto new_sinfo = + TensorStructInfo(ShapeExpr(shape_arr), tir_out_sinfo->dtype, vdev_attrs->dst_vdevice); + + return Call(call_tir->op, call_tir->args, call_tir->attrs, {new_sinfo}); + } + return expr; + }; + + return {pattern_out, rewriter}; +} + +} // namespace + +class CollectConsumerDetails : public ExprVisitor { + public: + using ExprVisitor::VisitExpr_; + + Map> Collect(const IRModule& mod, Function func, const Target& target) { + mod_ = mod; + target_ = target; + VisitExpr(func->body); + // Extend the consumer details for tuple items + for (const auto& val : arg_to_binding) { + if (consumers.find(val.first) != consumers.end()) { + if (consumers.find(val.second) == consumers.end()) { + consumers.Set(val.second, consumers[val.first]); + } else { + auto ent = consumers[val.second]; + for (auto ent_val : consumers[val.first]) { + ent.push_back(ent_val); + } + consumers.Set(val.second, ent); + } + } + } + return consumers; + } + + void VisitBinding_(const VarBindingNode* binding, + const TupleGetItemNode* tuple_get_item_node) final { + if (arg_to_binding.find(GetRef(binding->var.get())) == arg_to_binding.end()) { + arg_to_binding.Set(GetRef(binding->var.get()), + GetRef(tuple_get_item_node->tuple.get())); + } + } + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + Tuple func_args; + + if (call->op == call_tir_op) { + func_args = Downcast(call->args[1]); + } else { + func_args = Tuple(call->args); + } + + for (auto arg : func_args->fields) { + auto sinfo = GetStructInfo(arg); + if (auto tensor_sinfo = sinfo.as()) { + Array call_list; + + const VarNode* arg_var = arg.as(); + + if (consumers.find(GetRef(arg_var)) != consumers.end()) { + call_list = consumers[GetRef(arg_var)]; + } + call_list.push_back(GetRef(call)); + consumers.Set(GetRef(arg_var), call_list); + } + } + } + + private: + /* Map of each Var consumption by a call node */ + Map> consumers; + Map arg_to_binding; + IRModule mod_; + Target target_; +}; + +namespace transform { + +Pass FoldVDeviceScopeChange() { + auto pass_func = [=](Function func, IRModule mod, PassContext pc) { + /* here Target doesn't matter as the consumers we use only to find multiple consumers */ + auto consumers = + CollectConsumerDetails().Collect(mod, Downcast(func), Target("opencl")); + auto [pattern, rewriter] = CreatePatterns(consumers); + return RewriteCall(pattern, rewriter, func); + }; + return CreateFunctionPass(pass_func, 1, "FoldVDeviceScopeChange", {}); +} +TVM_FFI_REGISTER_GLOBAL("relax.backend.adreno.transform.FoldVDeviceScopeChange") + .set_body_typed(FoldVDeviceScopeChange); +} // namespace transform +} // namespace adreno +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 4f07c78458a1..13b1174376bc 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -309,6 +309,8 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, ICHECK(attrs) << "Invalid Call"; LayoutDecision data_layout, weight_layout, output_layout; + data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); ObjectPtr new_attrs = make_object(*attrs); if (it != desired_layouts.end()) { @@ -354,14 +356,16 @@ InferLayoutOutput InferLayoutConv2d(const Call& call, new_attrs->kernel_layout = (*it).second[1]; new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs)); + } else { + data_layout = LayoutDecision(InitialLayout(4)); + weight_layout = LayoutDecision(InitialLayout(4)); } } } // We don't have a desired layout for conv2d or desired layouts not compatible. // We can just propagate the layout from the input. - data_layout = GetLayoutDecision(var_layout_map, call->args[0]); - weight_layout = GetLayoutDecision(var_layout_map, call->args[1]); + output_layout = data_layout; new_attrs->data_layout = TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 9ecf19e7ae11..a430c6464a39 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1367,15 +1367,24 @@ TVM_REGISTER_OP("relax.hint_on_device") .set_attr("FInferStructInfo", InferHintOnDeviceStructInfo) .set_attr("FPurity", Bool(true)); -Expr MakeHintOnDevice(Expr data, Device device) { +Expr MakeHintOnDevice(Expr data, Device device, String memory_scope = "global") { static const Op& op = Op::Get("relax.hint_on_device"); ObjectPtr attrs = make_object(); attrs->dev_type = static_cast(device.device_type); attrs->dev_id = device.device_id; + attrs->memory_scope = memory_scope; return Call(op, {data}, Attrs(attrs), {}); } -TVM_FFI_REGISTER_GLOBAL("relax.op.hint_on_device").set_body_typed(MakeHintOnDevice); +TVM_FFI_REGISTER_GLOBAL("relax.op.hint_on_device") + .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) { + if (args.size() == 3) { + *ret = + MakeHintOnDevice(args[0].cast(), args[1].cast(), args[2].cast()); + } else { + *ret = MakeHintOnDevice(args[0].cast(), args[1].cast()); + } + }); } // namespace relax } // namespace tvm diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index d7d50f8fa714..2df75738101b 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -348,6 +348,15 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl } }; + /* + * This is the case where the output VDevice defined by a customization pass. + * Like targets that supports mixed VDevices (like differed by memory_scope for Adreno) + * and have specialized derivation for output VDevice. + */ + if (call->sinfo_args.size() > 0) { + return get_vdevice(call->sinfo_args[0]); + } + auto lhs_vdevice = get_vdevice(lhs_sinfo); auto rhs_vdevice = get_vdevice(rhs_sinfo); @@ -357,6 +366,7 @@ inline Optional InferBinaryArithOpOutVDevice(const Call& call, const Bl if (!rhs_vdevice.defined() || !rhs_vdevice.value()->target.defined()) { return lhs_vdevice; } + if (lhs_vdevice.value() != rhs_vdevice.value()) { ctx->ReportFatal(Diagnostic::Error(call) << "TypeErorr: " diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 74ae8e9cbc5c..eb40a6a8e669 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -157,15 +157,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call, Optional shape1 = GetRef(x1_sinfo->shape.as()); Optional shape2 = GetRef(x2_sinfo->shape.as()); + // Lets handle sub indexing as long as primal dims are matching - if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { - if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { - if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) { - return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); - } - } else if (shape1.defined()) { - if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) { - return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + if ((layout1->layout.ndim() != layout1->layout.ndim_primal()) || + (layout2->layout.ndim() != layout2->layout.ndim_primal())) { + if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) { + if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape2.value()->values.size()), layout1->layout, + shape2.value()->values)) { + return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs)); + } + } else if (shape1.defined()) { + if (CanProveLayoutTransform(InitialLayout(shape1.value()->values.size()), layout2->layout, + shape1.value()->values)) { + return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs)); + } } } } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 3eb29a82e3d1..1e11f91ff26e 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -325,12 +325,54 @@ InferLayoutOutput InferLayoutConcat(const Call& call, const auto* attrs = call->attrs.as(); ICHECK(attrs != nullptr) << "Invalid Call"; + NLayout nlayout = GetNLayout(var_layout_map, call->args[0]); ICHECK(nlayout.IsNested()); ICHECK(nlayout.NestedArray()[0].IsLeaf()); int n_tensor = nlayout.NestedArray().size(); LayoutDecision layout = nlayout.NestedArray()[0].LeafValue(); + + // We may expect mix of sub indexed and regular layouts here + // Pick the first sub indexed layout and try to prove it for all tensors + // On any failre select first occuring regular layout for all + auto nlayout_array = nlayout.NestedArray(); + for (auto n_layout : nlayout_array) { + ICHECK(n_layout.IsLeaf()); + LayoutDecision in_layout = n_layout.LeafValue(); + if (in_layout->layout.ndim() != in_layout->layout.ndim_primal()) { + const auto* tuple_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tuple_sinfo != nullptr) + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_->GetTypeKey(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + StructInfo field_sinfo = tuple_sinfo->fields[i]; + const auto* field_tensor_sinfo = field_sinfo.as(); + ICHECK(field_tensor_sinfo != nullptr) + << call->op + << " expects the input to be a Tuple of Tensors. However, the given input is " + << call->args[0]->struct_info_; + auto t_sinfo = GetRef(field_tensor_sinfo); + Optional t_shape = GetRef(t_sinfo->shape.as()); + LayoutDecision curr_layout = nlayout_array[i].LeafValue(); + if (!CanProveLayoutTransform(curr_layout->layout, in_layout->layout, + t_shape.value()->values)) { + // Some tensor unhappy with sub indexed layout, lets pick first regular layout + for (auto pick_layout : nlayout_array) { + if (pick_layout.LeafValue()->layout.ndim() == + pick_layout.LeafValue()->layout.ndim_primal()) { + in_layout = pick_layout.LeafValue(); + break; + } + } + break; + } + } + layout = in_layout; + break; + } + } + Array input_layouts, output_layouts; for (int i = 0; i < n_tensor; ++i) { input_layouts.push_back(layout); diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index a0ac6fffb62c..f5b56d847a40 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -30,6 +30,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -60,11 +62,16 @@ bool KnowAllShapeValues(const StructInfo& sinfo) { class LegalizeMutator : public ExprMutator { public: explicit LegalizeMutator(const IRModule& mod, const Optional>& cmap, - bool enable_warning) + const Optional> skip_ops, bool enable_warning) : ExprMutator(mod), mod_(std::move(mod)), enable_warning_(enable_warning) { if (cmap) { cmap_ = std::move(cmap.value()); } + if (skip_ops.defined()) { + for (const auto name : skip_ops.value()) { + skip_ops_.insert(Op::Get(name)); + } + } } IRModule Transform() { @@ -237,6 +244,10 @@ class LegalizeMutator : public ExprMutator { } auto op = GetRef(op_node); + if (skip_ops_.find(op) != skip_ops_.end()) { + return visited_call; + } + bool shapes_are_known_if_required = [&]() -> bool { bool requires_arg_shapes = requires_arg_shapes_map.get(op, Bool(true))->value; if (!requires_arg_shapes) { @@ -385,16 +396,21 @@ class LegalizeMutator : public ExprMutator { * legalization function is not registered. */ bool enable_warning_; + /*! + * \brief List of ops to be skipped from legalization + */ + std::set skip_ops_; }; namespace transform { -Pass LegalizeOps(Optional> cmap, bool enable_warning) { +Pass LegalizeOps(Optional> cmap, Optional> skip_ops, + bool enable_warning) { auto pass_func = [=](IRModule mod, PassContext pc) { bool apply_legalize_ops = pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; if (apply_legalize_ops) { - mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform(); } return mod; }; diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index ee4773fb3a24..70eb9f6a71ef 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -55,6 +55,7 @@ class VDeviceLookup { ICHECK(attrs); int32_t device_type = attrs->dev_type; int32_t device_id = attrs->dev_id; + String memory_scope = attrs->memory_scope; CHECK(opt_vdevices_.defined()) << "ValueError: The target VDevice in the GlobalInfos was not found."; @@ -65,7 +66,8 @@ class VDeviceLookup { for (auto vdevice : vdevices) { int dev_type = vdevice->target->GetTargetDeviceType(); - if (dev_type == device_type && vdevice->vdevice_id == device_id) { + if (dev_type == device_type && vdevice->vdevice_id == device_id && + memory_scope == vdevice->memory_scope) { return vdevice; } } diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc new file mode 100644 index 000000000000..0e4ce22b537f --- /dev/null +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/specialize_tir_params.cc + * \brief Update PrimFunc buffers based on updated scope (or structure) info. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/tensor/manipulate.h" +#include "infer_layout_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using tvm::tir::Buffer; + +static Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { + auto shape = tensor_sinfo->GetShape(); + ICHECK(shape.defined()); + return shape.value(); +} + +class SpecializeTIRCallArgs : ExprMutator { + public: + IRModule Run(IRModule mod) { + mod_ = mod; + for (const auto& [gv, func] : mod->functions) { + if (func->IsInstance()) { + const auto& base_func = mod->Lookup(gv); + // Only non primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + continue; + } + relax::Function update_func = Downcast(VisitExpr(func)); + updates_->Add(gv, update_func); + } + } + mod_.CopyOnWrite()->Update(updates_); + return mod_; + } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const CallNode* call_node) override { + auto call = Downcast(ExprMutator::VisitExpr_(call_node)); + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op == call_tir_op) { + return SpecializeTirPrimFunc(call); + } + return call; + } + + private: + Expr SpecializeTirPrimFunc(Call call) { + auto gv = Downcast(call->args[0]); + auto pfunc = Downcast(mod_->Lookup(gv)); + auto args = Downcast(call->args[1])->fields; + Map> param_map; + + for (size_t i = 0; i < args.size(); ++i) { + auto sinfo = GetStructInfo(args[i]); + CHECK(sinfo->IsInstance()) + << "Expected Tensor struct Info for call :" << call->op; + auto tensor_sinfo = Downcast(sinfo); + CHECK(tensor_sinfo->shape.defined()) << "Shape undefined for call:" << call->args[0]; + String scope = "global"; + if (tensor_sinfo->vdevice.defined()) { + scope = tensor_sinfo->vdevice.value()->memory_scope; + } + String name; + if (args[i]->IsInstance()) { + name = Downcast(args[i])->name_hint(); + } else { + name = std::string({static_cast('A' + i)}); + } + + const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), + tensor_sinfo->dtype, name, scope); + param_map.Set(pfunc->params[i], buffer); + } + String scope = "global"; + auto out_sinfo = call->sinfo_args[0]; + if (out_sinfo->IsInstance()) { + auto sinfo = Downcast(out_sinfo); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); + } else { + ICHECK(out_sinfo->IsInstance()) + << "Expect output struct info of call_tir to be either TupleStructInfo or " + "TensorStructInfo, but got " + << out_sinfo; + + const auto& tuple_sinfo = Downcast(out_sinfo); + Array sinfo_fields; + int index = 0; + for (const auto& si : tuple_sinfo->fields) { + ICHECK(si->IsInstance()) + << "Fields of TupleStructInfo must be TensorStructInfo for call_tir " + "output structinfo, but got " + << si; + auto sinfo = Downcast(si); + if (sinfo->vdevice.defined()) { + scope = sinfo->vdevice.value()->memory_scope; + } + const Buffer& buffer = + tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + param_map.Set(pfunc->params[args.size() + index], buffer); + index++; + } + } + + auto new_pfunc = Specialize(pfunc, param_map); + for (const auto& [var, buffer] : new_pfunc->buffer_map) { + auto* ptr = buffer->data->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + } + auto new_prim_func = WithAttr(new_pfunc, "scoped", Integer(1)); + updates_->Add(gv, new_prim_func); + return call; + } + IRModule mod_; + IRModule updates_; +}; + +namespace transform { + +Pass SpecializePrimFuncBasedOnCallSite() { + auto pass_func = [=](IRModule mod, PassContext pc) { + return relax::SpecializeTIRCallArgs().Run(mod); + }; + return CreateModulePass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"SpecializePrimFuncBasedOnCallSite", + /*required=*/{}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.transform.SpecializePrimFuncBasedOnCallSite") + .set_body_typed(SpecializePrimFuncBasedOnCallSite); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index edd953e3126e..8505ad72b5b6 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -386,7 +386,7 @@ inline String GetCodegenName(const std::string& composite_name) { inline int GetDeviceIndex(const IRModule& mod, const VDevice& vdevice) { Array vdevices = mod->global_infos["vdevice"]; for (int i = 0; i < static_cast(vdevices.size()); ++i) { - if (vdevices[i] == vdevice) { + if (vdevices[i].same_as(vdevice)) { return i; } } diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 5ee90e29b009..0e2f33285ab8 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -313,7 +313,7 @@ class CLMLRuntime : public JSONRuntimeBase { const auto f = tvm::ffi::Function::GetGlobal("runtime.SaveParams"); if (f.has_value()) { - std::string dump_bytes = (*f)(dump_tensors); + std::string dump_bytes = (*f)(dump_tensors).cast(); std::ostringstream oss; /*TODO(Siva) HEX encoding doubles the size, look for better encode that can cross the RPC. */ for (size_t i = 0; i < dump_bytes.size(); ++i) { @@ -464,8 +464,8 @@ class CLMLRuntime : public JSONRuntimeBase { cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); int dtype_size = cl_dtype == CL_FLOAT ? 4 : 2; void* tmpptr = reinterpret_cast(malloc(isize * dtype_size)); - TVMArrayCopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), - isize * dtype_size); + NDArray::CopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + isize * dtype_size); CopyDataToCLMLTensor(layer_.inputs[nid], tmpptr); free(tmpptr); } @@ -479,7 +479,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -500,7 +500,7 @@ class CLMLRuntime : public JSONRuntimeBase { if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = tvm::ffi::Function::GetGlobal(std::string("profiling.timer.opencl")); - t = f->operator()(cws->tentry->device); + t = f->operator()(cws->tentry->device).cast(); t->Start(); queue = CLML_QUEUE; evts.resize(evts.size() + 1); @@ -551,8 +551,8 @@ class CLMLRuntime : public JSONRuntimeBase { void* tmpptr = reinterpret_cast(malloc(osize * dtype_size)); CopyDataFromCLMLTensor(layer_.outputs[0], tmpptr); - TVMArrayCopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), - osize * dtype_size); + NDArray::CopyFromBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), + osize * dtype_size); free(tmpptr); } } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index f03a83a929ec..f783cd7d5150 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -96,6 +96,26 @@ void NDArray::CopyToBytes(const DLTensor* handle, void* data, size_t nbytes, DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); } +void NDArray::CopyFromBytes(const DLTensor* handle, void* data, size_t nbytes, + TVMStreamHandle stream) { + size_t arr_size = GetDataSize(*handle); + ICHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + ICHECK(ffi::IsContiguous(*handle)) << "ArrayCopyToBytes only support contiguous array for now"; + + DLTensor from; + from.data = const_cast(data); + from.device = Device{kDLCPU, 0}; + from.ndim = handle->ndim; + from.dtype = handle->dtype; + from.shape = handle->shape; + from.strides = nullptr; + from.byte_offset = 0; + + DeviceAPI::Get(handle->device)->CopyDataFromTo(&from, const_cast(handle), stream); + // Synchronize in case data become unavailable later. + DeviceAPI::Get(handle->device)->StreamSync(handle->device, stream); +} + NDArray NDArray::Empty(ffi::Shape shape, DLDataType dtype, Device dev, Optional mem_scope) { struct DeviceAPIAlloc { void AllocData(DLTensor* tensor, ffi::Optional mem_scope) { diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 9c4efadc2b83..d1d1c2cc2428 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -254,7 +254,11 @@ Optional PrintHintOnDevice(const relax::Call& n, const ObjectPath& n_p, ICHECK(n->attrs.defined()); if (n->attrs.as()) { AttrPrinter(n_p->Attr("attrs"), d, &kwargs_keys, &kwargs_values)(n->attrs); + ExprDoc scope_val = kwargs_values.back(); + kwargs_keys.pop_back(); + kwargs_values.pop_back(); args.push_back(Relax(d, "device")->Call({}, kwargs_keys, kwargs_values)); + args.push_back(scope_val); } return Relax(d, "hint_on_device")->Call(args); } @@ -277,7 +281,8 @@ Optional PrintToVDevice(const relax::Call& n, const ObjectPath& n_p, int dev_index = FindVDeviceIndexByTargetKind(vdev, d); kwargs_keys.push_back("dst_vdevice"); kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("dst_vdevice"))); + LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index) + ":" + vdev->memory_scope, + n_p->Attr("dst_vdevice"))); } return Relax(d, "to_vdevice")->Call(args, kwargs_keys, kwargs_values); } diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 7043952c7c15..66643fc2e9fd 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -126,8 +126,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_keys.push_back("vdevice"); std::string dev_kind = n->vdevice.value()->target->kind->name; int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d); - kwargs_values.push_back( - LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), n_p->Attr("vdevice"))); + kwargs_values.push_back(LiteralDoc::Str( + dev_kind + ":" + std::to_string(dev_index) + ":" + n->vdevice.value()->memory_scope, + n_p->Attr("vdevice"))); } if (args.empty() && kwargs_keys.empty()) { return Relax(d, "Tensor"); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 9d23661bace3..b52689ab42db 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2152,6 +2152,9 @@ TVM_FFI_REGISTER_GLOBAL("tir.schedule.GetLoopIterType") return "O"; } }); +TVM_FFI_REGISTER_GLOBAL("tir.schedule.HasIfThenElse").set_body_typed([](const Stmt& stmt) -> bool { + return HasIfThenElse(stmt); +}); } // namespace tir } // namespace tvm diff --git a/tests/python/relax/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py new file mode 100644 index 000000000000..24b4cf66b888 --- /dev/null +++ b/tests/python/relax/adreno/test_transform_annotate_custom_scope.py @@ -0,0 +1,1204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateScope(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, scope_info: dict) -> None: + self.scope_info = scope_info + self.matched = True + + def visit(self, mod: IRModule) -> None: + """Entry point""" + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + return self.matched + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + # if call.args[0].name_hint in self.scope_info: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + call_mem_scope = ( + "global" if not arg_sinfo.vdevice else arg_sinfo.vdevice.memory_scope + ) + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][0][idx] + ), f"Scope mismatched for argument {idx} in {call.args[0].name_hint}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + call_mem_scope = ( + "global" + if not call.sinfo_args[0].vdevice + else call.sinfo_args[0].vdevice.memory_scope + ) + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][1][0] + ), f"Scope mismatched for return scope: {call.args[0].name_hint}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + call_mem_scope = "global" if not sinfo.vdevice else sinfo.vdevice.memory_scope + assert ( + call_mem_scope == self.scope_info[call.args[0].name_hint][1][idx] + ), f"Scope mismatched for return scope for {idx} in {call.args[0].name_hint}" + + +def verify(mod, expected): + tgt = tvm.target.Target("opencl --device=adreno", host="llvm") + skip_ops = [ + "relax.nn.conv2d", + "relax.nn.max_pool2d", + "relax.nn.adaptive_avg_pool2d", + # "relax.nn.layer_norm", + ] + with tgt: + mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} + mod = tvm.relax.transform.ConvertLayout(desired_layouts)(mod) + mod = tvm.relax.transform.Normalize()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.LegalizeOps(skip_ops=skip_ops)(mod) + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.backend.adreno.transform.AnnotateCustomMemoryScope(tgt)(mod) + # There is a possibility of some skipped ops above might not use 5D layouts. + mod = tvm.relax.transform.LegalizeOps()(mod) + mod = tvm.relax.transform.LegalizeOps( + {"relax.nn.conv2d": legalize_adreno.conv2d_NCHWc_OIHWo}, + )(mod) + # Lets get pattern info for newly legalized ops + mod = tvm.relax.transform.AnnotateTIROpPattern()(mod) + mod = tvm.relax.transform.FoldConstant()(mod) + mod = tvm.relax.transform.FuseOps()(mod) + mod = tvm.relax.transform.FuseTIR()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(mod) + mod = tvm.relax.transform.DeadCodeElimination()(mod) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(mod) + mod = tvm.relax.transform.Normalize()(mod) + + ValidateScope(expected).visit(mod) + + +def test_conv2d(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 64, 56, 56), "float32"), w: R.Tensor((32, 64, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 32, 54, 54), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-nhwc"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": (["global.texture-nhwc", "global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NCHW_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d( + x, + w, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_NHWC_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 28, 28, 16), "float32"), w: R.Tensor((4, 3, 3, 16), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 26, 26, 4), "float32") = R.nn.conv2d( + x, + w, + data_layout="NHWC", + kernel_layout="OHWI", + out_dtype="float32", + ) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def _test_conv2d_symbolic_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor("float32", ndim=4), w: R.Tensor("float32", ndim=4) + ) -> R.Tensor("float32", ndim=4): + with R.dataflow(): + N, C, H, W = T.int64(), T.int64(16), T.int64(), T.int64() + Nw, Cw, Hw, Ww = T.int64(4), T.int64(16), T.int64(), T.int64() + lv0 = R.match_cast(x, R.Tensor((N, C, H, W), "float32")) + lv1 = R.match_cast(w, R.Tensor((Nw, Cw, Hw, Ww), "float32")) + gv: R.Tensor( + (N, T.int64(4), H + T.int64(1) - Hw, W + T.int64(1) - Ww), "float32" + ) = R.nn.conv2d(lv0, lv1, out_dtype="float32") + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_relu_conv2d_relu_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + x0: R.Tensor((2, 16, 28, 28), "float32") = R.nn.relu(x) + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x0, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + R.output(gv2) + return gv2 + + Expected = { + "relu": (["global"], ["global"]), + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + + verify(Input, Expected) + + +def test_conv2d_relu_tanh_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.tanh(gv2) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_relu_tir_tanh": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_add_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_fma_relu_conv2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 4, 28, 28), "float32"), + w: R.Tensor((4, 4, 3, 3), "float32"), + scale: R.Tensor((2, 4, 26, 26), dtype="float32"), + bias: R.Tensor((2, 4, 26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.ewise_fma(gv, scale, bias) + gv3: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv2) + gv4: R.Tensor((2, 4, 24, 24), "float32") = R.nn.conv2d(gv3, w, out_dtype="float32") + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "relu": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4), "float32") = R.sum(gv, axis=[2, 3]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_keepdims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 1, 1), "float32") = R.sum(gv, axis=[2, 3], keepdims=True) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_sum_reduce_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 26), "float32") = R.sum(gv, axis=[1, 2]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "sum": (["global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_transpose_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv, axes=[3, 2, 1, 0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "transpose": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_expand_dims_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=6): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 1, 4, 1, 26, 26), "float32") = R.expand_dims(gv, axis=(-3, 1)) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "expand_dims": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_squeeze_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=3): + with R.dataflow(): + gv: R.Tensor((1, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((4, 26, 26), "float32") = R.squeeze(gv, axis=[0]) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "squeeze": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_strided_slice_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 2, 9, 7), dtype="float32") = R.strided_slice( + gv, begin=[0, 0, 0], end=[4, 26, 26], strides=[2, 3, 4], axes=[1, 2, 3] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "strided_slice": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + R.output(gv3) + return gv3 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + R.output(gv4) + return gv4 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_relu_concat_split_transpose_concat_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32")): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(gv) + gv3: R.Tensor((2, 8, 26, 26), "float32") = R.concat((gv, gv2), axis=1) + gv4 = R.split(gv3, indices_or_sections=2, axis=1) + gv5: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[0], axes=[3, 2, 1, 0]) + gv6: R.Tensor((26, 26, 4, 2), "float32") = R.permute_dims(gv4[1], axes=[3, 2, 1, 0]) + gv7: R.Tensor((26, 26, 8, 2), "float32") = R.concat((gv5, gv6), axis=2) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_relu_concatenate_split": (["global.texture-weight"], ["global", "global"]), + "te_layout_transform2": (["global"], ["global"]), + "te_layout_transform3": (["global"], ["global"]), + "fused_transpose_transpose_concatenate1": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_maxpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.max_pool2d( + gv, + pool_size=[2, 2], + strides=[2, 2], + padding=[0, 0], + layout="NCHW", + out_layout="NCHW", + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d_opencl": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_avgpool2d_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.adaptive_avg_pool2d(gv, output_size=[13, 13], layout="NCHW") + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "adaptive_avg_pool2d_opencl": (["global.texture-weight"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_softmax_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2 = R.nn.softmax(gv, axis=1) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "softmax": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_layernorm_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + gamma: R.Tensor((26, 26), dtype="float32"), + beta: R.Tensor((26, 26), dtype="float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.nn.layer_norm( + gv, gamma, beta, axes=[-2, -1] + ) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "layer_norm": (["global", "global", "global"], ["global"]), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_broadcast_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), + w: R.Tensor((4, 16, 3, 3), "float32"), + bias: R.Tensor((26, 26), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, bias) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "add": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_binary_ewise_scalar_sub_indexed(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "float32") + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + gv2: R.Tensor((2, 4, 26, 26), "float32") = R.add(gv, R.const(1, "float32")) + R.output(gv2) + return gv2 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_residual_block(): + """ + - some kind of residual block followed by convolution to have texture after residual block + - scalar data type verification which should be mapped to global memory scope + layout_transform (NCHW->NCHW4c) + | <- buffer + conv2d (1) <- to get textures as output + / \ + conv2d (2) | + \ / + add <- add should be fused into conv2d (2) + multiply to scalar <- buffer to the input of multiply scalar value + relu + | <- texture in intermediate tensor + conv2d (3) + relu + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 2, 2), "float32"), + w2: R.Tensor((32, 32, 1, 1), "float32"), + w3: R.Tensor((32, 32, 2, 2), "float32"), + bias: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[1, 1], out_dtype="float32") + bias_1 = R.multiply(bias, R.const(0.15, "float32")) + gv4 = R.add(gv3, bias_1) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv5, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.nn.relu(gv6) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "multiply": (["global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo1_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo2_opencl_relu1": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_fallback_to_buffer_conv2d(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform4": (["global"], ["global"]), + "conv2d": (["global", "global"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + "concatenate": (["global", "global"], ["global"]), + } + verify(Input, Expected) + + +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_relu": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform3": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo1_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "te_layout_transform4": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "concatenate": (["global.texture-weight", "global.texture-weight"], ["global"]), + "te_layout_transform5": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "max_pool2d_opencl": (["global.texture-weight"], ["global.texture-weight"]), + "te_layout_transform2": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo2_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo1_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global.texture-weight"], + ), + "fused_conv2d_NCHWc_OIHWo3_opencl_add": ( + ["global.texture-weight", "global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform3": (["global"], ["global"]), + } + verify(Input, Expected) + + +def test_injective_inputs1(): + """ + Input + / \ + / | + | / + conv2d (1) / + | / + conv2d (2) mean / + / \ / + | | \ / + | | (3) add + | | | + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv1) + gv = R.add(ad3, ad2) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], + ["global"], + ), + } + verify(Input, Expected) + + +def test_injective_nwo_inputs2(): + """ + Input + / \ + | \ + conv2d \ + | / + conv2d mean / + / \ / + add | \ | + | | \ | + | | \ / + | | (3) add + | | | + | \ / + | \ / + \ mul + \ / + add + + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((1, 4, 40, 40), "float32"), + w1: R.Tensor((4, 4, 3, 3), "float32"), + w2: R.Tensor((4, 4, 3, 3), "float32"), + w3: R.Tensor((4, 4, 3, 3), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + mean = R.mean(x, axis=1, keepdims=True) + conv1 = R.nn.conv2d( + x, w1, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + conv2 = R.nn.conv2d( + conv1, w2, padding=[1, 1, 1, 1], strides=[1, 1], out_dtype="float32" + ) + ad3 = R.add(conv1, conv2) + ad1 = R.add(mean, conv1) + ad2 = R.multiply(ad1, conv2) + gv = R.add(ad2, ad3) + R.output(gv) + return gv + + Expected = { + "te_layout_transform": (["global"], ["global.texture-weight"]), + "te_layout_transform1": (["global"], ["global.texture-weight"]), + "conv2d_NCHWc_OIHWo_opencl": ( + ["global.texture-weight", "global.texture-weight"], + ["global"], + ), + "te_layout_transform2": (["global"], ["global"]), + "fused_mean_add1": (["global", "global"], ["global"]), + "fused_conv2d_NCHWc_OIHWo_opencl_add_multiply_add": ( + [ + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + "global.texture-weight", + ], + ["global"], + ), + } + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py new file mode 100644 index 000000000000..b461f39dd744 --- /dev/null +++ b/tests/python/relax/adreno/test_transform_fold_vdevice_scope_change.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.ir.module import IRModule + + +def verify(input, expected): + mod = tvm.relax.backend.adreno.transform.FoldVDeviceScopeChange()(input) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_maxpool2d_scope_folding(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + @I.ir_module + class Expected: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" + ), + ) + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input, Expected) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index e3274aea886a..b0bec5e858af 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -17,6 +17,7 @@ import pytest import tvm +import tvm.testing from tvm import relax import tvm.script diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 262e37b91b1b..83b81a6898a7 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -206,10 +206,9 @@ def main( lv2: R.Tensor((N, H, W, C), dtype="float32") = R.match_cast( lv0, R.Tensor((N, H, W, C), dtype="float32") ) - lv3: R.Tensor((N, C, H, W), dtype="float32") = R.permute_dims( - lv2, axes=[0, 3, 1, 2] - ) - gv: R.Tensor(dtype="float32", ndim=4) = R.add(lv3, w) + lv3: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(w, axes=[0, 2, 3, 1]) + lv4: R.Tensor(dtype="float32", ndim=4) = R.add(lv2, lv3) + gv: R.Tensor(dtype="float32", ndim=4) = R.permute_dims(lv4, axes=[0, 3, 1, 2]) R.output(gv) return gv @@ -4585,5 +4584,413 @@ def main( verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) +def test_conv2d_conv2d_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) + \ / <- concat does support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((8, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((8, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 40, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 2, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv5, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv6: R.Tensor((2, 10, 10, 10, 4), dtype="float32") = R.concat((gv3, gv6), axis=1) + gv7: R.Tensor((2, 40, 10, 10), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_conv2d_conv2d_callback_to_buffer_conv2d_concat(): + """ + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (1) <- textures as output + / \ + conv2d (2) conv2d (3) <- conv2d (2) emits texture, conv2d (3) emits buffer + \ / <- concat shouldn't support textures here + concatenation + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((96, 32, 2, 2), "float32"), + w2: R.Tensor((32, 96, 2, 2), "float32"), + w3: R.Tensor((5, 96, 2, 2), "float32"), + bias1: R.Tensor((1, 96, 1, 1), "float32"), + bias2: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[2, 2], out_dtype="float32") + gv1 = R.add(gv, bias1) + gv2 = R.nn.relu(gv1) + gv3 = R.nn.conv2d(gv2, w2, strides=[2, 2], out_dtype="float32") + gv4 = R.add(gv3, bias2) + gv5 = R.nn.relu(gv4) + gv6 = R.nn.conv2d(gv2, w3, strides=[2, 2], out_dtype="float32") + gv7 = R.concat((gv3, gv6), axis=1) + R.output(gv7) + return gv7 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((96, 32, 2, 2), dtype="float32"), + w2: R.Tensor((32, 96, 2, 2), dtype="float32"), + w3: R.Tensor((5, 96, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 96, 1, 1), dtype="float32"), + bias2: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 37, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((24, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv2: R.Tensor((1, 24, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv1: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.add(gv, lv2) + gv2: R.Tensor((2, 24, 20, 20, 4), dtype="float32") = R.nn.relu(gv1) + lv3: R.Tensor((8, 96, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.conv2d( + gv2, + lv3, + strides=[2, 2], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv4: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv4: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.add(gv3, lv4) + gv5: R.Tensor((2, 8, 10, 10, 4), dtype="float32") = R.nn.relu(gv4) + lv5: R.Tensor((2, 96, 20, 20), dtype="float32") = R.layout_transform( + gv2, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 5, 10, 10), dtype="float32") = R.nn.conv2d( + lv5, + w3, + strides=[2, 2], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv6: R.Tensor((2, 32, 10, 10), dtype="float32") = R.layout_transform( + gv3, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + gv7: R.Tensor((2, 37, 10, 10), dtype="float32") = R.concat((lv6, gv6), axis=1) + R.output(gv7) + return gv7 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + +def test_pooling_branching_texture_params(): + """ + Verification of the pooling and many branches having textures + layout_transform (NCHW->NCHW4c) + | <- texture + conv2d (0) <- to get textures + | <- textures + pooling + / \ \ <- textures + conv2d (1) conv2d (2) conv2d (3) + \ / | + add | <- to have the only one output, will be fused + \ / + add <- to have the only one output, will be fused + | <- buffer + layout_transform (NCHW4c->NCHW) + """ + + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), "float32"), + w1: R.Tensor((32, 32, 1, 1), "float32"), + w2: R.Tensor((32, 32, 2, 2), "float32"), + w3: R.Tensor((32, 32, 1, 1), "float32"), + w4: R.Tensor((32, 32, 2, 2), "float32"), + bias1: R.Tensor((1, 32, 1, 1), "float32"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + gv = R.nn.conv2d(x, w1, strides=[1, 1], out_dtype="float32") + gv1 = R.nn.max_pool2d(gv, pool_size=[2, 2], strides=[2, 2]) + gv2 = R.nn.conv2d( + gv1, w2, padding=[0, 0, 1, 1], strides=[1, 1], out_dtype="float32" + ) + gv3 = R.add(gv2, bias1) + gv4 = R.nn.relu(gv3) + gv5 = R.nn.conv2d( + gv1, w3, padding=[0, 0, 0, 0], strides=[1, 1], out_dtype="float32" + ) + gv6 = R.nn.conv2d( + gv1, w4, padding=[0, 1, 1, 0], strides=[1, 1], out_dtype="float32" + ) + gv7 = R.nn.relu(gv6) + gv8 = R.add(gv2, gv5) + gv9 = R.add(gv8, gv6) + R.output(gv9) + return gv9 + + @I.ir_module + class Expected_NCHW4c: + @R.function + def main( + x: R.Tensor((2, 32, 40, 40), dtype="float32"), + w1: R.Tensor((32, 32, 1, 1), dtype="float32"), + w2: R.Tensor((32, 32, 2, 2), dtype="float32"), + w3: R.Tensor((32, 32, 1, 1), dtype="float32"), + w4: R.Tensor((32, 32, 2, 2), dtype="float32"), + bias1: R.Tensor((1, 32, 1, 1), dtype="float32"), + ) -> R.Tensor((2, 32, 20, 20), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.layout_transform( + x, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + lv1: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv: R.Tensor((2, 8, 40, 40, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv1: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.max_pool2d( + gv, pool_size=[2, 2], strides=[2, 2], layout="NCHW4c", out_layout="NCHW4c" + ) + lv2: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w2, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv2: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv2, + padding=[0, 0, 1, 1], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv3: R.Tensor((1, 8, 1, 1, 4), dtype="float32") = R.layout_transform( + bias1, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4), index_dtype="int32" + ), + ) + gv3: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, lv3) + gv4: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv3) + lv4: R.Tensor((8, 32, 1, 1, 4), dtype="float32") = R.layout_transform( + w3, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv5: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv4, + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + lv5: R.Tensor((8, 32, 2, 2, 4), dtype="float32") = R.layout_transform( + w4, + index_map=T.index_map( + lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4), index_dtype="int32" + ), + ) + gv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.conv2d( + gv1, + lv5, + strides=[1, 1], + padding=[0, 1, 1, 0], + data_layout="NCHW4c", + kernel_layout="OIHW4o", + out_layout="NCHW4c", + out_dtype="float32", + ) + gv7: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.nn.relu(gv6) + gv8: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv2, gv5) + lv6: R.Tensor((2, 8, 20, 20, 4), dtype="float32") = R.add(gv8, gv6) + gv9: R.Tensor((2, 32, 20, 20), dtype="float32") = R.layout_transform( + lv6, + index_map=T.index_map( + lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3), index_dtype="int32" + ), + ) + R.output(gv9) + return gv9 + + verify(Input, Expected_NCHW4c, {"relax.nn.conv2d": ["NCHW4c", "OIHW4o"]}) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py new file mode 100644 index 000000000000..d92570025fce --- /dev/null +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import ir as I, relax as R, tir as T +from tvm.relax.transform.legalize_ops import adreno as legalize_adreno +from tvm.ir.module import IRModule +from tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor + + +@visitor +class ValidateBufferScopes(PyExprVisitor): # pylint: disable=abstract-method + def __init__(self, is_matched: bool) -> None: + self.is_matched = is_matched + + def visit(self, mod: IRModule) -> None: + """Entry point""" + self.mod = mod + for key, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op.name == "relax.call_tir": + pfunc = self.mod[call.args[0]] + if not self.is_matched: + # All scopes should be global in before pass + for _, buf in pfunc.buffer_map.items(): + assert ( + "global" == buf.data.type_annotation.storage_scope + ), f"expected to be global scoped, but got {val.data.type_annotation.storage_scope}" + else: + for idx, arg in enumerate(call.args[1]): + arg_sinfo = arg.struct_info + assert isinstance( + arg_sinfo, relax.TensorStructInfo + ), f"Expected TensorStructInfo but git {type(arg_sinfo)}" + buf = pfunc.buffer_map[pfunc.params[idx]] + assert ( + arg_sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {arg_sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + buf = pfunc.buffer_map[pfunc.params[-1]] + assert ( + call.sinfo_args[0].vdevice.memory_scope + == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {call.sinfo_args[0].vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + else: + assert isinstance( + call.sinfo_args[0], relax.TupleStructInfo + ), f"Expected TupleStructInfo but git {type(call.sinfo_args[0])}" + for idx, sinfo in enumerate(call.sinfo_args[0].fields): + buf = pfunc.buffer_map[pfunc.params[len(call.args[1]) + idx]] + assert ( + sinfo.vdevice.memory_scope == buf.data.type_annotation.storage_scope + ), f"scope mismatched after specialization {sinfo.vdevice.memory_scope} vs {buf.data.type_annotation.storage_scope}" + + +def verify(input): + ValidateBufferScopes(False).visit(input) + mod = tvm.relax.transform.SpecializePrimFuncBasedOnCallSite()(input) + ValidateBufferScopes(True).visit(mod) + + +def test_single_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def max_pool2d_opencl( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + pool_max: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4), T.int64(2), T.int64(2) + ): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads( + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32( + -340282346638528859811704183484516925440.0 + ) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + gv[ + v_ax0, + v_ax1, + v_ax2 * T.int64(2) + v_rv0, + v_ax3 * T.int64(2) + v_rv1, + v_ax4, + ], + ) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(26), T.int64(26)): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2 = T.axis.remap("SSSS", [self, i0, i1, i2]) + T.reads(x[v_self, v_i0, v_i1, v_i2]) + T.writes( + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] + ) + te_layout_transform[ + v_self, v_i0 // T.int64(4), v_i1, v_i2, v_i0 % T.int64(4) + ] = x[v_self, v_i0, v_i1, v_i2] + + @T.prim_func(private=True) + def te_layout_transform2( + lv2: T.Buffer( + (T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(13), T.int64(13)), "float32" + ), + ): + # with T.block("root"): + for self, i0, i1, i2, i3 in T.grid( + T.int64(2), T.int64(1), T.int64(13), T.int64(13), T.int64(4) + ): + with T.block("te_layout_transform"): + v_self, v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSSS", [self, i0, i1, i2, i3]) + T.reads(lv2[v_self, v_i0, v_i1, v_i2, v_i3]) + T.writes(te_layout_transform[v_self, v_i3, v_i1, v_i2]) + te_layout_transform[v_self, v_i3, v_i1, v_i2] = lv2[ + v_self, v_i0, v_i1, v_i2, v_i3 + ] + + @R.function + def main( + x: R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"): # noqa: F722 + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv2 = R.call_tir( + cls.max_pool2d_opencl, + (lv,), + out_sinfo=R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv5: R.Tensor( + (2, 1, 13, 13, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = R.to_vdevice(lv2, dst_vdevice="opencl:1:global") + gv2 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 13, 13), dtype="float32", vdevice="opencl:1:global"), + ) + R.output(gv2) + return gv2 + + verify(Input) + + +def test_multi_arg_return(): + @I.ir_module + class Input: + I.module_global_infos( + { + "vdevice": [ + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global.texture-weight"), + I.vdevice({"device": "adreno", "kind": "opencl"}, 0, "global"), + ] + } + ) + + @T.prim_func(private=True) + def conv2d_NCHWc_OIHWo_opencl( + lv: T.Buffer((T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32"), + lv1: T.Buffer((T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32"), + conv2d_NCHWc_OIHWo: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + conv2d_NCHWc_OIHWo[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def fused_relu_concatenate_split( + gv: T.Buffer((T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32"), + T_split_sections_intermediate: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + T_split_sections_intermediate_1: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + ): + T_split_sections_intermediate[0, 0, 0, 0, 0] = T.float32(0.0) + T_split_sections_intermediate_1[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform( + x: T.Buffer((T.int64(2), T.int64(16), T.int64(28), T.int64(28)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(28), T.int64(28), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform1( + w: T.Buffer((T.int64(4), T.int64(16), T.int64(3), T.int64(3)), "float32"), + te_layout_transform: T.Buffer( + (T.int64(1), T.int64(16), T.int64(3), T.int64(3), T.int64(4)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0, 0] = T.float32(0.0) + + @T.prim_func(private=True) + def te_layout_transform2( + lv3: T.Buffer( + (T.int64(2), T.int64(1), T.int64(26), T.int64(26), T.int64(4)), "float32" + ), + te_layout_transform: T.Buffer( + (T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float32" + ), + ): + te_layout_transform[0, 0, 0, 0] = T.float32(0.0) + + @R.function + def main( + x: R.Tensor((2, 16, 28, 28), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + w: R.Tensor((4, 16, 3, 3), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ) -> R.Tuple( + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), # noqa: F722 + ): + cls = Input + with R.dataflow(): + lv = R.call_tir( + cls.te_layout_transform, + (x,), + out_sinfo=R.Tensor( + (2, 4, 28, 28, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv1 = R.call_tir( + cls.te_layout_transform1, + (w,), + out_sinfo=R.Tensor( + (1, 16, 3, 3, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + gv = R.call_tir( + cls.conv2d_NCHWc_OIHWo_opencl, + (lv, lv1), + out_sinfo=R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:0:global.texture-weight" + ), + ) + lv_1 = R.call_tir( + cls.fused_relu_concatenate_split, + (gv,), + out_sinfo=[ + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + R.Tensor((2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global"), + ], + ) + lv3: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[0] + lv4 = R.call_tir( + cls.te_layout_transform2, + (lv3,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + lv5: R.Tensor( + (2, 1, 26, 26, 4), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ) = lv_1[1] + lv6 = R.call_tir( + cls.te_layout_transform2, + (lv5,), + out_sinfo=R.Tensor((2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global"), + ) + gv4: R.Tuple( + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + R.Tensor( + (2, 4, 26, 26), dtype="float32", vdevice="opencl:1:global" # noqa: F722 + ), + ) = (lv4, lv6) + R.output(gv4) + return gv4 + + verify(Input) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 694e7a688cf7..c0ff78ca4c6b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -439,5 +439,20 @@ def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): _check(foo, bb.get()["foo"]) +def test_hint_on_device_scoped(): + @R.function + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + r = R.hint_on_device(x, R.device(4, 2), "global.texture") + return r + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + tensor = bb.emit(relax.op.hint_on_device(x, R.opencl(2), "global.texture")) + bb.emit_func_output(tensor) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index e5775c10ec34..91886281806d 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -39,7 +39,7 @@ echo set\(USE_OPENCL ON\) >> config.cmake fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_CPP_RPC ON\) >> config.cmake -echo set\(USE_CPP_RTVM ON\) >> config.cmake +#echo set\(USE_CPP_RTVM ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_KALLOC_ALIGNMENT 32\) >> config.cmake @@ -62,4 +62,4 @@ cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain. -DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" \ -DMACHINE_NAME="aarch64-linux-gnu" .. -make -j$(nproc) tvm_rpc rtvm opencl-cpptest +make -j$(nproc) tvm_rpc opencl-cpptest