From 05ec62b70550691ddf599a8da04b7efc5cc21350 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Jan 2025 10:56:44 -0800 Subject: [PATCH] Schedule loop domains such that reshape transforms are cancelled (#3679) This PR adds a scheduling primitive, `cancelReshapeInLoopDomains(TensorView* from_tv)`, where all reshape transforms appearing between `from_tv` and fusion outputs are effectively cancelled in their loop domains. Please see the [comment](https://github.com/NVIDIA/Fuser/pull/3679/files#diff-dc44235151285593f374bf60312da86dddebe6aed272e619001c088db507b783R72) for a motivating example. This could be used to remove the restriction of the interfering reshape in reduction/normalization fusions. --- .../scheduler/tools/loop_domain_scheduler.cpp | 182 +++++++++++- csrc/scheduler/tools/loop_domain_scheduler.h | 63 ++++- tests/cpp/test_loop_domain_scheduling.cpp | 266 ++++++++++++++++++ 3 files changed, 492 insertions(+), 19 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index fd7f2a01240..daa7c92eacc 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -407,7 +408,8 @@ void scheduleLoopDomainsLike( void scheduleLoopDomainsBy( const std::vector& tvs, - Expr* transform) { + Expr* transform, + Direction replay_dir) { Fusion* fusion = transform->fusion(); IdModel id_model(fusion, /*build_graphs=*/false); const ValGraph& exact_graph = id_model.buildExactGraph(); @@ -439,17 +441,19 @@ void scheduleLoopDomainsBy( } } - Direction replay_dir = Direction::Undefined; - // It should be either: all of the inputs found and none of the // outputs found, or none of the inputs found and all of the // outputs found. - if (input_ids.size() == transform->inputs().size()) { + Direction replay_dir_tv = Direction::Undefined; + if (replay_dir != Direction::Backward && + input_ids.size() == transform->inputs().size()) { NVF_ERROR(output_ids.empty()); - replay_dir = Direction::Forward; - } else if (output_ids.size() == transform->outputs().size()) { + replay_dir_tv = Direction::Forward; + } else if ( + replay_dir != Direction::Forward && + output_ids.size() == transform->outputs().size()) { NVF_ERROR(input_ids.empty()); - replay_dir = Direction::Backward; + replay_dir_tv = Direction::Backward; } else { // Replay not possible since none of inputs nor outputs are connected with // the transform @@ -457,11 +461,12 @@ void scheduleLoopDomainsBy( } const auto& existing_ids = - replay_dir == Direction::Forward ? input_ids : output_ids; + replay_dir_tv == Direction::Forward ? input_ids : output_ids; // Clone inputs or outputs - auto& new_ids = replay_dir == Direction::Forward ? output_ids : input_ids; - const auto& ref_of_ids_to_generate = replay_dir == Direction::Forward + auto& new_ids = + replay_dir_tv == Direction::Forward ? output_ids : input_ids; + const auto& ref_of_ids_to_generate = replay_dir_tv == Direction::Forward ? transform->outputs() : transform->inputs(); @@ -500,5 +505,162 @@ void scheduleLoopDomainsBy( return; } +void cancelReshapeInLoopDomains(TensorView* from_tv) { + Fusion* fusion = from_tv->fusion(); + IdModel id_model(fusion, /*build_graphs=*/false); + id_model.buildExactGraph(); + const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT); + + // Reshapes producing these IDs should not be cancelled + ValGroups reshape_dependent_ids; + for (const ExprGroup& expr_g : + exact_graph.disjointExprSets().disjointSets()) { + if (expr_g->front()->isA()) { + reshape_dependent_ids.pushBack(exact_graph.inputGroups(expr_g)); + } + } + + for (const ValGroup& val_g : exact_graph.disjointValSets().disjointSets()) { + if (std::any_of(val_g->begin(), val_g->end(), [](Val* val) { + NVF_ERROR(val->isA()); + return val->as()->isReduction(); + })) { + reshape_dependent_ids.pushBack(val_g); + } + } + + auto all_dep_exprs_from_tv = + DependencyCheck::getAllExprsBetween({from_tv}, fusion->outputs()); + + // Visit all reshapes in a reverse topological order + for (auto exprs_it = all_dep_exprs_from_tv.rbegin(); + exprs_it != all_dep_exprs_from_tv.rend(); + ++exprs_it) { + auto reshape = dynamic_cast(*exprs_it); + if (reshape == nullptr) { + continue; + } + + auto reshape_out = reshape->out(); + + auto all_dep_vals = + DependencyCheck::getAllValsBetween({reshape_out}, fusion->outputs()); + // Exclude reshape_out. These tensors are going to be updated by + // replaying the reshape transform exprs using + // scheduleLoopDomainsBy. Since the reshape output + // tensor already has the exprs, replaying with + // scheduleLoopDomainsBy would complain if not excluded. For the + // reshape output tensor, setLoopDomain is done with the existing + // IDs without replaying. + all_dep_vals.erase(all_dep_vals.begin()); + auto all_dep_tvs = ir_utils::filterByType(all_dep_vals); + + // Find logical IDs that do not exist in the root domain. They are + // the new IDs that are produced by this reshape op. If a logical + // ID is already found in the root domain, there's nothing to do + // for it. + std::vector new_logical_ids; + for (const auto& logical_id : reshape_out->getLogicalDomain()) { + if (!reshape_out->domain()->isRoot(logical_id)) { + new_logical_ids.push_back(logical_id); + } + } + + if (new_logical_ids.empty()) { + // Nothing to do with a no-op reshape. This may not happen. + continue; + } + + // Find logical IDs that do not need to exist in the loop domain + std::unordered_set cancellable_ids; + for (const auto new_logical_id : new_logical_ids) { + auto new_id_group = exact_graph.toGroup(new_logical_id); + // Not cancellable if used by resize or reduced. + auto reachable_exprs = getReachableNodesFrom( + {new_id_group}, + {reshape_dependent_ids.begin(), reshape_dependent_ids.end()}, + Direction::Forward, + exact_graph); + if (!reachable_exprs.empty()) { + continue; + } + + cancellable_ids.insert(new_logical_id); + } + + if (cancellable_ids.empty()) { + continue; + } + + // Update the loop domain by each of the reshape exprs in a + // reverse topological order. + auto reshape_exprs = DependencyCheck::getAllExprsBetween( + {reshape_out->getRootDomain().begin(), + reshape_out->getRootDomain().end()}, + {reshape_out->getLogicalDomain().begin(), + reshape_out->getLogicalDomain().end()}); + + auto reshape_out_loop_domain = reshape_out->getLoopDomain(); + + for (auto reshape_exprs_it = reshape_exprs.rbegin(); + reshape_exprs_it != reshape_exprs.rend(); + ++reshape_exprs_it) { + auto reshape_expr = *reshape_exprs_it; + + // If any of the output IDs of reshape_expr is not found in + // cancellable_ids, that means the expr cannot be cancelled. + if (std::any_of( + reshape_expr->outputs().begin(), + reshape_expr->outputs().end(), + [&](Val* reshape_expr_out) -> bool { + return !cancellable_ids.count(reshape_expr_out); + })) { + continue; + } + + // Update all of the dependent TVs by this reshape expr + scheduleLoopDomainsBy( + all_dep_tvs.vector(), reshape_expr, Direction::Backward); + + cancellable_ids.insert( + reshape_expr->inputs().begin(), reshape_expr->inputs().end()); + + // For the reshape output tensor itself, since it already has the + // reshape expr, it just needs + // tv->setLoopDomain(tv->getRootDomain()). However, since some of the + // reshape exprs may not be cancellable, update a vector of the + // loop IDs for each of the cancelled exprs individually and use + // it to set the loop domain of the reshape output tensor + + // Insert the input IDs to the loop domain + auto insert_pos = std::find( + reshape_out_loop_domain.begin(), + reshape_out_loop_domain.end(), + reshape_expr->outputs().front()); + NVF_ERROR(insert_pos != reshape_out_loop_domain.end()); + for (auto inp : reshape_expr->inputs()) { + insert_pos = + reshape_out_loop_domain.insert(insert_pos, inp->as()); + ++insert_pos; + } + + // Remove the output IDs + reshape_out_loop_domain.erase( + std::remove_if( + reshape_out_loop_domain.begin(), + reshape_out_loop_domain.end(), + [&](IterDomain* cur_loop_id) { + return std::find( + reshape_expr->outputs().begin(), + reshape_expr->outputs().end(), + cur_loop_id) != reshape_expr->outputs().end(); + }), + reshape_out_loop_domain.end()); + } + + reshape_out->setLoopDomain(reshape_out_loop_domain); + } +} + } // namespace scheduler_tools } // namespace nvfuser diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 5939c9d31e2..fa0d4e0d2ae 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -7,13 +7,17 @@ // clang-format on #pragma once +#include + #include namespace nvfuser { class Expr; +class Fusion; class TensorView; class IterDomain; +class ViewOp; namespace scheduler_tools { @@ -30,14 +34,14 @@ void scheduleLoopDomainsLike( bool update_loop_domain_only = false); // Replay a transform expr on the loop domain of each of the given -// tensors. If the input of the transform is exact mapped with the loop -// domain, the transform is replayed as a forward op. If the output -// is exact mapped with the loop domain, it's replayed as a backward -// op. The loop domain of each tensor is updated with the replayed -// transform expr. If it's replayed as a forward op, the outputs -// replace the inputs in the loop domain. If it's replayed as a -// backward op, the inputs replace the outputs in the loop domain. The -// new IDs are inserted at the outermost position of the input IDs. +// tensors. If the replay direction is specified, the expr is replayed +// as specified. Otherwise, if the input of the transform is exact mapped with +// the loop domain, the transform is replayed as a forward op. If the output is +// exact mapped with the loop domain, it's replayed as a backward op. The loop +// domain of each tensor is updated with the replayed transform expr. If it's +// replayed as a forward op, the outputs replace the inputs in the loop domain. +// If it's replayed as a backward op, the inputs replace the outputs in the loop +// domain. The new IDs are inserted at the outermost position of the input IDs. // // For example, suppose a fusion has: // @@ -62,7 +66,48 @@ void scheduleLoopDomainsLike( // LoopDomainSchedulingTest.ScheduleLoopDomainsBy1 for more examples. void scheduleLoopDomainsBy( const std::vector& tvs, - Expr* transform); + Expr* transform, + Direction replay_dir = Direction::Undefined); + +// For each of immediate and indirect consumer tensors of from_tv, +// schedule its loop domain such that reshape transforms appearing +// between the tensor and from_tv are cancelled. For example, suppose +// a fusion has: +// +// t0 = makeSymbolicTensor(3); // [i0, i1, i2] +// t1 = permute(t0, {1, 0, 2}); // [i1, i0, i2] +// t2 = reshape(t1, {i1, i0*i2}); // [i1, i0*i2] +// t3 = sin(t2) // [i1, i0*i2] +// +// In this case, cancelReshapeInLoopDomains(t0) would affect t2 and t3 +// as follows: +// +// t2: +// root: [i1, i0*i2] (unchanged) +// logical: [i1, i0*i2] (unchanged) +// loop: [i1, i0, i2] +// +// t3: +// logical: [i1, i0*i2] (unchanged) +// loop: [i1, i0, i2] +// +// t1 would not be changed at all as there's no reshape between t0 and +// t1. +// +// This scheduling could help optimize memory accesses to +// fusion inputs. In the above case, we could then reorder the loop +// domains of t1, t2 and t3 as [i0, i1, i2], i.e., the same ordering +// as t0, which could minimize strided accesses. +// +// This scheduling is not always feasible. Specifically, if a reshape +// output iter domain is resized, the loop domain needs to keep using +// the reshape output iter domain. Similarly, if a rehape output iter +// domain is reduced, the reshape is currently not cancelled. This is +// because if a reshape has a split and only one of the split output +// iter domain is reduced, the split needs to remain. If a reshape +// only consists of merge transforms, cancellation should be possible, +// but that is not currently supported. +void cancelReshapeInLoopDomains(TensorView* from_tv); } // namespace scheduler_tools } // namespace nvfuser diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index 32492821a89..0ba0efc5283 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -542,4 +542,270 @@ TEST_F(LoopDomainSchedulingTest, BroadcastRefereceIDs) { } } +// Cancelling a reshape to make all tensors ordered as the input +TEST_F(LoopDomainSchedulingTest, CancelReshape1) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{16, 32, 2}; + + auto tv0 = makeContigConcreteTensor(shape); // [i0, i1, i2] + fusion.addInput(tv0); + auto tv1 = permute(tv0, {1, 0, 2}); // [i1, i0, i2] + auto tv2 = + reshape(tv1, shape, {shape[1], shape[0] * shape[2]}); // [i1, i0*i2] + auto tv3 = sin(tv2); + fusion.addOutput(tv3); + + // Cancel the reshape of tv2 + scheduler_tools::cancelReshapeInLoopDomains(tv0); + + // The loop domain of tv2 should now be the same as its root domain. + EXPECT_EQ(tv2->getRootDomain(), tv2->getLoopDomain()); + // The loop domain of tv3 should be exact mapped with the tv2 loop + // domain + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + EXPECT_EQ( + exact_graph.toGroups(tv3->getLoopDomain()), + exact_graph.toGroups(tv2->getLoopDomain())); + } + + // Reorder tv3 as the input + tv3->reorder({1, 0, 2}); + tv3->flatten(); + tv3->split(0, 128); + scheduler_tools::scheduleLoopDomainsLike({tv1, tv2}, tv3->getLoopDomain()); + + // All loop domains should be exact mapped + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + const auto ref_loop = exact_graph.toGroups(tv3->getLoopDomain()); + for (auto tv : {tv1, tv2}) { + EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop); + } + } + + inlineMost(); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Cancelling chained reshape ops +TEST_F(LoopDomainSchedulingTest, CancelReshape2) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{10, 11, 12}; + + auto tv0 = makeContigConcreteTensor(shape); // [i0, i1, i2] + fusion.addInput(tv0); + auto tv1 = reshape( + tv0, + {IrBuilder::create(shape[1]), + IrBuilder::create(shape[0] * shape[2])}); + auto tv2 = reshape( + tv1, + {IrBuilder::create(shape[1]), + IrBuilder::create(shape[2]), + IrBuilder::create(shape[0])}); + auto tv3 = reshape( + tv2, + {IrBuilder::create(shape[0] * shape[1]), + IrBuilder::create(shape[2])}); + fusion.addOutput(tv3); + + // Cancel all reshape ops + scheduler_tools::cancelReshapeInLoopDomains(tv0); + + // All of the tensors should have the same loop domain + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + const auto ref_loop = exact_graph.toGroups(tv0->getLoopDomain()); + for (auto tv : {tv1, tv2, tv3}) { + EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop); + } + } + + tv3->flatten(); + tv3->split(0, 32); + + inlineMost(); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Two reshapes that get merged by a binary op +TEST_F(LoopDomainSchedulingTest, CancelReshape3) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{10, 11}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + auto tv1 = reshape(tv0, {IrBuilder::create(-1L)}); + auto tv2 = reshape(tv0, {IrBuilder::create(-1L)}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + // The cancellation of the second reshape won't do anything as the + // loop domain is already updated by the first reshape. + scheduler_tools::cancelReshapeInLoopDomains(tv0); + + // All of the tensors should have the same loop domain + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + const auto ref_loop = exact_graph.toGroups(tv0->getLoopDomain()); + for (auto tv : {tv1, tv2, tv3}) { + EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop); + } + } + + inlineMost(); + + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + std::vector inputs({t0}); + + KernelExecutor ke; + ke.compile(&fusion, inputs); + auto outputs = ke.run(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Resize should prevent cancellation +TEST_F(LoopDomainSchedulingTest, CancelReshape4) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{10, 11, 12}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + // Non-cancellable reshape due to the following slice + auto tv1 = reshape( + tv0, {IrBuilder::create(shape[0]), IrBuilder::create(-1L)}); + auto tv2 = slice( + tv1, + {{fusion.zeroVal(), tv1->axis(0)->extent()}, + {fusion.oneVal(), tv1->axis(1)->extent()}}); + fusion.addOutput(tv2); + + // Cancellable reshape + auto tv3 = reshape( + tv0, + {IrBuilder::create(shape[0] * shape[1]), + IrBuilder::create(-1L)}); + auto tv4 = slice( + tv3, + {{fusion.zeroVal(), tv3->axis(0)->extent()}, + {fusion.oneVal(), tv3->axis(1)->extent()}}); + fusion.addOutput(tv4); + + const auto tv1_original_loop = tv1->getLoopDomain(); + const auto tv2_original_loop = tv2->getLoopDomain(); + + // tv1 and tv2 should not be modified as the slice depends on the reshaped + // domain + scheduler_tools::cancelReshapeInLoopDomains(tv0); + + EXPECT_EQ(tv1->getLoopDomain(), tv1_original_loop); + EXPECT_EQ(tv2->getLoopDomain(), tv2_original_loop); + + // The tv3 reshape should be cancelled as the slice does not + // depend on the reshape expr + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + ValGroups ref_loop; + for (const auto i : c10::irange(2)) { + ref_loop.pushBack(exact_graph.toGroup(tv0->getLoopDomain().at(i))); + } + // The first two loop IDs should be exact mapped with tv0 + for (auto tv : {tv3, tv4}) { + ASSERT_EQ(tv->getLoopDomain().size(), 3); + ValGroups tv_loop_groups; + for (const auto i : c10::irange(2)) { + tv_loop_groups.pushBack(exact_graph.toGroup(tv->getLoopDomain().at(i))); + } + EXPECT_EQ(tv_loop_groups, ref_loop); + } + } +} + +// Reduction should prevent cancellation +TEST_F(LoopDomainSchedulingTest, CancelReshape5) { + Fusion fusion; + FusionGuard fg(&fusion); + + std::vector shape{10, 11, 12}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + // Non-cancellable reshape due to the following reduction + auto tv1 = reshape( + tv0, {IrBuilder::create(shape[0]), IrBuilder::create(-1L)}); + auto tv2 = sum(tv1, {1}); + fusion.addOutput(tv2); + + // Cancellable reshape + auto tv3 = reshape( + tv0, + {IrBuilder::create(shape[0] * shape[1]), + IrBuilder::create(-1L)}); + auto tv4 = sum(tv3, {1}); + fusion.addOutput(tv4); + + const auto tv1_original_loop = tv1->getLoopDomain(); + const auto tv2_original_loop = tv2->getLoopDomain(); + + // tv1 and tv2 should not be modified as the tv2 reduction depends on the + // reshaped domain + scheduler_tools::cancelReshapeInLoopDomains(tv0); + + EXPECT_EQ(tv1->getLoopDomain(), tv1_original_loop); + EXPECT_EQ(tv2->getLoopDomain(), tv2_original_loop); + + // The tv3 reshape should be cancelled as the reduction does not + // depend on the reshape expr + { + IdModel id_model(&fusion, /*build_graphs=*/false); + const auto& exact_graph = id_model.buildExactGraph(); + const auto ref_loop = exact_graph.toGroups(tv0->getLoopDomain()); + for (auto tv : {tv3, tv4}) { + EXPECT_EQ(exact_graph.toGroups(tv->getLoopDomain()), ref_loop); + } + } +} + } // namespace nvfuser