You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// this case will fail, if using tidx = 8 and tidy = 64
// for inner reduciton, tidy is derived as 10240 / (tidxvecxnloadx) = 64
// for outer reduction, tidy is derived as 216 / nloady = 54
// the kernel will be launched with bdimy = 64
// in the generated kernel, all these 64 threads are attending the block
// reduction but only 54 of them have valid initial values. thus the result is
// polluted by other 10 threads and can't pass the validation. to avoid this
// issue, we can use one of the following methods: (1) make sure tidy derived
// from inner reduction & outer reduciton is same (when 216 % tidy == 0) or
// (2) instead of split outer reduciton tensor with nloady, split it with
// bdimy. The current scheduler is using method-2.
Need to capture this parallel pattern error and provide a more helpful message.
TEST_F(NVFuserTest, FusionCombinedReduction_CUDA) {
auto ceilDiv = [](const int a, const int b) { return (a + b - 1) / b; };
constexpr bool verbose = false;
const auto dev_prop = at::cuda::getCurrentDeviceProperties();
const int64_t device_multiprocessor_count =
(int64_t)dev_prop->multiProcessorCount;
const int dim0 = 2048;
const int dim1 = 10240;
#if 0
const int tidx = 64;
const int tidy = 8;
#else
const int tidx = 8;
const int tidy = 64;
#endif
const int bidy = 2 * device_multiprocessor_count; // 216
const int vecx = 4;
const int nloadx =
ceilDiv(dim1, vecx * tidx * tidy); // 5, simulate persistent buffer
const int nloady = ceilDiv(bidy, tidy); // 216/16=13.5 -> 14
Fusion fusion;
FusionGuard fg(&fusion);
TensorView* tv0 = makeContigTensor(2);
TensorView* tv1 = sum(tv0, {1});
TensorView* tv2 = sum(tv0, {0});
fusion.addInput(tv0);
fusion.addOutput(tv1);
fusion.addOutput(tv2);
auto cached_inputs = scheduler_utils::cacheInputs(&fusion, true);
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(&fusion, true);
auto reduction_tvs = scheduler_utils::getReductionTvs(&fusion);
scheduler_utils::clearMemorySpace(&fusion);
std::vector<TensorView*> inner_reduction_tvs, outer_reduction_tvs;
for (auto tv : reduction_tvs) {
if (scheduler_utils::isFastestDimReduction(tv)) {
inner_reduction_tvs.emplace_back(tv);
} else {
outer_reduction_tvs.emplace_back(tv);
}
if (verbose)
std::cout << "tv= " << tv->toString() << ", fastest_dim_reduction= "
<< scheduler_utils::isFastestDimReduction(tv) << std::endl;
}
TensorView* inner_reduction_tv = inner_reduction_tvs[0];
TensorView* outer_reduction_tv = outer_reduction_tvs[0];
inner_reduction_tv->split(-1, vecx);
inner_reduction_tv->split(-2, tidx);
inner_reduction_tv->split(-3, nloadx, false);
inner_reduction_tv->split(0, bidy, false);
inner_reduction_tv->axis(0)->parallelize(ParallelType::BIDy);
inner_reduction_tv->axis(-3)->parallelize(ParallelType::TIDy);
inner_reduction_tv->axis(-2)->parallelize(ParallelType::TIDx);
inner_reduction_tv->axis(-1)->parallelize(ParallelType::Vectorize);
if (verbose)
std::cout << "inner_reduction_tv " << inner_reduction_tv->toString()
<< std::endl;
auto reference_tv_inner =
reduction_scheduler_utils::sortAndRFactor(inner_reduction_tv);
if (verbose)
std::cout << "reference_tv_inner " << reference_tv_inner->toString()
<< std::endl;
outer_reduction_tv->split(0, bidy, false);
auto partialResult = outer_reduction_tv->rFactor({1});
partialResult->cacheBefore();
partialResult->setMemoryType(MemoryType::Global);
auto partialResultReload = partialResult->cacheAfter();
outer_reduction_tv->split(0, nloady, false);
outer_reduction_tv->split(-1, tidx);
outer_reduction_tv->split(-2, bidy);
outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy);
outer_reduction_tv->axis(-2)->parallelize(ParallelType::BIDy);
outer_reduction_tv->axis(-1)->parallelize(ParallelType::TIDx);
if (verbose)
std::cout << "outer_reduction_tv " << outer_reduction_tv->toString()
<< std::endl;
auto reference_tv_outer =
reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv);
if (verbose)
std::cout << "reference_tv_outer " << reference_tv_outer->toString()
<< std::endl;
reduction_scheduler_utils::propagateTransformation(
reference_tv_inner, {partialResultReload});
reduction_scheduler_utils::propagateTransformation(
reference_tv_outer, {partialResultReload});
std::vector<TensorView*> cached_gmem_temp{partialResult};
// cached_gmem is float, may use a different vectorization factor
for (auto tv : cached_gmem_temp) {
tv->split(-1, 4);
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
reduction_scheduler_utils::propagateParallelization(
&fusion,
inner_reduction_tv,
reference_tv_inner,
true,
true,
cached_inputs,
cached_outputs);
reduction_scheduler_utils::propagateParallelization(
&fusion,
outer_reduction_tv,
reference_tv_outer,
true,
true,
cached_inputs,
cached_outputs);
inlineMost();
LaunchParams launch_constraints;
constexpr int64_t maxrregcount = 64;
CompileParams compile_params{DataType::Int, maxrregcount, true};
if (verbose)
fusion.print();
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor tv_input = at::randn({dim0, dim1}, options);
auto tv_aten_output = tv_input.to(at::kFloat).sum({1});
at::Tensor tv_cg_output = at::empty({dim0}, options);
at::Tensor qv_cg_output = at::empty({dim1}, options);
auto qv_aten_output = tv_input.to(at::kFloat).sum({0});
FusionExecutor fe;
fe.compileFusion(&fusion, {tv_input}, launch_constraints, compile_params);
fe.runFusion(
{tv_input},
{tv_cg_output, qv_cg_output},
launch_constraints,
compile_params);
testValidate(
&fusion,
{tv_cg_output, qv_cg_output},
{tv_input},
{tv_aten_output, qv_aten_output},
__LINE__,
__FILE__);
}
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
🚀 The feature, motivation and pitch
// this case will fail, if using tidx = 8 and tidy = 64
// for inner reduciton, tidy is derived as 10240 / (tidxvecxnloadx) = 64
// for outer reduction, tidy is derived as 216 / nloady = 54
// the kernel will be launched with bdimy = 64
// in the generated kernel, all these 64 threads are attending the block
// reduction but only 54 of them have valid initial values. thus the result is
// polluted by other 10 threads and can't pass the validation. to avoid this
// issue, we can use one of the following methods: (1) make sure tidy derived
// from inner reduction & outer reduciton is same (when 216 % tidy == 0) or
// (2) instead of split outer reduciton tensor with nloady, split it with
// bdimy. The current scheduler is using method-2.
Need to capture this parallel pattern error and provide a more helpful message.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: