Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Resize scheduler update #3657

Draft
wants to merge 126 commits into
base: main
Choose a base branch
from
Draft

[WIP] Resize scheduler update #3657

wants to merge 126 commits into from

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 31, 2024

No description provided.

Copy link

github-actions bot commented Jan 15, 2025

PR Reviewer Guide 🔍

(Review updated until commit 75e0aef)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 No relevant tests
⚡ Recommended focus areas for review

Potential Logic Change

The IndexLowering::handle function has been modified to include a new condition for omitting the pad predicate. This change may affect the logic of the function and should be reviewed carefully.

void IndexLowering::handle(const PadOp* pad) {
  // Convert to a where op as:
  // consumer[consumer_idx] = (consumer_idx >= left_pad && consumer_idx <
  //                           consumer_extent - right_pad) ?
  //     producer[producer_idx] :
  //     0;

  auto producer_tv = pad->in()->as<TensorView>();
  auto consumer_tv = pad->out()->as<TensorView>();
  auto producer_doms =
      TensorDomain::noReductions(producer_tv->getLogicalDomain());

  const auto in = lowerSrcIndex(pad->in(), pad->out());
  const auto out = lowerDstIndex(pad->out());

  const auto pad_val = pad->value();

  // Build a predicate for where
  bool can_omit_where_predicate = canOmitPadPredicate(pad);

  if (can_omit_where_predicate) {
    pushBack(IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, in));
  } else {
    auto consumer_root_indices = Index::getConsumerPerDimLogicalIndex(
        consumer_tv, for_loops_, getRotatedLoop());
    Val* pred = consumer_tv->fusion()->trueVal();
    for (auto padded_axis : pad->getPaddedAxes()) {
      auto consumer_idx = consumer_root_indices.at(padded_axis);
      auto consumer_root_id = consumer_tv->getLogicalDomain().at(padded_axis);
      NVF_ERROR(!consumer_root_id->maybePartial());
      const auto& pad_widths = pad->getPadWidths(padded_axis);
      pred = SimplifyingIrBuilder::logicalAndExpr(
          pred,
          // idx >= left_pad && idx < extent - right_pad
          SimplifyingIrBuilder::logicalAndExpr(
              SimplifyingIrBuilder::geExpr(consumer_idx, pad_widths.first),
              SimplifyingIrBuilder::ltExpr(
                  consumer_idx,
                  SimplifyingIrBuilder::subExpr(
                      consumer_root_id->getMaybeExpandedExtent(),
                      pad_widths.second))));
    }

    pred = GpuLower::current()->commonScalarMap().hoistScalar(pred, for_loops_);

    pushBack(IrBuilder::create<TernaryOp>(
        TernaryOpType::Where, out, pred, in, pad_val));
  }

  GpuLower::current()->propagateExprInfo(pad, back());
}
Potential Logic Change

The hasNonUniqueBcast function has been modified to include an additional parameter check_static_size. This change may affect the logic of the function and should be reviewed carefully.

