diff --git a/python_bindings/src/halide/halide_/PyStage.cpp b/python_bindings/src/halide/halide_/PyStage.cpp index b412a6f2b39e..fac47fa3cf1f 100644 --- a/python_bindings/src/halide/halide_/PyStage.cpp +++ b/python_bindings/src/halide/halide_/PyStage.cpp @@ -14,7 +14,7 @@ void define_stage(py::module &m) { .def("dump_argument_list", &Stage::dump_argument_list) .def("name", &Stage::name) - .def("rfactor", (Func(Stage::*)(std::vector>)) & Stage::rfactor, + .def("rfactor", (Func(Stage::*)(const std::vector> &)) & Stage::rfactor, py::arg("preserved")) .def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor, py::arg("r"), py::arg("v")) diff --git a/src/ApplySplit.cpp b/src/ApplySplit.cpp index b6491f063fba..22c3425c02a4 100644 --- a/src/ApplySplit.cpp +++ b/src/ApplySplit.cpp @@ -157,7 +157,6 @@ vector apply_split(const Split &split, const string &prefix, } } break; case Split::RenameVar: - case Split::PurifyRVar: result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution); result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt); break; @@ -167,10 +166,7 @@ vector apply_split(const Split &split, const string &prefix, } vector> compute_loop_bounds_after_split(const Split &split, const string &prefix) { - // Define the bounds on the split dimensions using the bounds - // on the function args. If it is a purify, we should use the bounds - // from the dims instead. - + // Define the bounds on the split dimensions using the bounds on the function args. vector> let_stmts; Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent"); @@ -201,9 +197,6 @@ vector> compute_loop_bounds_after_split(const Split &spl let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max); let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent); break; - case Split::PurifyRVar: - // Do nothing for purify - break; } return let_stmts; diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 724adb993afd..aba76f7798ed 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace Halide { namespace Internal { @@ -297,7 +298,6 @@ class BoundsInference : public IRMutator { } // Default case (no specialization) - vector predicates = def.split_predicate(); for (const ReductionVariable &rv : def.schedule().rvars()) { rvars.insert(rv); } @@ -308,23 +308,15 @@ class BoundsInference : public IRMutator { } vecs[1] = def.values(); + vector predicates = def.split_predicate(); for (size_t i = 0; i < result.size(); ++i) { for (const Expr &val : vecs[i]) { - if (!predicates.empty()) { - Expr cond_val = Call::make(val.type(), - Internal::Call::if_then_else, - {likely(predicates[0]), val}, - Internal::Call::PureIntrinsic); - for (size_t i = 1; i < predicates.size(); ++i) { - cond_val = Call::make(cond_val.type(), - Internal::Call::if_then_else, - {likely(predicates[i]), cond_val}, - Internal::Call::PureIntrinsic); - } - result[i].emplace_back(const_true(), cond_val); - } else { - result[i].emplace_back(const_true(), val); - } + Expr cond_val = std::accumulate( + predicates.begin(), predicates.end(), val, + [](const auto &acc, const auto &pred) { + return Call::make(acc.type(), Call::if_then_else, {likely(pred), acc}, Call::PureIntrinsic); + }); + result[i].emplace_back(const_true(), cond_val); } } diff --git a/src/Derivative.cpp b/src/Derivative.cpp index e64fb4ada94b..05d3168c95d3 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -1532,7 +1532,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call( // f(r.x) = ... && r is associative // => f(x) = ... if (var != nullptr && var->reduction_domain.defined() && - var->reduction_domain.split_predicate().empty()) { + is_const_one(var->reduction_domain.predicate())) { ReductionDomain rdom = var->reduction_domain; int rvar_id = -1; for (int rid = 0; rid < (int)rdom.domain().size(); rid++) { diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index b915a507f090..b6f49cb1bf43 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -368,8 +368,6 @@ Split::SplitType Deserializer::deserialize_split_type(Serialize::SplitType split return Split::SplitType::RenameVar; case Serialize::SplitType::FuseVars: return Split::SplitType::FuseVars; - case Serialize::SplitType::PurifyRVar: - return Split::SplitType::PurifyRVar; default: user_error << "unknown split type " << (int)split_type << "\n"; return Split::SplitType::SplitVar; diff --git a/src/Func.cpp b/src/Func.cpp index 8ffa9cb1e563..e87b14c89c9a 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include #include #ifdef _MSC_VER @@ -35,8 +37,12 @@ namespace Halide { using std::map; using std::ofstream; +using std::optional; using std::pair; using std::string; +using std::tuple; +using std::unordered_map; +using std::unordered_set; using std::vector; using namespace Internal; @@ -425,7 +431,6 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel.count(split.outer)) { parallel.insert(split.old_var); } @@ -448,7 +453,6 @@ void check_for_race_conditions_in_split_with_blend(const StageSchedule &sched) { } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel.count(split.old_var)) { parallel.insert(split.outer); } @@ -602,530 +606,427 @@ class SubstituteSelfReference : public IRMutator { /** Substitute all self-reference calls to 'func' with 'substitute' which * args (LHS) is the old args (LHS) plus 'new_args' in that order. * Expect this method to be called on the value (RHS) of an update definition. */ -Expr substitute_self_reference(Expr val, const string &func, const Function &substitute, - const vector &new_args) { +vector substitute_self_reference(const vector &values, const string &func, + const Function &substitute, const vector &new_args) { SubstituteSelfReference subs(func, substitute, new_args); - val = subs.mutate(val); - return val; -} - -// Substitute the occurrence of 'name' in 'exprs' with 'value'. -void substitute_var_in_exprs(const string &name, const Expr &value, vector &exprs) { - for (auto &expr : exprs) { - expr = substitute(name, value, expr); + vector result; + for (const auto &val : values) { + result.push_back(subs.mutate(val)); } + return result; } -void apply_split_result(const vector> &bounds_let_stmts, - const vector &splits_result, - vector &predicates, vector &args, - vector &values) { - - for (const auto &res : splits_result) { - switch (res.type) { - case ApplySplitResult::Substitution: - case ApplySplitResult::LetStmt: - // Apply substitutions to the list of predicates, args, and values. - // Make sure we substitute in all the let stmts as well since we are - // not going to add them to the exprs. - substitute_var_in_exprs(res.name, res.value, predicates); - substitute_var_in_exprs(res.name, res.value, args); - substitute_var_in_exprs(res.name, res.value, values); - break; - default: - internal_assert(res.type == ApplySplitResult::Predicate); - predicates.push_back(res.value); - break; - } - } +} // anonymous namespace - // Make sure we substitute in all the let stmts from 'bounds_let_stmts' - // since we are not going to add them to the exprs. - for (const auto &let : bounds_let_stmts) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); - } +Func Stage::rfactor(const RVar &r, const Var &v) { + definition.schedule().touched() = true; + return rfactor({{r, v}}); } -/** Apply split directives on the reduction variables. Remove the old RVar from - * the list and add the split result (inner and outer RVars) to the list. Add - * new predicates corresponding to the TailStrategy to the RDom predicate list. */ -bool apply_split(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::SplitVar); - const auto it = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - - Expr old_max, old_min, old_extent; - - if (it != rvars.end()) { - debug(4) << " Splitting " << it->var << " into " << s.outer << " and " << s.inner << "\n"; - - old_max = simplify(it->min + it->extent - 1); - old_min = it->min; - old_extent = it->extent; - - it->var = s.inner; - it->min = 0; - it->extent = s.factor; - - rvars.insert(it + 1, {s.outer, 0, simplify((old_extent - 1 + s.factor) / s.factor)}); - - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +// Helpers for rfactor implementation +namespace { - return true; +optional find_dim(const vector &items, const VarOrRVar &v) { + const auto has_v = std::find_if(items.begin(), items.end(), [&](auto &x) { + return dim_match(x, v); + }); + return has_v == items.end() ? std::nullopt : std::make_optional(*has_v); +} + +using SubstitutionMap = std::map; + +/** This is a helper function for building up a substitution map that + * corresponds to pushing down a nest of lets. The lets should be fed + * to this function from innermost to outermost. This is equivalent to + * building a let-nest as a one-hole context and then simplifying. + * + * This looks like it might be quadratic or worse, and technically it is, + * but this isn't a problem for the way it is used inside rfactor. There + * are only a few uses: + * + * 1. Remapping preserved RVars to new RVars + * 2. Remapping factored RVars to new Vars + * 3. Filling the holes in the associative template + * 4. Accumulating the lets from ApplySplit + * + * These are naturally bounded by O(#splits + #dims) which is quite small + * in practice. Classes (1) and (2) cannot blow up expressions since they + * simply rename variables. Class (3) cannot blow up expressions either + * since nothing else can refer to the holes. That leaves only class (1). + * Fortunately, the lets generated by splits are benign. Split factors can't + * refer to RVars, and we won't see the consumed RVars in another split. So + * in total, this avoids any sort of exponentially sized substitution. + * + * @param subst The existing let nest (represented by a SubstitutionMap). + * @param name The name to bind, cannot already exist in the nest. + * @param value The value to bind. Will be substituted into nested values. + */ +void add_let(SubstitutionMap &subst, const string &name, const Expr &value) { + internal_assert(!subst.count(name)) << "would shadow " << name << " in let nest.\n" + << "\tPresent value: " << subst[name] << "\n" + << "\tProposed value: " << value; + for (auto &[_, e] : subst) { + e = substitute(name, value, e); + } + subst.emplace(name, value); +} + +pair project_rdom(const vector &dims, const ReductionDomain &rdom, const vector &splits) { + // The bounds projections maps expressions that reference the old RDom + // bounds to expressions that reference the new RDom bounds (from dims). + // We call this a projection because we are computing the symbolic image + // of the N-dimensional RDom (dimensionality including splits) in the + // M < N - dimensional result. + SubstitutionMap bounds_projection{}; + for (const Split &split : reverse_view(splits)) { + for (const auto &[name, value] : compute_loop_bounds_after_split(split, "")) { + add_let(bounds_projection, name, value); + } + } + for (const auto &[var, min, extent] : rdom.domain()) { + add_let(bounds_projection, var + ".loop_min", min); + add_let(bounds_projection, var + ".loop_max", min + extent - 1); + add_let(bounds_projection, var + ".loop_extent", extent); + } + + // Build the new RDom from the bounds_projection. + vector new_rvars; + for (const Dim &dim : dims) { + const Expr new_min = simplify(bounds_projection.at(dim.var + ".loop_min")); + const Expr new_extent = simplify(bounds_projection.at(dim.var + ".loop_extent")); + new_rvars.push_back(ReductionVariable{dim.var, new_min, new_extent}); + } + ReductionDomain new_rdom{new_rvars}; + new_rdom.where(rdom.predicate()); + + // Compute a mapping from old dimensions to equivalent values using only + // the new dimensions. For example, if we have an RDom {{0, 20}} and we + // split r.x by 2 into r.xo and r.xi, then this map will contain: + // r.x ~> 2 * r.xo + r.xi + // Certain split tail cases can place additional predicates on the RDom. + // These are handled here, too. + SubstitutionMap dim_projection{}; + SubstitutionMap dim_extent_alignment{}; + for (const auto &[var, _, extent] : rdom.domain()) { + dim_extent_alignment[var] = extent; + } + for (const Split &split : splits) { + for (const auto &result : apply_split(split, "", dim_extent_alignment)) { + switch (result.type) { + case ApplySplitResult::LetStmt: + add_let(dim_projection, result.name, substitute(bounds_projection, result.value)); + break; + case ApplySplitResult::PredicateCalls: + case ApplySplitResult::PredicateProvides: + case ApplySplitResult::Predicate: + new_rdom.where(substitute(bounds_projection, result.value)); + break; + case ApplySplitResult::Substitution: + case ApplySplitResult::SubstitutionInCalls: + case ApplySplitResult::SubstitutionInProvides: + case ApplySplitResult::BlendProvides: + // The lets returned by ApplySplit are sufficient + break; + } + } } - return false; -} - -/** Apply fuse directives on the reduction variables. Remove the - * fused RVars from the list and add the fused RVar to the list. */ -bool apply_fuse(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::FuseVars); - const auto &iter_outer = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.outer == rv.var); }); - const auto &iter_inner = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.inner == rv.var); }); - - Expr inner_min, inner_extent, outer_min, outer_extent; - if ((iter_outer != rvars.end()) && (iter_inner != rvars.end())) { - debug(4) << " Fusing " << s.outer << " and " << s.inner << " into " << s.old_var << "\n"; - - inner_min = iter_inner->min; - inner_extent = iter_inner->extent; - outer_min = iter_outer->min; - outer_extent = iter_outer->extent; - - Expr extent = iter_outer->extent * iter_inner->extent; - iter_outer->var = s.old_var; - iter_outer->min = 0; - iter_outer->extent = extent; - rvars.erase(iter_inner); - - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); - - return true; + for (const auto &rv : new_rdom.domain()) { + add_let(dim_projection, rv.var, Variable::make(Int(32), rv.var, new_rdom)); } - return false; + return {new_rdom, dim_projection}; } -/** Apply purify directives on the reduction variables and predicates. Purify - * replace a RVar with a Var, thus, the RVar needs to be removed from the list. - * Any reference to the RVar in the predicates will be replaced with reference - * to a Var. */ -bool apply_purify(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::PurifyRVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Purify RVar " << iter->var << " into Var " << s.outer - << ", deleting it from the rvars list\n"; - rvars.erase(iter); +} // namespace - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); +pair, vector> Stage::rfactor_validate_args(const std::vector> &preserved, const AssociativeOp &prover_result) { + const vector &dims = definition.schedule().dims(); - return true; - } - return false; -} + user_assert(prover_result.associative()) + << "In schedule for " << name() << ": can't perform rfactor() " + << "because we can't prove associativity of the operator\n" + << dump_argument_list(); -/** Apply rename directives on the reduction variables. */ -bool apply_rename(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values, map &dim_extent_alignment) { - internal_assert(s.split_type == Split::RenameVar); - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&s](const ReductionVariable &rv) { return (s.old_var == rv.var); }); - if (iter != rvars.end()) { - debug(4) << " Renaming " << iter->var << " into " << s.outer << "\n"; - iter->var = s.outer; + unordered_set is_rfactored; + for (const auto &[rv, v] : preserved) { + // Check that the RVars are in the dims list + const auto &rv_dim = find_dim(dims, rv); + user_assert(rv_dim && rv_dim->is_rvar()) + << "In schedule for " << name() << ": can't perform rfactor() " + << "on " << rv.name() << " since either it is not in the reduction " + << "domain, or has already been consumed by another scheduling directive\n" + << dump_argument_list(); - vector splits_result = apply_split(s, "", dim_extent_alignment); - vector> bounds_let_stmts = compute_loop_bounds_after_split(s, ""); - apply_split_result(bounds_let_stmts, splits_result, predicates, args, values); + is_rfactored.insert(rv_dim->var); - return true; + // Check that the new pure Vars we used to rename the RVar aren't already in the dims list + user_assert(!find_dim(dims, v)) + << "In schedule for " << name() << ": can't perform rfactor() " + << "on " << rv.name() << " because the name " << v.name() + << "is already used elsewhere in the Func's schedule.\n" + << dump_argument_list(); } - return false; -} -/** Apply scheduling directives (e.g. split, fuse, etc.) on the reduction - * variables. */ -bool apply_split_directive(const Split &s, vector &rvars, - vector &predicates, vector &args, - vector &values) { - map dim_extent_alignment; - for (const ReductionVariable &rv : rvars) { - dim_extent_alignment[rv.var] = rv.extent; - } + // If the operator is associative but non-commutative, rfactor() on inner + // dimensions (excluding the outer dimensions) is not valid. + if (!prover_result.commutative()) { + optional last_rvar; + for (const auto &d : reverse_view(dims)) { + bool is_inner = is_rfactored.count(d.var) && last_rvar && !is_rfactored.count(last_rvar->var); + user_assert(!is_inner) + << "In schedule for " << name() << ": can't rfactor an inner " + << "dimension " << d.var << " without rfactoring the outer " + << "dimensions, since the operator is non-commutative.\n" + << dump_argument_list(); - vector> rvar_bounds; - for (const ReductionVariable &rv : rvars) { - rvar_bounds.emplace_back(rv.var + ".loop_min", rv.min); - rvar_bounds.emplace_back(rv.var + ".loop_max", simplify(rv.min + rv.extent - 1)); - rvar_bounds.emplace_back(rv.var + ".loop_extent", rv.extent); + if (d.is_rvar()) { + last_rvar = d; + } + } } - bool found = false; - switch (s.split_type) { - case Split::SplitVar: - found = apply_split(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::FuseVars: - found = apply_fuse(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::PurifyRVar: - found = apply_purify(s, rvars, predicates, args, values, dim_extent_alignment); - break; - case Split::RenameVar: - found = apply_rename(s, rvars, predicates, args, values, dim_extent_alignment); - break; + // Check that no Vars were fused into RVars + vector var_splits, rvar_splits; + Scope<> rdims; + for (const ReductionVariable &rv : definition.schedule().rvars()) { + rdims.push(rv.var); } + for (const Split &split : definition.schedule().splits()) { + switch (split.split_type) { + case Split::SplitVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + rdims.push(split.inner); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; + case Split::FuseVars: + if (rdims.contains(split.outer) || rdims.contains(split.inner)) { + user_assert(rdims.contains(split.outer) && rdims.contains(split.inner)) + << "In schedule for " << name() << ": can't rfactor an Func " + << "that has fused a Var into an RVar: " << split.outer + << ", " << split.inner << "\n" + << dump_argument_list(); - if (found) { - for (const auto &let : rvar_bounds) { - substitute_var_in_exprs(let.first, let.second, predicates); - substitute_var_in_exprs(let.first, let.second, args); - substitute_var_in_exprs(let.first, let.second, values); + rdims.pop(split.outer); + rdims.pop(split.inner); + rdims.push(split.old_var); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; + case Split::RenameVar: + if (rdims.contains(split.old_var)) { + rdims.pop(split.old_var); + rdims.push(split.outer); + rvar_splits.emplace_back(split); + } else { + var_splits.emplace_back(split); + } + break; } } - return found; + return std::make_pair(std::move(var_splits), std::move(rvar_splits)); } -} // anonymous namespace - -Func Stage::rfactor(const RVar &r, const Var &v) { - definition.schedule().touched() = true; - return rfactor({{r, v}}); -} - -Func Stage::rfactor(vector> preserved) { +Func Stage::rfactor(const vector> &preserved) { user_assert(!definition.is_init()) << "rfactor() must be called on an update definition\n"; definition.schedule().touched() = true; - const string &func_name = function.name(); - vector &args = definition.args(); - vector &values = definition.values(); - - // Figure out which pure vars were used in this update definition. - std::set pure_vars_used; - internal_assert(args.size() == dim_vars.size()); - for (size_t i = 0; i < args.size(); i++) { - if (const Internal::Variable *var = args[i].as()) { - if (var->name == dim_vars[i].name()) { - pure_vars_used.insert(var->name); - } - } - } - // Check whether the operator is associative and determine the operator and // its identity for each value in the definition if it is a Tuple - const auto &prover_result = prove_associativity(func_name, args, values); - - user_assert(prover_result.associative()) - << "Failed to call rfactor() on " << name() - << " since it can't prove associativity of the operator\n"; - internal_assert(prover_result.size() == values.size()); + const auto &prover_result = prove_associativity(function.name(), definition.args(), definition.values()); + + const auto &[var_splits, rvar_splits] = rfactor_validate_args(preserved, prover_result); + + const vector dim_vars_exprs = [&] { + vector result; + result.insert(result.end(), dim_vars.begin(), dim_vars.end()); + return result; + }(); + + // sort preserved by the dimension ordering + vector preserved_rvars; + vector preserved_vars; + vector preserved_rdims; + unordered_set preserved_rdims_set; + vector intermediate_rdims; + { + unordered_map dim_ordering; + for (size_t i = 0; i < definition.schedule().dims().size(); i++) { + dim_ordering.emplace(definition.schedule().dims()[i].var, i); + } - vector &splits = definition.schedule().splits(); - vector &dims = definition.schedule().dims(); - vector &rvars = definition.schedule().rvars(); - vector predicates = definition.split_predicate(); + vector> preserved_with_dims; + for (const auto &[rv, v] : preserved) { + const optional rdim = find_dim(definition.schedule().dims(), rv); + internal_assert(rdim); + preserved_with_dims.emplace_back(rv, v, *rdim); + } - Scope scope; // Contains list of RVars lifted to the intermediate Func - vector rvars_removed; + std::sort(preserved_with_dims.begin(), preserved_with_dims.end(), [&](const auto &lhs, const auto &rhs) { + return dim_ordering.at(std::get<2>(lhs).var) < dim_ordering.at(std::get<2>(rhs).var); + }); - vector is_rfactored(dims.size(), false); - for (const pair &i : preserved) { - const RVar &rv = i.first; - const Var &v = i.second; - { - // Check that the RVar are in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&rv](const Dim &dim) { return var_name_match(dim.var, rv.name()); }); - user_assert((iter != dims.end()) && (*iter).is_rvar()) - << "In schedule for " << name() - << ", can't perform rfactor() on " << rv.name() - << " since it is not in the reduction domain\n" - << dump_argument_list(); - is_rfactored[iter - dims.begin()] = true; + for (const auto &[rv, v, dim] : preserved_with_dims) { + preserved_rvars.push_back(rv); + preserved_vars.push_back(v); + preserved_rdims.push_back(dim); + preserved_rdims_set.insert(dim.var); } - { - // Check that the new pure Vars we used to rename the RVar aren't already in the dims list - const auto &iter = std::find_if(dims.begin(), dims.end(), - [&v](const Dim &dim) { return var_name_match(dim.var, v.name()); }); - user_assert(iter == dims.end()) - << "In schedule for " << name() - << ", can't rename the rvars " << rv.name() << " into " << v.name() - << ", since it is already used in this Func's schedule elsewhere.\n" - << dump_argument_list(); - } - } - // If the operator is associative but non-commutative, rfactor() on inner - // dimensions (excluding the outer dimensions) is not valid. - if (!prover_result.commutative()) { - int last_rvar = -1; - for (int i = dims.size() - 1; i >= 0; --i) { - if ((last_rvar != -1) && is_rfactored[i]) { - user_assert(is_rfactored[last_rvar]) - << "In schedule for " << name() - << ", can't rfactor an inner dimension " << dims[i].var - << " without rfactoring the outer dimensions, since the " - << "operator is non-commutative.\n" - << dump_argument_list(); - } - if (dims[i].is_rvar()) { - last_rvar = i; + for (const Dim &dim : definition.schedule().dims()) { + if (dim.is_rvar() && !preserved_rdims_set.count(dim.var)) { + intermediate_rdims.push_back(dim); } } } - // We need to apply the split directives on the reduction vars, so that we can - // correctly lift the RVars not in 'rvars_kept' and distribute the RVars to the - // intermediate and merge Funcs. + // Project the RDom into each side + ReductionDomain intermediate_rdom, preserved_rdom; + SubstitutionMap intermediate_map, preserved_map; { - vector temp; - for (const Split &s : splits) { - // If it's already applied, we should remove it from the split list. - if (!apply_split_directive(s, rvars, predicates, args, values)) { - temp.push_back(s); - } - } - splits = temp; - } - - // Reduction domain of the intermediate update definition - vector intm_rvars; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(preserved.begin(), preserved.end(), - [&rv](const pair &pair) { return var_name_match(rv.var, pair.first.name()); }); - if (iter == preserved.end()) { - intm_rvars.push_back(rv); - scope.push(rv.var, rv.var); - } - } - RDom intm_rdom(intm_rvars); - - // Sort the Rvars kept and their Vars replacement based on the RVars of - // the reduction domain AFTER applying the split directives, so that we - // can have a consistent args order for the update definition of the - // intermediate and new merge Funcs. - std::sort(preserved.begin(), preserved.end(), - [&](const pair &lhs, const pair &rhs) { - const auto &iter_lhs = std::find_if(rvars.begin(), rvars.end(), - [&lhs](const ReductionVariable &rv) { return var_name_match(rv.var, lhs.first.name()); }); - const auto &iter_rhs = std::find_if(rvars.begin(), rvars.end(), - [&rhs](const ReductionVariable &rv) { return var_name_match(rv.var, rhs.first.name()); }); - return iter_lhs < iter_rhs; - }); - // The list of RVars to keep in the new update definition - vector rvars_kept(preserved.size()); - // List of pure Vars to replace the RVars in the intermediate's update definition - vector vars_rename(preserved.size()); - for (size_t i = 0; i < preserved.size(); ++i) { - const auto &val = preserved[i]; - rvars_kept[i] = val.first; - vars_rename[i] = val.second; - } - - // List of RVars for the new reduction domain. Any RVars not in 'rvars_kept' - // are removed from the RDom - { - vector temp; - for (const auto &rv : rvars) { - const auto &iter = std::find_if(rvars_kept.begin(), rvars_kept.end(), - [&rv](const RVar &rvar) { return var_name_match(rv.var, rvar.name()); }); - if (iter != rvars_kept.end()) { - temp.push_back(rv); - } else { - rvars_removed.push_back(rv.var); - } - } - rvars.swap(temp); - } - RDom f_rdom(rvars); - - // Init definition of the intermediate Func + ReductionDomain rdom{definition.schedule().rvars(), definition.predicate(), true}; - // Compute args of the init definition of the intermediate Func. - // Replace the RVars, which are in 'rvars_kept', with the specified new pure - // Vars. Also, add the pure Vars of the original init definition as part of - // the args. - // For example, if we have the following Func f: - // f(x, y) = 10 - // f(r.x, r.y) += h(r.x, r.y) - // Calling f.update(0).rfactor({{r.y, u}}) will generate the following - // intermediate Func: - // f_intm(x, y, u) = 0 - // f_intm(r.x, u, u) += h(r.x, u) - - vector init_args; - init_args.insert(init_args.end(), dim_vars.begin(), dim_vars.end()); - init_args.insert(init_args.end(), vars_rename.begin(), vars_rename.end()); + // Intermediate + std::tie(intermediate_rdom, intermediate_map) = project_rdom(intermediate_rdims, rdom, rvar_splits); + for (size_t i = 0; i < preserved.size(); i++) { + add_let(intermediate_map, preserved_rdims[i].var, preserved_vars[i]); + } + intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate()))); - vector init_vals(values.size()); - for (size_t i = 0; i < init_vals.size(); ++i) { - init_vals[i] = prover_result.pattern.identities[i]; + // Preserved + std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits); + Scope intm_rdom; + for (const auto &[var, min, extent] : intermediate_rdom.domain()) { + intm_rdom.push(var, Interval{min, min + extent - 1}); + } + preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom)); } - Func intm(func_name + "_intm"); - intm(init_args) = Tuple(init_vals); + // Intermediate func + Func intm(function.name() + "_intm"); - // Args of the update definition of the intermediate Func - vector update_args(args.size() + vars_rename.size()); - - // We need to substitute the reference to the old RDom's RVars with - // the new RDom's RVars. Also, substitute the reference to RVars which - // are in 'rvars_kept' with their corresponding new pure Vars - map substitution_map; - for (size_t i = 0; i < intm_rvars.size(); ++i) { - substitution_map[intm_rvars[i].var] = intm_rdom[i]; - } - for (size_t i = 0; i < vars_rename.size(); i++) { - update_args[i + args.size()] = vars_rename[i]; - RVar rvar_kept = rvars_kept[i]; - // Find the full name of rvar_kept in rvars - const auto &iter = std::find_if(rvars.begin(), rvars.end(), - [&rvar_kept](const ReductionVariable &rv) { return var_name_match(rv.var, rvar_kept.name()); }); - substitution_map[iter->var] = vars_rename[i]; - } - for (size_t i = 0; i < args.size(); i++) { - Expr arg = substitute(substitution_map, args[i]); - update_args[i] = arg; + // Intermediate pure definition + { + vector args = dim_vars_exprs; + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + intm(args) = Tuple(prover_result.pattern.identities); } - // Compute the predicates for the intermediate Func and the new update definition - for (const Expr &pred : predicates) { - Expr subs_pred = substitute(substitution_map, pred); - intm_rdom.where(subs_pred); - if (!expr_uses_vars(pred, scope)) { - // Only keep the predicate that does not depend on the lifted RVars - // (either explicitly or implicitly). For example, if 'rx' is split - // into 'rxo' and 'rxi' and 'rxo' is part of the lifted RVars, we'll - // ignore every predicate that depends on 'rx' - f_rdom.where(pred); + // Intermediate update definition + { + vector args = definition.args(); + args.insert(args.end(), preserved_vars.begin(), preserved_vars.end()); + args = substitute(intermediate_map, args); + + vector values = definition.values(); + values = substitute_self_reference(values, function.name(), intm.function(), preserved_vars); + values = substitute(intermediate_map, values); + intm.function().define_update(args, values, intermediate_rdom); + + // Intermediate schedule + vector intm_dims = definition.schedule().dims(); + + // Replace rvar dims IN the preserved list with their Vars in the INTERMEDIATE Func + for (auto &dim : intm_dims) { + const auto it = std::find_if(preserved_rvars.begin(), preserved_rvars.end(), [&](const auto &rv) { + return dim_match(dim, rv); + }); + if (it != preserved_rvars.end()) { + const auto offset = it - preserved_rvars.begin(); + const auto &var = preserved_vars[offset]; + const auto &pure_dim = find_dim(intm.function().definition().schedule().dims(), var); + internal_assert(pure_dim); + dim = *pure_dim; + } } - } - definition.predicate() = f_rdom.domain().predicate(); - // The update values the intermediate Func should compute - vector update_vals(values.size()); - for (size_t i = 0; i < update_vals.size(); i++) { - Expr val = substitute(substitution_map, values[i]); - // Need to update the self-reference in the update definition to point - // to the new intermediate Func - val = substitute_self_reference(val, func_name, intm.function(), vars_rename); - update_vals[i] = val; - } - // There may not actually be a reference to the RDom in the args or values, - // so we use Function::define_update, which lets pass pass an explicit RDom. - intm.function().define_update(update_args, update_vals, intm_rdom.domain()); - - // Determine the dims and schedule of the update definition of the - // intermediate Func. We copy over the schedule from the original - // update definition (e.g. split, parallelize, vectorize, etc.) - intm.function().update(0).schedule().dims() = dims; - intm.function().update(0).schedule().splits() = splits; - - // Copy over the storage order of the original pure dims - vector &intm_storage_dims = intm.function().schedule().storage_dims(); - internal_assert(intm_storage_dims.size() == - function.schedule().storage_dims().size() + vars_rename.size()); - for (size_t i = 0; i < function.schedule().storage_dims().size(); ++i) { - intm_storage_dims[i] = function.schedule().storage_dims()[i]; - } - - for (size_t i = 0; i < rvars_kept.size(); ++i) { - // Apply the purify directive that replaces the RVar in rvars_kept - // with a pure Var - intm.update(0).purify(rvars_kept[i], vars_rename[i]); - } + // Add factored pure dims to the INTERMEDIATE func just before outermost + unordered_set dims; + for (const auto &dim : intm_dims) { + dims.insert(dim.var); + } + for (const Var &var : preserved_vars) { + const optional &dim = find_dim(intm.function().definition().schedule().dims(), var); + internal_assert(dim) << "Failed to find " << var.name() << " in list of pure dims"; + if (!dims.count(dim->var)) { + intm_dims.insert(intm_dims.end() - 1, *dim); + } + } - // Determine the dims of the new update definition - - // The new update definition needs all the pure vars of the Func, but the - // one we're rfactoring may not have used them all. Add any missing ones to - // the dims list. - - // Add pure Vars from the original init definition to the dims list - // if they are not already in the list - for (const Var &v : dim_vars) { - if (!pure_vars_used.count(v.name())) { - Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; - // Insert it just before Var::outermost - dims.insert(dims.end() - 1, d); - } + intm.function().update(0).schedule() = definition.schedule().get_copy(); + intm.function().update(0).schedule().dims() = std::move(intm_dims); + intm.function().update(0).schedule().rvars() = intermediate_rdom.domain(); + intm.function().update(0).schedule().splits() = var_splits; } - // Then, we need to remove lifted RVars from the dims list - for (const string &rv : rvars_removed) { - remove(rv); - } + // Preserved update definition + { + // Replace the current definition with calls to the intermediate func. + vector f_load_args = dim_vars_exprs; + for (const ReductionVariable &rv : preserved_rdom.domain()) { + f_load_args.push_back(Variable::make(Int(32), rv.var, preserved_rdom)); + } - // Define the new update definition which refers to the intermediate Func. - // Using the same example as above, the new update definition is: - // f(x, y) += f_intm(x, y, r.y) + for (size_t i = 0; i < definition.values().size(); ++i) { + if (!prover_result.ys[i].var.empty()) { + Expr r = (definition.values().size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); + add_let(preserved_map, prover_result.ys[i].var, r); + } - // Args for store in the new update definition - vector f_store_args(dim_vars.size()); - for (size_t i = 0; i < f_store_args.size(); ++i) { - f_store_args[i] = dim_vars[i]; - } - - // Call's args to the intermediate Func in the new update definition - vector f_load_args; - f_load_args.insert(f_load_args.end(), dim_vars.begin(), dim_vars.end()); - for (int i = 0; i < f_rdom.dimensions(); ++i) { - f_load_args.push_back(f_rdom[i]); - } - internal_assert(f_load_args.size() == init_args.size()); + if (!prover_result.xs[i].var.empty()) { + Expr prev_val = Call::make(intm.types()[i], function.name(), + dim_vars_exprs, Call::CallType::Halide, + FunctionPtr(), i); + add_let(preserved_map, prover_result.xs[i].var, prev_val); + } else { + user_warning << "Update definition of " << name() << " at index " << i + << " doesn't depend on the previous value. This isn't a" + << " reduction operation\n"; + } + } - // Update value of the new update definition. It loads values from - // the intermediate Func. - vector f_values(values.size()); + vector reducing_dims; + { + // Remove rvar dims NOT IN the preserved list from the REDUCING Func + for (const auto &dim : definition.schedule().dims()) { + if (!dim.is_rvar() || preserved_rdims_set.count(dim.var)) { + reducing_dims.push_back(dim); + } + } - // There might be cross-dependencies between tuple elements, so we need - // to collect all substitutions first. - map replacements; - for (size_t i = 0; i < f_values.size(); ++i) { - if (!prover_result.ys[i].var.empty()) { - Expr r = (values.size() == 1) ? Expr(intm(f_load_args)) : Expr(intm(f_load_args)[i]); - replacements.emplace(prover_result.ys[i].var, r); + // Add missing pure vars to the REDUCING func just before outermost. + // This is necessary whenever the update does not reference one of the + // pure variables. For instance, factoring a histogram (clamps elided): + // g(x) = 0; g(f(r.x, r.y)) += 1; + // Func intm = g.rfactor(r.y, u); + // Here we generate an intermediate func intm that looks like: + // intm(x, u) = 0; intm(f(r.x, u), u) += 1; + // And we need the reducing func to be: + // g(x) += intm(x, r.y); + // But x was not referenced in the original update definition, so that + // dimension is added here. + for (size_t i = 0; i < dim_vars.size(); i++) { + if (!expr_uses_var(definition.args()[i], dim_vars[i].name())) { + Dim d = {dim_vars[i].name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto}; + reducing_dims.insert(reducing_dims.end() - 1, d); + } + } } - if (!prover_result.xs[i].var.empty()) { - Expr prev_val = Call::make(intm.types()[i], func_name, - f_store_args, Call::CallType::Halide, - FunctionPtr(), i); - replacements.emplace(prover_result.xs[i].var, prev_val); - } else { - user_warning << "Update definition of " << name() << " at index " << i - << " doesn't depend on the previous value. This isn't a" - << " reduction operation\n"; - } - } - for (size_t i = 0; i < f_values.size(); ++i) { - f_values[i] = substitute(replacements, prover_result.pattern.ops[i]); + definition.args() = dim_vars_exprs; + definition.values() = substitute(preserved_map, prover_result.pattern.ops); + definition.predicate() = preserved_rdom.predicate(); + definition.schedule().dims() = std::move(reducing_dims); + definition.schedule().rvars() = preserved_rdom.domain(); + definition.schedule().splits() = var_splits; } - // Update the definition - args.swap(f_store_args); - values.swap(f_values); - return intm; } @@ -1187,7 +1088,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool round_up_ok = !exact; if (round_up_ok && !definition.is_init()) { // If it's the outermost split in this dimension, RoundUp - // is OK. Otherwise we need GuardWithIf to avoid + // is OK. Otherwise, we need GuardWithIf to avoid // recomputing values in the case where the inner split // factor does not divide the outer split factor. std::set inner_vars; @@ -1200,7 +1101,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } @@ -1224,7 +1124,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c bool predicate_loads_ok = !exact; if (predicate_loads_ok && tail == TailStrategy::PredicateLoads) { // If it's the outermost split in this dimension, PredicateLoads - // is OK. Otherwise we can't prove it's safe. + // is OK. Otherwise, we can't prove it's safe. std::set inner_vars; for (const Split &s : definition.schedule().splits()) { switch (s.split_type) { @@ -1235,7 +1135,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (inner_vars.count(s.old_var)) { inner_vars.insert(s.outer); } @@ -1297,7 +1196,6 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } break; case Split::RenameVar: - case Split::PurifyRVar: if (it != descends_from_shiftinwards_outer.end()) { descends_from_shiftinwards_outer[s.outer] = it->second; } @@ -1484,46 +1382,6 @@ void Stage::specialize_fail(const std::string &message) { s.failure_message = message; } -Stage &Stage::purify(const VarOrRVar &old_var, const VarOrRVar &new_var) { - user_assert(old_var.is_rvar && !new_var.is_rvar) - << "In schedule for " << name() - << ", can't rename " << (old_var.is_rvar ? "RVar " : "Var ") << old_var.name() - << " to " << (new_var.is_rvar ? "RVar " : "Var ") << new_var.name() - << "; purify must take a RVar as old_Var and a Var as new_var\n"; - - debug(4) << "In schedule for " << name() << ", purify RVar " - << old_var.name() << " to Var " << new_var.name() << "\n"; - - StageSchedule &schedule = definition.schedule(); - - // Replace the old dimension with the new dimensions in the dims list - bool found = false; - string old_name, new_name = new_var.name(); - vector &dims = schedule.dims(); - - for (size_t i = 0; (!found) && i < dims.size(); i++) { - if (dim_match(dims[i], old_var)) { - found = true; - old_name = dims[i].var; - dims[i].var = new_name; - dims[i].dim_type = DimType::PureVar; - } - } - - if (!found) { - user_error - << "In schedule for " << name() - << ", could not find rename dimension: " - << old_var.name() - << "\n" - << dump_argument_list(); - } - - Split split = {old_name, new_name, "", 1, false, TailStrategy::RoundUp, Split::PurifyRVar}; - definition.schedule().splits().push_back(split); - return *this; -} - void Stage::remove(const string &var) { debug(4) << "In schedule for " << name() << ", remove " << var << "\n"; @@ -1601,7 +1459,6 @@ void Stage::remove(const string &var) { } break; case Split::RenameVar: - case Split::PurifyRVar: debug(4) << " replace/rename " << split.old_var << " into " << split.outer << "\n"; if (should_remove(split.outer)) { @@ -1690,7 +1547,6 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { break; case Split::SplitVar: case Split::RenameVar: - case Split::PurifyRVar: if (split.inner == old_name) { split.inner = new_name; found = true; diff --git a/src/Func.h b/src/Func.h index ae739b8dc538..d0d566a2e1c2 100644 --- a/src/Func.h +++ b/src/Func.h @@ -60,6 +60,7 @@ struct VarOrRVar { class ImageParam; namespace Internal { +struct AssociativeOp; class Function; struct Split; struct StorageDim; @@ -81,7 +82,6 @@ class Stage { void split(const std::string &old, const std::string &outer, const std::string &inner, const Expr &factor, bool exact, TailStrategy tail); void remove(const std::string &var); - Stage &purify(const VarOrRVar &old_name, const VarOrRVar &new_name); const std::vector &storage_dims() const { return function.schedule().storage_dims(); @@ -89,6 +89,9 @@ class Stage { Stage &compute_with(LoopLevel loop_level, const std::map &align); + std::pair, std::vector> + rfactor_validate_args(const std::vector> &preserved, const Internal::AssociativeOp &prover_result); + public: Stage(Internal::Function f, Internal::Definition d, size_t stage_index) : function(std::move(f)), definition(std::move(d)), stage_index(stage_index) { @@ -184,7 +187,7 @@ class Stage { * */ // @{ - Func rfactor(std::vector> preserved); + Func rfactor(const std::vector> &preserved); Func rfactor(const RVar &r, const Var &v); // @} diff --git a/src/Inline.cpp b/src/Inline.cpp index 54399cf77b76..31b6efcdf749 100644 --- a/src/Inline.cpp +++ b/src/Inline.cpp @@ -76,8 +76,6 @@ void validate_schedule_inlined_function(Function f) { << split.inner << " because " << f.name() << " is scheduled inline.\n"; - break; - case Split::PurifyRVar: break; } } diff --git a/src/Schedule.h b/src/Schedule.h index ea2692752a9e..906dbe6c7b5e 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -332,8 +332,7 @@ struct Split { enum SplitType { SplitVar = 0, RenameVar, - FuseVars, - PurifyRVar }; + FuseVars }; // If split_type is Rename, then this is just a renaming of the // old_var to the outer and not a split. The inner var should @@ -341,10 +340,6 @@ struct Split { // the same list as splits so that ordering between them is // respected. - // If split type is Purify, this replaces the old_var RVar to - // the outer Var. The inner var should be ignored, and factor - // should be one. - // If split_type is Fuse, then this does the opposite of a // split, it joins the outer and inner into the old_var. SplitType split_type; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index c7a257dd085e..02f5553f748c 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -178,7 +178,6 @@ Stmt build_loop_nest( const auto &dims = func.args(); const auto &func_s = func.schedule(); const auto &stage_s = def.schedule(); - const auto &predicates = def.split_predicate(); // We'll build it from inside out, starting from the body, // then wrapping it in for loops. @@ -306,7 +305,7 @@ Stmt build_loop_nest( } // Put all the reduction domain predicates into the containers vector. - for (Expr pred : predicates) { + for (Expr pred : def.split_predicate()) { pred = qualify(prefix, pred); // Add a likely qualifier if there isn't already one if (Call::as_intrinsic(pred, {Call::likely, Call::likely_if_innermost})) { @@ -413,8 +412,7 @@ Stmt build_loop_nest( } // Define the bounds on the split dimensions using the bounds - // on the function args. If it is a purify, we should use the bounds - // from the dims instead. + // on the function args. for (const Split &split : reverse_view(splits)) { vector> let_stmts = compute_loop_bounds_after_split(split, prefix); for (const auto &let_stmt : let_stmts) { @@ -2229,7 +2227,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ // // However, there are four types of Split, and the concept of a child var varies across them: // - For a vanilla split, inner and outer are the children and old_var is the parent. - // - For rename and purify, the outer is the child and the inner is meaningless. + // - For rename, the outer is the child and the inner is meaningless. // - For fuse, old_var is the child and inner/outer are the parents. // // (@abadams comments: "I acknowledge that this is gross and should be refactored.") @@ -2249,7 +2247,6 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ } break; case Split::RenameVar: - case Split::PurifyRVar: if (parallel_vars.count(split.outer)) { parallel_vars.insert(split.old_var); } diff --git a/src/Serialization.cpp b/src/Serialization.cpp index 15722d878974..d731d9c9d85c 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -338,8 +338,6 @@ Serialize::SplitType Serializer::serialize_split_type(const Split::SplitType &sp return Serialize::SplitType::RenameVar; case Split::SplitType::FuseVars: return Serialize::SplitType::FuseVars; - case Split::SplitType::PurifyRVar: - return Serialize::SplitType::PurifyRVar; default: user_error << "Unsupported split type\n"; return Serialize::SplitType::SplitVar; diff --git a/src/Solve.cpp b/src/Solve.cpp index 3f124601345a..20d6f5200101 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -1239,6 +1239,10 @@ Expr and_condition_over_domain(const Expr &e, const Scope &varying) { return simplify(bounds.min); } +Expr or_condition_over_domain(const Expr &c, const Scope &varying) { + return simplify(!and_condition_over_domain(simplify(!c), varying)); +} + // Testing code namespace { diff --git a/src/Solve.h b/src/Solve.h index ff5124e508c6..4d06fda47d6b 100644 --- a/src/Solve.h +++ b/src/Solve.h @@ -47,6 +47,13 @@ Interval solve_for_inner_interval(const Expr &c, const std::string &variable); * 'and' over the vector lanes, and return a scalar result. */ Expr and_condition_over_domain(const Expr &c, const Scope &varying); +/** Take a conditional that includes variables that vary over some + * domain, and convert it to a weaker (less frequently false) condition + * that doesn't depend on those variables. Formally, the input expr + * implies the output expr. Note that this function might be unable to + * provide a better response than simply const_true(). */ +Expr or_condition_over_domain(const Expr &c, const Scope &varying); + void solve_test(); } // namespace Internal diff --git a/src/Substitute.h b/src/Substitute.h index 22bdf640b7a8..e514fda40359 100644 --- a/src/Substitute.h +++ b/src/Substitute.h @@ -6,6 +6,7 @@ * Defines methods for substituting out variables in expressions and * statements. */ +#include #include #include "Expr.h" @@ -37,6 +38,16 @@ Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr); Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt); // @} +/** Substitute a container of Exprs or Stmts out of place */ +template +T substitute(const std::map &replacements, const T &container) { + T output; + std::transform(container.begin(), container.end(), std::back_inserter(output), [&](const auto &expr_or_stmt) { + return substitute(replacements, expr_or_stmt); + }); + return output; +} + /** Substitutions where the IR may be a general graph (and not just a * DAG). */ // @{ diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index efc465cbee82..499488ce8b95 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -548,7 +548,6 @@ enum SplitType: ubyte { SplitVar, RenameVar, FuseVars, - PurifyRVar, } table Split { diff --git a/test/common/expect_abort.cpp b/test/common/expect_abort.cpp index cb09a7242921..fec89b0913b7 100644 --- a/test/common/expect_abort.cpp +++ b/test/common/expect_abort.cpp @@ -19,6 +19,9 @@ auto handler = ([]() { << std::flush; suppress_abort = false; std::abort(); // We should never EXPECT an internal error + } catch (const Halide::Error &e) { + std::cerr << e.what() << "\n" + << std::flush; } catch (const std::exception &e) { std::cerr << e.what() << "\n" << std::flush; diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 0478c3b11087..5272b2717de7 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -94,6 +94,8 @@ tests(GROUPS error require_fail.cpp reuse_var_in_schedule.cpp reused_args.cpp + rfactor_after_var_and_rvar_fusion.cpp + rfactor_fused_var_and_rvar.cpp rfactor_inner_dim_non_commutative.cpp round_up_and_blend_race.cpp run_with_large_stack_throws.cpp diff --git a/test/error/rfactor_after_var_and_rvar_fusion.cpp b/test/error/rfactor_after_var_and_rvar_fusion.cpp new file mode 100644 index 000000000000..acda4e4bb6fb --- /dev/null +++ b/test/error/rfactor_after_var_and_rvar_fusion.cpp @@ -0,0 +1,25 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing y and r$z + f.update() + .fuse(r.x, r.y, rxy) + .fuse(r.z, y, yrz) + .rfactor(rxy, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +} diff --git a/test/error/rfactor_fused_var_and_rvar.cpp b/test/error/rfactor_fused_var_and_rvar.cpp new file mode 100644 index 000000000000..64a79c269690 --- /dev/null +++ b/test/error/rfactor_fused_var_and_rvar.cpp @@ -0,0 +1,26 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f{"f"}; + RDom r({{0, 5}, {0, 5}, {0, 5}}, "r"); + Var x{"x"}, y{"y"}; + f(x, y) = 0; + f(x, y) += r.x + r.y + r.z; + + RVar rxy{"rxy"}, yrz{"yrz"}, yr{"yr"}; + Var z{"z"}; + + // Error: In schedule for f.update(0), can't perform rfactor() after fusing r$z and y + f.update() + .fuse(r.x, r.y, rxy) + .fuse(y, r.z, yrz) + .fuse(rxy, yrz, yr) + .rfactor(yr, z); + + f.print_loop_nest(); + + printf("Success!\n"); + return 0; +}