From 15da6bba7cd1fca3356e4cf9297bc065da67fb9f Mon Sep 17 00:00:00 2001 From: ghostplant Date: Wed, 1 Sep 2021 20:25:45 +0800 Subject: [PATCH] update algo_reduce schedule (#296) --- backends/c-rocm/schedule/standard/algo_reduce.py | 2 -- engine/device-stub/tvm_v0.7.patch | 11 ++++++++++- engine/install_antares_host.sh | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/c-rocm/schedule/standard/algo_reduce.py b/backends/c-rocm/schedule/standard/algo_reduce.py index 864563d7..8ab5ef3c 100644 --- a/backends/c-rocm/schedule/standard/algo_reduce.py +++ b/backends/c-rocm/schedule/standard/algo_reduce.py @@ -11,8 +11,6 @@ def schedule_branch(attrs, output, prefix): sizes = cfg.define_split(f"{prefix}R{i}", attrs.get_extent(ax), num_outputs=2) if rax == i: r_range = max(2, sizes[1]) - if not attrs.backend.startswith('c-cuda'): - r_range = r_range if r_range != 32 else 16 ko, ki = s[output].split(ax, factor=r_range) BF = s.rfactor(output, ki) diff --git a/engine/device-stub/tvm_v0.7.patch b/engine/device-stub/tvm_v0.7.patch index aa3899d5..4ebfc735 100644 --- a/engine/device-stub/tvm_v0.7.patch +++ b/engine/device-stub/tvm_v0.7.patch @@ -213,7 +213,7 @@ index 9cd29357f..e65c61149 100644 TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) == IntImm(DataType::UInt(8), dtype.bits()) && diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc -index f6cb09672..fae978d03 100644 +index f6cb09672..9f39e841c 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -34,13 +34,17 @@ @@ -235,6 +235,15 @@ index f6cb09672..fae978d03 100644 Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { +@@ -486,7 +490,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). + bool is_warp_reduction(const std::vector& types) const { + // Only cuda target supports warp reductions. +- if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false; ++ if (strncmp("c-cuda", getenv("BACKEND"), 6) != 0) return false; + + // rocm only supports 32 bit operands for shuffling at the moment + if ((target_->kind->name == "rocm") && diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 921c7ad79..d96f871de 100644 --- a/src/tir/transforms/split_host_device.cc diff --git a/engine/install_antares_host.sh b/engine/install_antares_host.sh index c2338810..cdeec86e 100755 --- a/engine/install_antares_host.sh +++ b/engine/install_antares_host.sh @@ -3,7 +3,7 @@ cd $(dirname $0)/.. ANTARES_ROOT=$(pwd) -VERSION_TAG=v0.2dev10 +VERSION_TAG=v0.2dev11 REQUIRED_CMDS="git python3 g++ make"