Skip to content

Commit

Permalink
enforcement on loop partition control
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest authored and tqchen committed Jan 20, 2025
1 parent d641354 commit d6c1489
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 17 deletions.
3 changes: 3 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ TVM_DLL const Op& vscale();
*/
TVM_DLL const Op& get_active_lane_mask();

/*! \brief Annotate a predicate not be considered as target condition of loop partition. */
TVM_DLL const Op& ignore_loop_partition();

/*! \brief The kind of structure field info used in intrinsic */
enum TVMStructFieldKind : int {
// array head address
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,7 @@ def wrapped(*args, **kwargs):
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
vscale = _op_wrapper(_tir_op.vscale)
ignore_loop_partition = _op_wrapper(_tir_op.ignore_loop_partition)


def _dtype_forward(func):
Expand Down Expand Up @@ -2262,4 +2263,5 @@ def wrapped(*args, **kwargs):
"vscale",
"get_active_lane_mask",
"call_kernel",
"ignore_loop_partition",
]
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .op import start_profile_intrinsic, end_profile_intrinsic
from .op import vscale, get_active_lane_mask, get_vscale_expr
from .op import dp4a
from .op import ignore_loop_partition
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,18 @@ def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> Pri
return min_size // dtype.bits * vscale()


def ignore_loop_partition(predicate) -> PrimExpr:
"""
Annotate a predicate not be considered as target condition of loop partition.
Parameters
----------
predicate : PrimExpr
The annotated predicate expression.
"""
return call_intrin("bool", "tir.ignore_loop_partition", predicate)


# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));

TIR_DEFINE_BUILTIN_FUNC(ignore_loop_partition)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kNone));

} // namespace builtin
} // namespace tir
} // namespace tvm
51 changes: 36 additions & 15 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class CandidateSelector final : public StmtExprVisitor {
: partition_const_loop_(partition_const_loop) {}

void VisitStmt_(const ForNode* op) final {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
// partition const loop when sets partition_const_loop_
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
// always treat var with hint to be partitioned
const VarNode* var = op->loop_var.get();
if (partition_hint_vars.count(var)) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
Expand All @@ -126,14 +126,14 @@ class CandidateSelector final : public StmtExprVisitor {
const IterVarNode* iv = op->node.as<IterVarNode>();
ICHECK(iv);
Var var = iv->var;
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
// always treat var with hint to be partitioned
if (partition_hint_vars.count(var.get())) {
candidates.insert(GetRef<Stmt>(op));
StmtExprVisitor::VisitStmt_(op);
return;
}
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
Expand Down Expand Up @@ -262,6 +262,8 @@ class PartitionFinder : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::likely())) {
DeduceCondition(op->args[0]);
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
return;
} else {
StmtExprVisitor::VisitExpr_(op);
}
Expand All @@ -287,6 +289,22 @@ class PartitionFinder : public StmtExprVisitor {
// cond is true within interval
partitions[{cond, true}] = interval;
}

if (interval.IsNothing()) {
// `DeduceBound` do not support NE now, thus when
// deduce l==r failed, just only try (l<=r && l>=r)
if (const EQNode* op = cond.as<EQNode>()) {
IntSet part1 = DeduceBound(current_var_, GE(op->a, op->b), hint_map_, relax_map_);
IntSet part2 = DeduceBound(current_var_, LE(op->a, op->b), hint_map_, relax_map_);
interval = arith::Intersect({part1, part2});
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond, true}] = interval;
return;
}
}
}

PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
Expand Down Expand Up @@ -469,6 +487,7 @@ std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
if (kv.first.second == cond_value) {
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);

if (!intersection->IsEmpty()) {
sets.push_back(kv.second);
cond_set.insert(kv.first.first);
Expand Down Expand Up @@ -625,8 +644,7 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
}();

if (middle_interval.IsNothing() && opt_cond_value == false) {
// Return loop directly as it can be simplified.
return stmt;
return Stmt();
}