// Reusing some code from lowering specifically in lower_trivial_broadcast.cpp
// ConcretizedBroadcastDomains::maybeNonUniquelyConcretized this checks if
// there's a broadcast iteration domain that's being broadcasted to seemingly
// different extents, meaning we don't know in the kernel if the dimension is
// being broadcasted to one size multiple times or different sizes. This is a
// hard to optimize problem and likely indicates we shouldn't be fusing.
bool hasNonUniqueBcast(Fusion* fusion, bool check_static_size) {
  ConcretizedBroadcastDomains concretize_info(fusion);

  for (auto tv : fusion->allTvs()) {
    for (auto id : tv->getMaybeRootDomain()) {
      if (concretize_info.maybeNonUniquelyConcretized(id)) {
        if (check_static_size) {
          int64_t static_size = -1;
          for (auto concrete_id : concretize_info.allConcretizedDomains(id)) {
            if (!concrete_id->extent()->isConstInt()) {
              return true;
            }
            auto this_static_size =
                concrete_id->extent()->evaluate().as<int64_t>();
            if (static_size == -1) {
              static_size = this_static_size;
            } else if (static_size != this_static_size) {
              return true;
            }
          }

          std::cerr << "Concretized to multiple IDs but same extents: "
                    << id->toString() << " -> "
                    << toDelimitedString(
                           concretize_info.allConcretizedDomains(id))
                    << "\n";
        } else {
          return true;
        }
      }
    }
  }
  return false;
Potential Logic Change

The ResizeScheduler::schedule function has been modified to include additional logging and debugging statements. This change may affect the logic of the function and should be reviewed carefully.

void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) {
  FUSER_PERF_SCOPE("ResizeScheduler::schedule");

  FusionGuard fg(fusion);
  const auto resize_params = dynamic_cast<const ResizeParams*>(params);
  NVF_ERROR(resize_params != nullptr);

  scheduler_utils::clearMemorySpace(fusion);

  {
    std::cout << std::endl;
    std::cout << "Resize scheduling\n";
    fusion->print();
    std::cout << std::endl;
  }

  {
    std::stringstream file_name;
    file_name << "pre_scheduling.dot";
    IrGraphGenerator::print(
        fusion,
        file_name.str().c_str(),
        IrGraphGenerator::DetailLevel::ComputeOnly);
  }

  auto ref_tv = getReferenceTensor(fusion);
  NVF_ERROR(ref_tv != nullptr);

  scheduler_utils::cacheInputs(fusion, true);
  scheduler_utils::cacheAndForkOutputs(fusion, true);

  auto resize_tensor_ops = ir_utils::getOpsOfType<SliceOp, PadOp>(fusion);

  std::unique_ptr<IdModel> id_model =
      std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
  id_model->buildExactGraph();

  // Replicate resize inputs if necessary to avoid conflicting
  // propagations
  const auto exclusivity_info_map = scheduler_tools::getNonExclusiveResizeInfo(
      resize_tensor_ops, id_model->idGraph(IdMappingMode::EXACT));
  for (auto resize_tensor_op : resize_tensor_ops) {
    auto out_tv = resize_tensor_op->output(0)->as<TensorView>();
    if (exclusivity_info_map.count(out_tv) == 0) {
      continue;
    }
    auto inp_tv = resize_tensor_op->input(0)->as<TensorView>();
    // Since cacheInput may skip caching if an input is used by
    // slice/pad, inp_tv may be a fusion input, in which case it is
    // not necessary to recompute the tensor.
    if (inp_tv->isFusionInput()) {
      continue;
    }
    auto inp_tv_copy = RecomputeTv::recompute(inp_tv);
    ir_utils::replaceValInExprInputs(resize_tensor_op, inp_tv, inp_tv_copy);
  }

  {
    std::cout << std::endl;
    std::cout << "After recomputation\n";
    fusion->print();
    std::cout << std::endl;

    std::stringstream file_name;
    file_name << "after_recomputation.dot";
    IrGraphGenerator::print(
        fusion,
        file_name.str().c_str(),
        IrGraphGenerator::DetailLevel::ComputeOnly);
  }

  TensorView* largest_input = nullptr;
  if (resize_params->largest_input >= 0) {
    largest_input =
        fusion->inputs().at(resize_params->largest_input)->as<TensorView>();

    // The tensors are going to be reordered to align with the largest
    // input. To make it work, merge operations for reshape should be
    // cancelled.
    scheduler_tools::cancelReshapeInLoopDomains(largest_input);
  }

  {
    std::cout << std::endl;
    std::cout << "After reshape cancel\n";
    fusion->print();
    std::cout << std::endl;
  }
  for (auto expr : fusion->exprs()) {
    if (!expr->isOneOf<SliceOp, PadOp>()) {
      continue;
    }

    std::cerr << "propagateResize: " << expr->toString();

    scheduler_tools::propagateResizeToInputs(expr);
  }

  // Update the IdModel
  id_model = std::make_unique<IdModel>(fusion, /*build_graphs=*/false);
  id_model->buildExactGraph();

  // Detect an ending repeat
  auto static_repeat_info = scheduler_tools::getMaybeStaticRepeatInfo(ref_tv);

  if (static_repeat_info.has_value()) {
    std::cerr << "Static repeat: "
              << static_repeat_info->reshape_repeat_id->toString() << "\n";
  }

  // Just simple scheduling for now.
  // TODO: Do something smarter. Can just use the pointwise scheduler?

  std::cerr << "Ref tensor: " << ref_tv->toString() << "\n";

  // Reorder tensors to align with the largest input. This is expected
  // to improve the memory read performance, while the write
  // performance could be lowered. This should generally be more
  // important to optimize the read performance, but more robust
  // decision would be needed.
  if (largest_input != nullptr) {
    std::vector<IterDomain*> ref_alloc;
    ref_alloc.reserve(largest_input->getMaybeAllocationDomain().size());
    std::copy_if(
        largest_input->getMaybeAllocationDomain().begin(),
        largest_input->getMaybeAllocationDomain().end(),
        std::back_inserter(ref_alloc),
        [](IterDomain* alloc_id) {
          return !alloc_id->isBroadcast() && !alloc_id->isReduction() &&
              !alloc_id->isDeviceDim();
        });

    // Reorder the reference as the allocation domain of the largest fusion
    // input
    scheduler_utils::reorderTensorLike(ref_tv, ref_alloc);
  }

  const int64_t bdimx = 128;

  // Make sure the DID ID located at the outermost position
  auto outermost_pos = scheduler_utils::reorderDevicesToOuter(ref_tv);

  // [DID, ..., ...]
  //        ^
  //        +--- outermost_pos

  // Move the static repeat ID to the outermost position if
  // detected. The repeat ID then just remains there with no
  // scheduling.
  bool repeat_id_moved_to_outermost = false;
  if (static_repeat_info.has_value()) {
    NVF_ERROR(ref_tv == static_repeat_info->repeat_output_tv);
    auto ref_repeat_id_it = std::find_if(
        ref_tv->getLoopDomain().begin(),
        ref_tv->getLoopDomain().end(),
        [&](IterDomain* loop_id) {
          return id_model->idGraph(IdMappingMode::EXACT)
              .disjointValSets()
              .strictAreMapped(loop_id, static_repeat_info->reshape_repeat_id);
        });
    // Gives up if the repeat ID is not found. Unclear if this could
    // actually happen, though.
    if (ref_repeat_id_it != ref_tv->getLoopDomain().end()) {
      auto repeat_id_pos =
          std::distance(ref_tv->getLoopDomain().begin(), ref_repeat_id_it);
      NVF_ERROR(
          repeat_id_pos >= outermost_pos,
          "Unexpected to have DID-parallelized repeat axis: ",
          static_repeat_info->reshape_repeat_id->toString());

      // [DID, ..., repeat_id, ...]
      //        ^
      //        +--- outermost_pos
      ref_tv->reorder(std::unordered_map<int64_t, int64_t>{{repeat_id_pos, 0}});
      ++outermost_pos;
      // [repeat_id, DID, ...]
      //                   ^
      //                   +--- outermost_pos

      repeat_id_moved_to_outermost = true;
    }
  }

  const int64_t vec_factor = resize_params->vectorization_factor;

  int64_t next_innermost_pos = -1;
  // [..., ...]
  //        ^
  //        +--- next_innermost_pos

  if (vec_factor > 1) {
    ref_tv->split(-1, vec_factor);
    --next_innermost_pos;
    // [..., vec_factor]
    //   ^
    //   +--- next_innermost_pos
  }

  ref_tv->flatten(outermost_pos, next_innermost_pos);
  // [..., I0, vec_factor]
  //       ^
  //       +--- next_innermost_pos

  ref_tv->split(next_innermost_pos, bdimx);
  ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::TIDx);
  --next_innermost_pos;
  // [..., I0/bdimx, bdimx(TIDx), vec_factor]
  //         ^
  //         +--- next_innermost_pos

  if (resize_params->split_grid_x_dim) {
    ref_tv->split(next_innermost_pos, ResizeParams::max_gdimx);
    // [..., I0/bdimx/max_gdimx, max_gdimx, bdimx(TIDx), vec_factor]
  }
  ref_tv->axis(next_innermost_pos)->parallelize(ParallelType::BIDx);
  // [..., I0/bdimx/max_gdimx, max_gdimx(BIDx), bdimx(TIDx), vec_factor] or
  // [..., I0/bdimx(BIDx), bdimx(TIDx), vec_factor]

  std::cout << "Before ref prop\n";
  fusion->print();
  std::cout << std::endl;

  for (auto tv : fusion->allTvs()) {
    std::cerr << tv->toString() << "\n";
    for (auto expr : tv->domain()->allExprs()) {
      std::cerr << expr->toString();
    }
    std::cerr << "---\n";
  }

  {
    IdModel idg(fusion, false);
    idg.buildExactGraph();
    std::ofstream ofs("exact_graph_before_ref_prop.dot", std::ofstream::trunc);
    auto dot_string = idg.idGraph(IdMappingMode::EXACT).toGraphvizDotGraph();
    ofs << dot_string;
    ofs.close();
  }

  // Propagate the reference to the other tensors. Note that the
  // update flag is enabled to workaround the resize propagation
  // issue. This may not work if there's a tensor that is reshaped
  // from the reference tensor, but that should not be the case as the
  // reference is picked by the same routine used for the pointwise
  // scheduler.
  //
  // When an ending static repeat is detected and the repeat ID is
  // moved to the outermost position, propagation is done separately
  // between the tensors before the repeat and after the repeat. The
  // tensors are first grouped into the pre-repeat group and the
  // post-repeat group, where only the latter group has the repeat
  // IDs. When propagating the loop domain of the reference tensor,
  // which has the repeat ID, the full loop domain is propagated only
  // to the post-repeat group. For the pre-repeat group, the repeat ID
  // is dropped and only the remaining loop domain is propagated.
  if (repeat_id_moved_to_outermost) {
    // Divide all tvs to the pre and posgt repeat groups
    auto all_tvs = fusion->allTvs();
    std::vector<TensorView*> post_repeat_tvs;
    post_repeat_tvs.reserve(static_repeat_info->repeat_tvs.size());
    std::vector<TensorView*> pre_repeat_tvs;
    pre_repeat_tvs.reserve(
        all_tvs.size() - static_repeat_info->repeat_tvs.size());
    for (auto tv : all_tvs) {
      if (static_repeat_info->repeat_tvs.count(tv)) {
        post_repeat_tvs.push_back(tv);
      } else {
        pre_repeat_tvs.push_back(tv);
      }
    }

    // The repeat ID should be located at the outermost position
    std::vector<IterDomain*> non_repeated_loop{
        ref_tv->getLoopDomain().begin() + 1, ref_tv->getLoopDomain().end()};

    scheduler_tools::scheduleLoopDomainsLike(
        pre_repeat_tvs,
        non_repeated_loop,
        /*update_loop_domain_only=*/true);
    scheduler_tools::scheduleLoopDomainsLike(
        post_repeat_tvs,
        ref_tv->getLoopDomain(),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants