Skip to content

Commit

Permalink
[tools][triton-tensor-layout] Allow parsing ttgir files with triton_n…
Browse files Browse the repository at this point in the history
…vidia_gpu ops
  • Loading branch information
bertmaher committed Sep 19, 2024
1 parent fad49b2 commit c837d65
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
4 changes: 4 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,9 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
target_link_libraries(triton-tensor-layout PRIVATE
TritonGPUIR
TritonNvidiaGPUIR
${triton_libs}
${conversion_libs}
${dialect_libs}
TritonTestAnalysis
)
12 changes: 7 additions & 5 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "RegisterTritonDialects.h"

#include "mlir/AsmParser/AsmParser.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/MLIRContext.h"

#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
Expand Down Expand Up @@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
return failure();
}

auto printLambda = [&](StringRef name, Attribute attr) {
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
ss << "Print layout attribute: #" << name << " = " << attr << "\n";

auto rankedTensorTy = RankedTensorType::get(
Expand Down Expand Up @@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
if (layoutAttrStr.empty())
return success();

Attribute layout = parseAttribute(layoutAttrStr, context);
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
if (!layout) {
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
return failure();
Expand All @@ -178,8 +181,7 @@ int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");

DialectRegistry registry;
// Register all dialects that can print tensor layout.
registry.insert<triton::gpu::TritonGPUDialect>();
registerTritonDialects(registry);

MLIRContext ctx(registry);
ctx.loadAllAvailableDialects();
Expand All @@ -189,7 +191,7 @@ int main(int argc, char **argv) {
return 1;
}

Type parsedTy = parseType(TensorStr, &ctx);
mlir::Type parsedTy = parseType(TensorStr, &ctx);
if (!parsedTy) {
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
<< "\n";
Expand Down

0 comments on commit c837d65

Please sign in to comment.