Skip to content

Commit

Permalink
review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Sarkauskas authored and nsarka committed Jan 14, 2025
1 parent 9ad47e3 commit 4409faa
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 39 deletions.
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_integration.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp
Expand Down Expand Up @@ -700,7 +699,12 @@ if(BUILD_TEST)
add_test(tutorial "${NVFUSER_ROOT}/tests/cpp/test_tutorial.cpp" "")
list(APPEND TEST_BINARIES tutorial)

add_test(test_host_ir "${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp" "")
set(HOSTIR_TEST_SRCS)
list(APPEND HOSTIR_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp
)
add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "")
list(APPEND TEST_BINARIES test_host_ir)

if(BUILD_PYTHON)
Expand Down
6 changes: 2 additions & 4 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,9 @@ void HostIrEvaluator::handle(LaunchKernel* launch_kernel) {
args.push(input_evaluation);
}

// placeholder for storing the outputs
std::vector<at::Tensor> outputs;

// run the compiled kernel
outputs = container_->getKernelExecutor(launch_kernel->getIndex())->run(args);
std::vector<at::Tensor> outputs =
container_->getKernelExecutor(launch_kernel->getIndex())->run(args);

// Store the outputs in the context
for (auto output_idx : c10::irange(outputs.size())) {
Expand Down
14 changes: 5 additions & 9 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ bool PostOnStream::sameAs(const Statement* other) const {

LaunchKernel::LaunchKernel(
IrBuilderPasskey passkey,
int hic_executor_index,
int64_t hic_executor_index,
std::vector<Val*> inputs,
std::vector<Val*> outputs)
: Expr(passkey, std::move(inputs), std::move(outputs), {}),
Expand All @@ -131,31 +131,27 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(LaunchKernel)

std::string LaunchKernel::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "LaunchKernel ("
<< "Inputs:{";
indent(ss, indent_size) << "LaunchKernel("
<< "Inputs: {";
std::for_each(inputs().begin(), inputs().end(), [&ss](auto input) {
ss << input->toString(0) << ", ";
});
ss << "}, Outputs:{";
ss << "}, Outputs: {";
std::for_each(outputs().begin(), outputs().end(), [&ss](auto output) {
ss << output->toString(0) << ", ";
});
ss << "})" << std::endl;
return ss.str();
}

int LaunchKernel::getIndex() const {
int64_t LaunchKernel::getIndex() const {
return hic_executor_index_;
}

std::string LaunchKernel::toInlineString(int indent_size) const {
NVF_CHECK(false, "Can not be printed inline");
}

bool LaunchKernel::sameAs(const Statement* other) const {
return false;
}

Stream::Stream(IrBuilderPasskey passkey, Val* index)
: Val(passkey, ValType::Stream), index_(index) {}

Expand Down
14 changes: 5 additions & 9 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class LaunchKernel : public Expr {
using Expr::Expr;
LaunchKernel(
IrBuilderPasskey passkey,
int hic_executor_index, // TODO
int64_t hic_executor_index, // Index into the HostIrContainer's vector of
// KernelExecutors--i.e., the kernel this IR
// should launch
std::vector<Val*> inputs,
std::vector<Val*> outputs);

Expand All @@ -137,15 +139,9 @@ class LaunchKernel : public Expr {
return "hir::LaunchKernel";
}

int getIndex() const;
int64_t getIndex() const;

bool sameAs(const Statement* other) const override;

Expr* hostOpToPost() const {
return attributes_.at(0)->as<Expr>();
}

int hic_executor_index_;
int64_t hic_executor_index_;
};

class Stream : public Val {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
Expand All @@ -10,13 +10,15 @@
#include <host_ir/executor.h>
#include <ir/all_nodes.h>
#include <ops/all_ops.h>
#include <tests/cpp/multidevice.h>
#include <tests/cpp/utils.h>

namespace nvfuser {

namespace hir {

TEST_F(MultiDeviceTest, LaunchKernel) {
using HostIrIntegrationTest = NVFuserTest;

TEST_F(HostIrIntegrationTest, LaunchKernel) {
Fusion fusion;
FusionGuard fg(&fusion);
TensorView* tv0 = makeSymbolicTensor(2);
Expand All @@ -40,26 +42,23 @@ TEST_F(MultiDeviceTest, LaunchKernel) {
auto tv2 = ir_cloner.clone(tv0);
auto tv3 = ir_cloner.clone(tv1);

std::vector<Val*> lk_inputs = {tv2};
std::vector<Val*> lk_outputs = {tv3};
std::vector<Val*> launch_kernel_inputs = {tv2};
std::vector<Val*> launch_kernel_outputs = {tv3};

hic->addInput(lk_inputs.back());
hic->addOutput(lk_outputs.back());
hic->addInput(launch_kernel_inputs.back());
hic->addOutput(launch_kernel_outputs.back());

auto launch_kernel =
IrBuilder::create<LaunchKernel>(0, lk_inputs, lk_outputs);
auto launch_kernel = IrBuilder::create<LaunchKernel>(
0, launch_kernel_inputs, launch_kernel_outputs);

hic->pushBackTopLevelExprs(launch_kernel);

HostIrEvaluatorParams params;
params.use_fusion_executor_cache = false;
HostIrEvaluator hie(std::move(hic), communicator_, params);
HostIrEvaluator hie(std::move(hic));

at::Tensor output = at::empty({32, 32}, options);
auto outputs =
hie.runWithInput({{lk_inputs.back(), t0}, {lk_outputs.back(), output}});
auto outputs = hie.runWithInput({{tv2, t0}, {tv3, output}});

ASSERT_TRUE(outputs[0].equal(t0));
EXPECT_TRUE(outputs[0].equal(t0));
}

} // namespace hir
Expand Down

0 comments on commit 4409faa

Please sign in to comment.