Skip to content

Commit

Permalink
Schedule loop domains such that reshape transforms are cancelled (#3679)
Browse files Browse the repository at this point in the history
This PR adds a scheduling primitive,
`cancelReshapeInLoopDomains(TensorView* from_tv)`, where all reshape
transforms appearing between `from_tv` and fusion outputs are
effectively cancelled in their loop domains. Please see the
[comment](https://github.com/NVIDIA/Fuser/pull/3679/files#diff-dc44235151285593f374bf60312da86dddebe6aed272e619001c088db507b783R72)
for a motivating example.

This could be used to remove the restriction of the interfering reshape
in reduction/normalization fusions.
  • Loading branch information
naoyam authored Jan 10, 2025
1 parent 5251964 commit 05ec62b
Show file tree
Hide file tree
Showing 3 changed files with 492 additions and 19 deletions.
182 changes: 172 additions & 10 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <id_model/id_model.h>
#include <id_model/schedule.h>
#include <ir/internal_nodes.h>
#include <ir/utils.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <val_graph_visitor.h>

Expand Down Expand Up @@ -407,7 +408,8 @@ void scheduleLoopDomainsLike(

void scheduleLoopDomainsBy(
const std::vector<TensorView*>& tvs,
Expr* transform) {
Expr* transform,
Direction replay_dir) {
Fusion* fusion = transform->fusion();
IdModel id_model(fusion, /*build_graphs=*/false);
const ValGraph& exact_graph = id_model.buildExactGraph();
Expand Down Expand Up @@ -439,29 +441,32 @@ void scheduleLoopDomainsBy(
}
}

Direction replay_dir = Direction::Undefined;

// It should be either: all of the inputs found and none of the
// outputs found, or none of the inputs found and all of the
// outputs found.
if (input_ids.size() == transform->inputs().size()) {
Direction replay_dir_tv = Direction::Undefined;
if (replay_dir != Direction::Backward &&
input_ids.size() == transform->inputs().size()) {
NVF_ERROR(output_ids.empty());
replay_dir = Direction::Forward;
} else if (output_ids.size() == transform->outputs().size()) {
replay_dir_tv = Direction::Forward;
} else if (
replay_dir != Direction::Forward &&
output_ids.size() == transform->outputs().size()) {
NVF_ERROR(input_ids.empty());
replay_dir = Direction::Backward;
replay_dir_tv = Direction::Backward;
} else {
// Replay not possible since none of inputs nor outputs are connected with
// the transform
continue;
}

const auto& existing_ids =
replay_dir == Direction::Forward ? input_ids : output_ids;
replay_dir_tv == Direction::Forward ? input_ids : output_ids;

// Clone inputs or outputs
auto& new_ids = replay_dir == Direction::Forward ? output_ids : input_ids;
const auto& ref_of_ids_to_generate = replay_dir == Direction::Forward
auto& new_ids =
replay_dir_tv == Direction::Forward ? output_ids : input_ids;
const auto& ref_of_ids_to_generate = replay_dir_tv == Direction::Forward
? transform->outputs()
: transform->inputs();

Expand Down Expand Up @@ -500,5 +505,162 @@ void scheduleLoopDomainsBy(
return;
}

void cancelReshapeInLoopDomains(TensorView* from_tv) {
Fusion* fusion = from_tv->fusion();
IdModel id_model(fusion, /*build_graphs=*/false);
id_model.buildExactGraph();
const auto& exact_graph = id_model.idGraph(IdMappingMode::EXACT);

// Reshapes producing these IDs should not be cancelled
ValGroups reshape_dependent_ids;
for (const ExprGroup& expr_g :
exact_graph.disjointExprSets().disjointSets()) {
if (expr_g->front()->isA<Resize>()) {
reshape_dependent_ids.pushBack(exact_graph.inputGroups(expr_g));
}
}

for (const ValGroup& val_g : exact_graph.disjointValSets().disjointSets()) {
if (std::any_of(val_g->begin(), val_g->end(), [](Val* val) {
NVF_ERROR(val->isA<IterDomain>());
return val->as<IterDomain>()->isReduction();
})) {
reshape_dependent_ids.pushBack(val_g);
}
}

auto all_dep_exprs_from_tv =
DependencyCheck::getAllExprsBetween({from_tv}, fusion->outputs());

// Visit all reshapes in a reverse topological order
for (auto exprs_it = all_dep_exprs_from_tv.rbegin();
exprs_it != all_dep_exprs_from_tv.rend();
++exprs_it) {
auto reshape = dynamic_cast<ViewOp*>(*exprs_it);
if (reshape == nullptr) {
continue;
}

auto reshape_out = reshape->out();

auto all_dep_vals =
DependencyCheck::getAllValsBetween({reshape_out}, fusion->outputs());
// Exclude reshape_out. These tensors are going to be updated by
// replaying the reshape transform exprs using
// scheduleLoopDomainsBy. Since the reshape output
// tensor already has the exprs, replaying with
// scheduleLoopDomainsBy would complain if not excluded. For the
// reshape output tensor, setLoopDomain is done with the existing
// IDs without replaying.
all_dep_vals.erase(all_dep_vals.begin());
auto all_dep_tvs = ir_utils::filterByType<TensorView>(all_dep_vals);

// Find logical IDs that do not exist in the root domain. They are
// the new IDs that are produced by this reshape op. If a logical
// ID is already found in the root domain, there's nothing to do
// for it.
std::vector<IterDomain*> new_logical_ids;
for (const auto& logical_id : reshape_out->getLogicalDomain()) {
if (!reshape_out->domain()->isRoot(logical_id)) {
new_logical_ids.push_back(logical_id);
}
}

if (new_logical_ids.empty()) {
// Nothing to do with a no-op reshape. This may not happen.
continue;
}

// Find logical IDs that do not need to exist in the loop domain
std::unordered_set<Val*> cancellable_ids;
for (const auto new_logical_id : new_logical_ids) {
auto new_id_group = exact_graph.toGroup(new_logical_id);
// Not cancellable if used by resize or reduced.
auto reachable_exprs = getReachableNodesFrom<ValGraphPermissiveBFS>(
{new_id_group},
{reshape_dependent_ids.begin(), reshape_dependent_ids.end()},
Direction::Forward,
exact_graph);
if (!reachable_exprs.empty()) {
continue;
}

cancellable_ids.insert(new_logical_id);
}

if (cancellable_ids.empty()) {
continue;
}

// Update the loop domain by each of the reshape exprs in a
// reverse topological order.
auto reshape_exprs = DependencyCheck::getAllExprsBetween(
{reshape_out->getRootDomain().begin(),
reshape_out->getRootDomain().end()},
{reshape_out->getLogicalDomain().begin(),
reshape_out->getLogicalDomain().end()});

auto reshape_out_loop_domain = reshape_out->getLoopDomain();

for (auto reshape_exprs_it = reshape_exprs.rbegin();
reshape_exprs_it != reshape_exprs.rend();
++reshape_exprs_it) {
auto reshape_expr = *reshape_exprs_it;

// If any of the output IDs of reshape_expr is not found in
// cancellable_ids, that means the expr cannot be cancelled.
if (std::any_of(
reshape_expr->outputs().begin(),
reshape_expr->outputs().end(),
[&](Val* reshape_expr_out) -> bool {
return !cancellable_ids.count(reshape_expr_out);
})) {
continue;
}

// Update all of the dependent TVs by this reshape expr
scheduleLoopDomainsBy(
all_dep_tvs.vector(), reshape_expr, Direction::Backward);

cancellable_ids.insert(
reshape_expr->inputs().begin(), reshape_expr->inputs().end());

// For the reshape output tensor itself, since it already has the
// reshape expr, it just needs
// tv->setLoopDomain(tv->getRootDomain()). However, since some of the
// reshape exprs may not be cancellable, update a vector of the
// loop IDs for each of the cancelled exprs individually and use
// it to set the loop domain of the reshape output tensor

// Insert the input IDs to the loop domain
auto insert_pos = std::find(
reshape_out_loop_domain.begin(),
reshape_out_loop_domain.end(),
reshape_expr->outputs().front());
NVF_ERROR(insert_pos != reshape_out_loop_domain.end());
for (auto inp : reshape_expr->inputs()) {
insert_pos =
reshape_out_loop_domain.insert(insert_pos, inp->as<IterDomain>());
++insert_pos;
}

// Remove the output IDs
reshape_out_loop_domain.erase(
std::remove_if(
reshape_out_loop_domain.begin(),
reshape_out_loop_domain.end(),
[&](IterDomain* cur_loop_id) {
return std::find(
reshape_expr->outputs().begin(),
reshape_expr->outputs().end(),
cur_loop_id) != reshape_expr->outputs().end();
}),
reshape_out_loop_domain.end());
}

reshape_out->setLoopDomain(reshape_out_loop_domain);
}
}

} // namespace scheduler_tools
} // namespace nvfuser
63 changes: 54 additions & 9 deletions csrc/scheduler/tools/loop_domain_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
// clang-format on
#pragma once

#include <bfs.h>

#include <vector>

namespace nvfuser {

class Expr;
class Fusion;
class TensorView;
class IterDomain;
class ViewOp;

namespace scheduler_tools {

Expand All @@ -30,14 +34,14 @@ void scheduleLoopDomainsLike(
bool update_loop_domain_only = false);

// Replay a transform expr on the loop domain of each of the given
// tensors. If the input of the transform is exact mapped with the loop
// domain, the transform is replayed as a forward op. If the output
// is exact mapped with the loop domain, it's replayed as a backward
// op. The loop domain of each tensor is updated with the replayed
// transform expr. If it's replayed as a forward op, the outputs
// replace the inputs in the loop domain. If it's replayed as a
// backward op, the inputs replace the outputs in the loop domain. The
// new IDs are inserted at the outermost position of the input IDs.
// tensors. If the replay direction is specified, the expr is replayed
// as specified. Otherwise, if the input of the transform is exact mapped with
// the loop domain, the transform is replayed as a forward op. If the output is
// exact mapped with the loop domain, it's replayed as a backward op. The loop
// domain of each tensor is updated with the replayed transform expr. If it's
// replayed as a forward op, the outputs replace the inputs in the loop domain.
// If it's replayed as a backward op, the inputs replace the outputs in the loop
// domain. The new IDs are inserted at the outermost position of the input IDs.
//
// For example, suppose a fusion has:
//
Expand All @@ -62,7 +66,48 @@ void scheduleLoopDomainsLike(
// LoopDomainSchedulingTest.ScheduleLoopDomainsBy1 for more examples.
void scheduleLoopDomainsBy(
const std::vector<TensorView*>& tvs,
Expr* transform);
Expr* transform,
Direction replay_dir = Direction::Undefined);

// For each of immediate and indirect consumer tensors of from_tv,
// schedule its loop domain such that reshape transforms appearing
// between the tensor and from_tv are cancelled. For example, suppose
// a fusion has:
//
// t0 = makeSymbolicTensor(3); // [i0, i1, i2]
// t1 = permute(t0, {1, 0, 2}); // [i1, i0, i2]
// t2 = reshape(t1, {i1, i0*i2}); // [i1, i0*i2]
// t3 = sin(t2) // [i1, i0*i2]
//
// In this case, cancelReshapeInLoopDomains(t0) would affect t2 and t3
// as follows:
//
// t2:
// root: [i1, i0*i2] (unchanged)
// logical: [i1, i0*i2] (unchanged)
// loop: [i1, i0, i2]
//
// t3:
// logical: [i1, i0*i2] (unchanged)
// loop: [i1, i0, i2]
//
// t1 would not be changed at all as there's no reshape between t0 and
// t1.
//
// This scheduling could help optimize memory accesses to
// fusion inputs. In the above case, we could then reorder the loop
// domains of t1, t2 and t3 as [i0, i1, i2], i.e., the same ordering
// as t0, which could minimize strided accesses.
//
// This scheduling is not always feasible. Specifically, if a reshape
// output iter domain is resized, the loop domain needs to keep using
// the reshape output iter domain. Similarly, if a rehape output iter
// domain is reduced, the reshape is currently not cancelled. This is
// because if a reshape has a split and only one of the split output
// iter domain is reduced, the split needs to remain. If a reshape
// only consists of merge transforms, cancellation should be possible,
// but that is not currently supported.
void cancelReshapeInLoopDomains(TensorView* from_tv);

} // namespace scheduler_tools
} // namespace nvfuser
Loading

0 comments on commit 05ec62b

Please sign in to comment.