Skip to content

Commit

Permalink
update algo_reduce schedule (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Sep 1, 2021
1 parent 9b7e2bb commit 15da6bb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 0 additions & 2 deletions backends/c-rocm/schedule/standard/algo_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ def schedule_branch(attrs, output, prefix):
sizes = cfg.define_split(f"{prefix}R{i}", attrs.get_extent(ax), num_outputs=2)
if rax == i:
r_range = max(2, sizes[1])
if not attrs.backend.startswith('c-cuda'):
r_range = r_range if r_range != 32 else 16
ko, ki = s[output].split(ax, factor=r_range)
BF = s.rfactor(output, ki)

Expand Down
11 changes: 10 additions & 1 deletion engine/device-stub/tvm_v0.7.patch
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ index 9cd29357f..e65c61149 100644
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
IntImm(DataType::UInt(8), dtype.bits()) &&
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index f6cb09672..fae978d03 100644
index f6cb09672..9f39e841c 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -34,13 +34,17 @@
Expand All @@ -235,6 +235,15 @@ index f6cb09672..fae978d03 100644

Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
@@ -486,7 +490,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
// Also, the warp/wavefront size differs (64 on rocm, 32 on cuda).
bool is_warp_reduction(const std::vector<DataType>& types) const {
// Only cuda target supports warp reductions.
- if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false;
+ if (strncmp("c-cuda", getenv("BACKEND"), 6) != 0) return false;

// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc
index 921c7ad79..d96f871de 100644
--- a/src/tir/transforms/split_host_device.cc
Expand Down
2 changes: 1 addition & 1 deletion engine/install_antares_host.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
cd $(dirname $0)/..
ANTARES_ROOT=$(pwd)

VERSION_TAG=v0.2dev10
VERSION_TAG=v0.2dev11

REQUIRED_CMDS="git python3 g++ make"

Expand Down

0 comments on commit 15da6bb

Please sign in to comment.