Skip to content

Commit

Permalink
[MetaSchedule][Hexagon] Add postproc for verifying VTCM usage (apache…
Browse files Browse the repository at this point in the history
…#13538)

* add new postproc VerifyVTCMLimit

* remove pass

* add test

* add doc, missing file

* Add back VectorizeLoop in prereq lowering pass

* fix lint
  • Loading branch information
masahi authored Dec 6, 2022
1 parent 965490e commit 6574e16
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 4 deletions.
5 changes: 5 additions & 0 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class Postproc : public runtime::ObjectRef {
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyGPUCode();
/*!
* \brief Verifies that the VTCM usage of a given schedule is within the provided limit.
* \return The postprocessor created
*/
TVM_DLL static Postproc VerifyVTCMLimit();
/*!
* \brief Creates a postprocessor that rewrites the layout of input tensor
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Verifies that the VTCM usage of the given prim_func is within the provided limit.
* \param func The function to be checked.
* \param limit The limit to check.
* \return true if the VTCM usage is within the provided limit.
*/
TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);

/*!
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .rewrite_tensorize import RewriteTensorize
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
from .verify_vtcm_limit import VerifyVTCMLimit
31 changes: 31 additions & 0 deletions python/tvm/meta_schedule/postproc/verify_vtcm_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that verifies the VTCM usage of a given schedule."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc


@register_object("meta_schedule.VerifyVTCMLimit")
class VerifyVTCMLimit(Postproc):
"""Verifies that the VTCM usage of a given schedule is within the provided limit."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocVerifyVTCMLimit, # type: ignore # pylint: disable=no-member
)
7 changes: 3 additions & 4 deletions src/meta_schedule/postproc/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ Array<Postproc> Postproc::DefaultCUDATensorCore() {

Array<Postproc> Postproc::DefaultHexagon() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(),
Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(),
Postproc::RewriteLayout(),
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(),
Postproc::VerifyVTCMLimit(),
};
}

Expand Down
104 changes: 104 additions & 0 deletions src/meta_schedule/postproc/verify_vtcm_limit.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/tir/transform.h>

#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class VerifyVTCMLimitNode : public PostprocNode {
public:
Integer vtcm_capacity;

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
ICHECK(target->kind->name == "hexagon");
// The value of 0 will disable VTCM verification.
vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value_or(0);
}

bool Verify(const IRModule& mod) const {
for (const auto& kv : mod->functions) {
if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
if (!tir::VerifyVTCMLimit(GetRef<tir::PrimFunc>(prim_func), vtcm_capacity)) {
return false;
}
}
}
return true;
}

bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
for (const auto& kv : mod->functions) {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
} catch (const dmlc::Error& e) {
return false;
}
if (!Verify(lowered)) {
return false;
}
}
}
return true;
}

Postproc Clone() const {
ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(*this);
return Postproc(n);
}

static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit";
TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode);
};

Postproc Postproc::VerifyVTCMLimit() {
ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit")
.set_body_typed(Postproc::VerifyVTCMLimit);

} // namespace meta_schedule
} // namespace tvm
9 changes: 9 additions & 0 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](
return CalculateAllocatedBytes(func);
});

bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
auto sizes = CalculateAllocatedBytes(func);
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
return false;
}
return true;
}

namespace transform {

Pass VerifyVTCMLimit(const Integer& limit) {
Expand Down
127 changes: 127 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc_verify_vtcm_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
import tvm.testing
from tvm import meta_schedule as ms
from tvm import tir
from tvm.script import tir as T


def _create_context(mod, target) -> ms.TuneContext:
return ms.TuneContext(
mod=mod,
target=target,
space_generator=ms.space_generator.PostOrderApply(
sch_rules=[],
postprocs=[ms.postproc.VerifyVTCMLimit()],
mutator_probs={},
),
task_name="test",
)


# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
# fmt: off


@tvm.script.ir_module
class Conv2dNCHWcVTCM:
@T.prim_func
def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(54), T.int64(54), T.int64(32)), "int32"]):
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
p0_global_vtcm = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)], dtype="uint8", scope="global.vtcm")
p1_global_vtcm = T.alloc_buffer([T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)], dtype="uint8", scope="global.vtcm")
for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
for oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(1)):
for oc_chunk_1_init, oh_1_init, ow_1_init, oc_chunk_2_init, oh_2_init, ow_2_init in T.grid(T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(1), T.int64(9)):
with T.block("conv2d_NCHWc_int8_o_init"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1_init + oc_chunk_2_init + oc_chunk_0)
v_oh = T.axis.spatial(T.int64(54), oh_2_init + oh_0 * T.int64(27) + oh_1_init)
v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1_init * T.int64(9) + ow_2_init)
v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
T.reads()
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
for oc_block_1 in T.vectorized(T.int64(32)):
with T.block("conv2d_NCHWc_int8_init"):
v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1)
T.reads()
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init])
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0
for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial(T.int64(2), annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}):
for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(26912)):
with T.block("p0_global.vtcm"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_fused // T.int64(13456))
v2 = T.axis.spatial(T.int64(56), oh_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(13456) // T.int64(464))
v3 = T.axis.spatial(T.int64(56), ow_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(464) // T.int64(16))
v4 = T.axis.spatial(T.int64(32), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(16) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(16))
T.reads(p0[v0, v1, v2, v3, v4])
T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(9216)):
with T.block("p1_global.vtcm"):
v0 = T.axis.spatial(T.int64(2), oc_chunk_0)
v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(4608))
v2 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4608) // T.int64(1536))
v3 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(1536) // T.int64(512))
v4 = T.axis.spatial(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(512) // T.int64(128))
v5 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4))
v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4))
T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6]
for n_1, oc_chunk_1, oh_1, ow_1, oc_block_0_1, kh_1, kw_1, ic_outer_1, ic_f_inner_1, ic_s_inner_0_1, n_2, oc_chunk_2, oh_2, ow_2, oc_block_0_2 in T.grid(T.int64(1), T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(3), T.int64(3), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(9), T.int64(1)):
with T.block("conv2d_NCHWc_int8_o_update"):
v_n = T.axis.spatial(T.int64(1), T.int64(0))
v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1 + oc_chunk_2 + oc_chunk_0)
v_oh = T.axis.spatial(T.int64(54), oh_2 + oh_0 * T.int64(27) + oh_1)
v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1 * T.int64(9) + ow_2)
v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
v_kh, v_kw, v_ic_outer = T.axis.remap("RRR", [kh_1, kw_1, ic_outer_1])
v_ic_f_inner = T.axis.reduce(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ic_f_inner_1)
v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0))
T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4)], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, T.int64(0) : T.int64(32), T.int64(0) : T.int64(4)])
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
for oc_block_1, ic_s_inner_1 in T.grid(T.int64(32), T.int64(4)):
with T.block("conv2d_NCHWc_int8"):
v_oc_block_i, v_ic_s_inner_i = T.axis.remap("SR", [oc_block_1, ic_s_inner_1])
T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i])
T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] + T.Cast("int32", p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i]) * T.Cast("int32", p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])

#fmt on


def test_conv2d_vtcm():
def get_target(vtcm_cap):
target = tvm.target.hexagon("v68", vtcm_capacity=vtcm_cap)
return tvm.target.Target(target, host=target)

sch = tir.Schedule(Conv2dNCHWcVTCM, debug_mask="all")

ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(70000))
assert not ctx.space_generator.postprocs[0].apply(sch)

ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(75000))
assert ctx.space_generator.postprocs[0].apply(sch)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 6574e16

Please sign in to comment.