if (!opt_cond_value.has_value()) {
Expand Down Expand Up @@ -750,6 +768,9 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator {
if (op->op.same_as(builtin::likely())) {
ICHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
ICHECK_EQ(op->args.size(), 1);
return StmtExprMutator::VisitExpr(op->args[0]);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand Down
87 changes: 85 additions & 2 deletions tests/python/tir-transform/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,11 +570,12 @@ def test_explicit_partition_hint():
tvm.ir.assert_structural_equal(mod["main"], partitioned_concat)


def partition_from_scheduled_tir(prim_func, pass_cfg):
def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True):
with tvm.transform.PassContext(config=pass_cfg):
mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
if do_flatten:
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
Expand Down Expand Up @@ -1037,6 +1038,29 @@ def concat_five_buffers_with_equalities_expected(
T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]


@T.prim_func
def nested_partition_with_single_points(A: T.Buffer[(25,), "int32"]):
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if i == 1:
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if j > 2:
A[i * 5 + j] = i * 5 + j
else:
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if j > 2:
A[i * 5 + j] = i * 15 + j


@T.prim_func
def nested_partition_with_single_points_expected(A: T.Buffer[(25,), "int32"]):
for j in range(2):
A[j + 3] = j + 3
for j in range(2):
A[j + 8] = j + 8
for i, j in T.grid(3, 2):
A[i * 5 + j + 13] = i * 15 + j + 33


@pytest.mark.parametrize(
"origin,expected",
[
Expand All @@ -1045,6 +1069,7 @@ def concat_five_buffers_with_equalities_expected(
(concat_func_end_point_equality, concat_func_end_point_equality_expected),
(concat_func_edge_equalities, concat_func_edge_equalities_expected),
(concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected),
(nested_partition_with_single_points, nested_partition_with_single_points_expected),
],
)
def test_single_point_partition(origin, expected):
Expand All @@ -1062,5 +1087,63 @@ def test_single_point_partition(origin, expected):
tvm.ir.assert_structural_equal(mod["main"], expected)


def test_equation_on_floordiv():
@T.prim_func
def before(A: T.Buffer[(2, 2, 20), "int32"]):
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
if i == 1:
for vv in T.vectorized(640, annotations={"pragma_loop_partition_hint": 1}):
if i * 2 + vv // 320 == 3:
A[i - 1, i * 2 + vv // 320 - 3, vv % 320 // 16] = 1

@T.prim_func
def expected(A: T.Buffer[(2, 2, 20), "int32"]):
for vv in T.vectorized(320):
A[0, 0, vv // 16] = 1

expected = expected.with_attr({"global_symbol": "main"})
after = partition_from_scheduled_tir(
before.with_attr("global_symbol", "main"), {}, do_flatten=False
)
tvm.ir.assert_structural_equal(after["main"], expected)


def test_ignore_loop_partition_hint():
"""Skip unroll body and prologue for pipeline case"""

@T.prim_func
def before(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
B = T.decl_buffer([2], "float32")
C = T.decl_buffer([2], "float32")
for i in T.serial(12, annotations={"pragma_loop_partition_hint": 1}):
if T.ignore_loop_partition(i < 10):
B[i % 2] = A[i] + 1.0
if T.ignore_loop_partition(1 <= i and i < 11):
C[(i - 1) % 2] = B[(i - 1) % 2] + 2.0
if 2 <= i:
D[i - 2] = C[i % 2] + 3.0

@T.prim_func
def expected(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
B = T.decl_buffer([2], "float32")
C = T.decl_buffer([2], "float32")
for i in range(2):
B[i] = A[i] + 1.0
if i == 1:
C[i - 1] = B[i - 1] + 2.0
for i in T.serial(10):
if i < 8:
B[i % 2] = A[i + 2] + 1.0
if i < 9:
C[(i + 1) % 2] = B[(i + 1) % 2] + 2.0
D[i] = C[i % 2] + 3.0

expected = expected.with_attr({"global_symbol": "main"})
after = partition_from_scheduled_tir(
before.with_attr({"global_symbol": "main"}), {}, do_flatten=False
)
tvm.ir.assert_structural_equal(after["main"], expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit d6c1489

Please sign in to comment.