Skip to content

Commit d0dd697

Browse files
authored
[mlir][spirv] Switch to llvm::interleaved. NFC. (#136240)
Clean up printing code by switching to `llvm::interleaved` from #135517.
1 parent f2ecd86 commit d0dd697

File tree

3 files changed

+21
-34
lines changed

3 files changed

+21
-34
lines changed

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1616
#include "mlir/Interfaces/CallInterfaces.h"
1717

18+
#include "llvm/Support/InterleavedRange.h"
19+
1820
#include "SPIRVOpUtils.h"
1921
#include "SPIRVParsingUtils.h"
2022

@@ -119,12 +121,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
119121
void BranchConditionalOp::print(OpAsmPrinter &printer) {
120122
printer << ' ' << getCondition();
121123

122-
if (auto weights = getBranchWeights()) {
123-
printer << " [";
124-
llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
125-
printer << llvm::cast<IntegerAttr>(a).getInt();
126-
});
127-
printer << "]";
124+
if (std::optional<ArrayAttr> weights = getBranchWeights()) {
125+
printer << ' '
126+
<< llvm::interleaved_array(weights->getAsValueRange<IntegerAttr>());
128127
}
129128

130129
printer << ", ";

mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp

+8-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/Builders.h"
1313
#include "mlir/IR/DialectImplementation.h"
1414
#include "llvm/ADT/TypeSwitch.h"
15+
#include "llvm/Support/InterleavedRange.h"
1516

1617
using namespace mlir;
1718
using namespace mlir::spirv;
@@ -621,17 +622,14 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
621622
//===----------------------------------------------------------------------===//
622623

623624
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
624-
auto &os = printer.getStream();
625625
printer << spirv::VerCapExtAttr::getKindName() << "<"
626-
<< spirv::stringifyVersion(triple.getVersion()) << ", [";
627-
llvm::interleaveComma(
628-
triple.getCapabilities(), os,
629-
[&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
630-
printer << "], [";
631-
llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
632-
os << llvm::cast<StringAttr>(attr).getValue();
633-
});
634-
printer << "]>";
626+
<< spirv::stringifyVersion(triple.getVersion()) << ", "
627+
<< llvm::interleaved_array(llvm::map_range(
628+
triple.getCapabilities(), spirv::stringifyCapability))
629+
<< ", "
630+
<< llvm::interleaved_array(
631+
triple.getExtensionsAttr().getAsValueRange<StringAttr>())
632+
<< ">";
635633
}
636634

637635
static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

+8-18
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "llvm/ADT/STLExtras.h"
3636
#include "llvm/ADT/StringExtras.h"
3737
#include "llvm/ADT/TypeSwitch.h"
38+
#include "llvm/Support/InterleavedRange.h"
3839
#include <cassert>
3940
#include <numeric>
4041
#include <optional>
@@ -807,10 +808,8 @@ void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
807808
printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
808809
printer.printSymbolName(getFn());
809810
auto interfaceVars = getInterface().getValue();
810-
if (!interfaceVars.empty()) {
811-
printer << ", ";
812-
llvm::interleaveComma(interfaceVars, printer);
813-
}
811+
if (!interfaceVars.empty())
812+
printer << ", " << llvm::interleaved(interfaceVars);
814813
}
815814

816815
LogicalResult spirv::EntryPointOp::verify() {
@@ -862,13 +861,9 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
862861
printer << " ";
863862
printer.printSymbolName(getFn());
864863
printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
865-
auto values = this->getValues();
866-
if (values.empty())
867-
return;
868-
printer << ", ";
869-
llvm::interleaveComma(values, printer, [&](Attribute a) {
870-
printer << llvm::cast<IntegerAttr>(a).getInt();
871-
});
864+
ArrayAttr values = this->getValues();
865+
if (!values.empty())
866+
printer << ", " << llvm::interleaved(values.getAsValueRange<IntegerAttr>());
872867
}
873868

874869
//===----------------------------------------------------------------------===//
@@ -1824,13 +1819,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
18241819
void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
18251820
printer << " ";
18261821
printer.printSymbolName(getSymName());
1827-
printer << " (";
1828-
auto constituents = this->getConstituents().getValue();
1829-
1830-
if (!constituents.empty())
1831-
llvm::interleaveComma(constituents, printer);
1832-
1833-
printer << ") : " << getType();
1822+
printer << " (" << llvm::interleaved(this->getConstituents().getValue())
1823+
<< ") : " << getType();
18341824
}
18351825

18361826
LogicalResult spirv::SpecConstantCompositeOp::verify() {

0 commit comments

Comments
 (0)