Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
Browse files Browse the repository at this point in the history
…kernel_2
  • Loading branch information
csarofeen committed Jan 11, 2025
2 parents e008bc7 + 05ec62b commit fb38b77
Show file tree
Hide file tree
Showing 73 changed files with 3,253 additions and 772 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ env:

jobs:
clang-build:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -37,7 +37,7 @@ jobs:
python setup.py build
dynamic-type-meson:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ env:

jobs:
check-license:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
Expand All @@ -28,7 +28,7 @@ jobs:
test ! -s missing-header-files.txt
clang-tidy:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
git --no-pager diff --diff-filter=d --name-only $head_commit | grep -e "csrc/.*\.cpp" -e "csrc/.*\.h" | xargs lintrunner --take CLANGTIDY --force-color
lintrunner:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
with:
Expand Down
21 changes: 14 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,18 @@ endif()
add_library(codegen_internal OBJECT ${NVFUSER_SRCS})

if(NOT MSVC)
# -Werror is not enabled, because of gcc 12.2 used in manylinux image.
# consider enable this when we upgrade. linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function
# -Werror
)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror
# These warnings are not treated as errors because of gcc 12.2 used in
# manylinux image. consider enable this when we upgrade.
# linking comment:
# https://github.com/NVIDIA/Fuser/pull/3001#discussion_r1772551266
-Wno-error=restrict -Wno-error=stringop-overflow)
else()
target_compile_options(codegen_internal PRIVATE
-Wall -Wno-unused-function -Werror)
endif()
endif()

target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
Expand Down Expand Up @@ -442,8 +447,10 @@ if(BUILD_PYTHON)
# nvfuser python API sources
set(NVFUSER_PYTHON_SRCS)
list(APPEND NVFUSER_PYTHON_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/communicator_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp
)

add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS})
Expand Down
71 changes: 71 additions & 0 deletions csrc/bfs.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,77 @@ class BFS {
Direction allowed_direction_ = Direction::Undefined;
};

// Unlike the default BFS behavior, Expr is considered ready to
// visit as long as one of the inputs or outputs has any of its dependencies met
template <
typename ExprT,
typename ValT,
typename DefinitionT,
typename UsesT,
typename InputsT,
typename OutputsT>
class BFSWithPermissiveDependence
: public BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT> {
public:
using NodeType =
typename BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>::
NodeType;

BFSWithPermissiveDependence(
DefinitionT definition,
UsesT uses,
InputsT inputs,
OutputsT outputs,
std::vector<NodeType> from,
std::vector<NodeType> to,
bool require_all_to_visited = true,
Direction allowed_direction = Direction::Undefined)
: BFS<ExprT, ValT, DefinitionT, UsesT, InputsT, OutputsT>(
definition,
uses,
inputs,
outputs,
std::move(from),
std::move(to),
require_all_to_visited,
allowed_direction) {}

std::optional<std::pair<Direction, std::vector<NodeType>>> isReady(
const ExprT& expr) const override {
// Either any inputs or any outputs must have been visited
decltype(auto) inputs = this->inputs_(expr);
if (!inputs.empty() && this->allowed_direction_ != Direction::Backward &&
std::any_of(
inputs.begin(), inputs.end(), [&](const ValT& input) -> bool {
return this->isDependencySatisfied(input);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
inputs.begin(),
inputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& input) -> bool { return this->isVisited(input); });
return std::make_pair(Direction::Forward, prev_nodes);
}

decltype(auto) outputs = this->outputs_(expr);
if (!outputs.empty() && this->allowed_direction_ != Direction::Forward &&
std::any_of(
outputs.begin(), outputs.end(), [&](const ValT& output) -> bool {
return this->isDependencySatisfied(output);
})) {
std::vector<NodeType> prev_nodes;
std::copy_if(
outputs.begin(),
outputs.end(),
std::back_inserter(prev_nodes),
[&](const ValT& output) -> bool { return this->isVisited(output); });
return std::make_pair(Direction::Backward, prev_nodes);
}
return std::nullopt;
}
};

// Find the shortest path from the from vals to the to
// vals. Dependency between vals and exprs must be satisfied.
// It is an error if no valid path is found unless
Expand Down
27 changes: 22 additions & 5 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
public:
static std::string generateKernelDefinition(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
CudaKernelGenerator codegen(kernel);
codegen.genDeclaration(kernel_name);
codegen.genDeclaration(kernel_name, num_threads_per_cta);
codegen.startBlock();
codegen.genPrologue();
codegen.genBody();
Expand Down Expand Up @@ -272,8 +273,18 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

// Generates the kernel function declaration
void genDeclaration(const std::string& kernel_name) {
void genDeclaration(
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
code_ << "__global__ void ";
if (kernel_->hasManaged("enable_register_sharing") &&
kernel_->getManaged<bool>("enable_register_sharing")) {
NVF_ERROR(
num_threads_per_cta.has_value(),
"__launch_bounds__ must be set for register sharing warp specialization");
code_ << "__launch_bounds__(/*MAX_THREADS_PER_BLOCK=*/"
<< num_threads_per_cta.value() << ") ";
}
if (kernel_->hasManaged("cluster_dims")) {
auto cluster_dims =
kernel_->getManaged<std::tuple<int64_t, int64_t, int64_t>>(
Expand Down Expand Up @@ -3510,6 +3521,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n";
}

void handle(const kir::Return* ret) final {
indent() << "return;\n";
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down Expand Up @@ -3538,9 +3553,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name) {
const std::string& kernel_name,
std::optional<int64_t> num_threads_per_cta) {
FUSER_PERF_SCOPE("generateCudaKernel");
return CudaKernelGenerator::generateKernelDefinition(kernel, kernel_name);
return CudaKernelGenerator::generateKernelDefinition(
kernel, kernel_name, num_threads_per_cta);
}

} // namespace codegen
Expand Down
3 changes: 2 additions & 1 deletion csrc/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ namespace codegen {
//! Generates a CUDA kernel definition for the given kernel
NVF_API std::string generateCudaKernel(
const kir::Kernel* kernel,
const std::string& kernel_name = "CUDAGeneratedKernel");
const std::string& kernel_name = "CUDAGeneratedKernel",
std::optional<int64_t> num_threads_per_cta = std::nullopt);

} // namespace codegen
} // namespace nvfuser
23 changes: 23 additions & 0 deletions csrc/device_lower/analysis/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ void validateCircularBufferedTensor(const TensorView* tv) {
". Consumer memory type: ",
c_mem_type);

// Ensure that the warp-specialized circular buffer loop is the outer-most
// for-loop if register sharing is enabled.
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type) &&
std::get<WarpSpecialized>(tv->circularBufferOptions().type)
.num_registers.has_value()) {
for (int64_t axis : c10::irange((int64_t)tv->getLoopDomain().size())) {
// short-circuit: only check IterDomains to the left of the circular
// buffer position
if (axis >= circular_buffer_pos) {
break;
}
NVF_ERROR(
tv->getLoopDomain().at(axis)->isThread() ||
tv->getLoopDomain().at(axis)->isDeviceDim() ||
tv->getLoopDomain().at(axis)->isBroadcast() ||
tv->getLoopDomain().at(axis)->isOneInt(),
"When using register sharing with warp-specialized circular "
"buffering, the circular buffer loop must be the outer-most "
"for-loop.");
}
}

return;
}

Expand Down
9 changes: 9 additions & 0 deletions csrc/device_lower/analysis/device_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ void MinimumDeviceVersion::handle(LoadStoreOp* ls_op) {
}
}

void MinimumDeviceVersion::handle(TensorView* tv) {
if (std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type)) {
ensureVersion(
{9, 0},
"Warp Specialized Circular Buffering uses the setmaxnreg ptx instruction, which requires Hopper (9.0) or newer");
}
}

void MinimumDeviceVersion::ensureVersion(
std::pair<int, int> version,
std::string reason) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/device_lower/analysis/device_version.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class MinimumDeviceVersion : private IterVisitor {
//! https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
void handle(LoadStoreOp* ls_op) final;

//! If TensorView has warp specialized circular buffering, it will use the
//! setmaxnreg ptx instruction that requires Hopper (9.0+).
//! https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
void handle(TensorView* tv) final;

//! bump min_version_ to at least this value
void ensureVersion(std::pair<int, int> version, std::string reason);

Expand Down
37 changes: 37 additions & 0 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,11 +1394,48 @@ class CircularBufferInserter : private kir::ExprMutator {
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

// Set default value
auto& circular_buffer_options =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
circular_buffer_loop->iter_domain());
bool enable_register_sharing =
std::holds_alternative<WarpSpecialized>(circular_buffer_options.type) &&
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.has_value();

GpuLower::current()->kernel()->manage(
"enable_register_sharing", enable_register_sharing);

if (enable_register_sharing) {
auto&& [decrease_num_registers, increase_num_registers] =
std::get<WarpSpecialized>(circular_buffer_options.type)
.num_registers.value();

// Decrease registers in load warp group
kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(decrease_num_registers, DataType::Index),
/*increase_registers=*/false);
warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp);

// Increase registers in compute warp group
kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create<kir::SetMaxNReg>(
IrBuilder::create<Val>(increase_num_registers, DataType::Index),
/*increase_registers*/ true);
warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp);
}

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

if (enable_register_sharing) {
// Terminate the warp group handling Load loop immediately after
// finishing its work.
kir::Return* ret = IrBuilder::create<kir::Return>();
warp_dispatch_ite->thenBody().push_back(ret);
}

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
warp_dispatch_ite->elseBody().push_back(prefetch_loop);
Expand Down
10 changes: 10 additions & 0 deletions csrc/device_lower/pass/fusion_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ class LoadStoreOpInserter : private kir::ExprMutator {
container, LoadStoreOpType::Set, out, in));
}

void handle(RepeatOp* op) final {
auto out = op->out();
auto in = op->in();
auto container = out->container();
registerReplaceAndPropagate(
op,
IrBuilder::createInContainer<LoadStoreOp>(
container, LoadStoreOpType::Set, out, in));
}

void handle(ViewOp* vop) final {
auto out = vop->out();
auto in = vop->in();
Expand Down
10 changes: 10 additions & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,16 @@ void IndexLowering::handle(const kir::WgMmaFence* fence) {
pushBack(const_cast<kir::WgMmaFence*>(fence)); // NOLINT
}

void IndexLowering::handle(const kir::SetMaxNReg* maxnreg) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::SetMaxNReg*>(maxnreg)); // NOLINT
}

void IndexLowering::handle(const kir::Return* ret) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::Return*>(ret)); // NOLINT
}

void IndexLowering::handle(const kir::AsyncCommit* commit) {
// TODO(kir): remove the need for const_cast
pushBack(const_cast<kir::AsyncCommit*>(commit)); // NOLINT
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::FenceAsyncProxy*) final;
void handle(const kir::WgMmaFence*) final;
void handle(const kir::SetMaxNReg*) final;
void handle(const kir::Return*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArrive*) final;
Expand Down
13 changes: 13 additions & 0 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ class LowerToInlinePtx : public kir::ExprMutator {
std::vector<Val*>{},
kir::Asm::Options{/*volatile=*/true}));
}

void handle(kir::SetMaxNReg* maxnreg) final {
std::string ptx = (maxnreg->increaseRegisters())
? "setmaxnreg.inc.sync.aligned.u32"
: "setmaxnreg.dec.sync.aligned.u32";
registerReplace(
maxnreg,
IrBuilder::create<kir::Asm>(
ptx,
std::vector<Val*>{},
std::vector<Val*>{maxnreg->numberOfRegisters()},
kir::Asm::Options{/*volatile=*/true}));
}
};

std::vector<Expr*> lowerToInlinePtx(const std::vector<Expr*>& exprs) {
Expand Down
Loading

0 comments on commit fb38b77

Please sign in to comment.