Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Remove the !cir.void return type for functions returning void #1203

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3471,8 +3471,6 @@ def FuncOp : CIR_Op<"func", [
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> getCallableResults() {
if (::llvm::isa<cir::VoidType>(getFunctionType().getReturnType()))
return {};
return getFunctionType().getReturnTypes();
}

Expand All @@ -3489,10 +3487,15 @@ def FuncOp : CIR_Op<"func", [
}

/// Returns the argument types of this function.
llvm::ArrayRef<mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
llvm::ArrayRef<mlir::Type> getArgumentTypes() {
return getFunctionType().getInputs();
}

/// Returns the result types of this function.
llvm::ArrayRef<mlir::Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
/// Returns 0 or 1 result type of this function (0 in the case of a function
/// returing void)
llvm::ArrayRef<mlir::Type> getResultTypes() {
return getFunctionType().getReturnTypes();
}

/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
/// the 'type' attribute is present and checks if it holds a function type.
Expand Down
19 changes: 12 additions & 7 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -379,22 +379,27 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {

```mlir
!cir.func<!bool ()>
!cir.func<!cir.void ()>
!cir.func<!s32i (!s8i, !s8i)>
!cir.func<!s32i (!s32i, ...)>
```
}];

let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType,
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, ArrayRefParameter<"mlir::Type">:$returnTypes,
"bool":$varArg);
let assemblyFormat = [{
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
`<` custom<FuncType>($returnTypes, $inputs, $varArg) `>`
}];

let builders = [
// Construct with an actual return type or explicit !cir.void
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
return $_get(returnType.getContext(), inputs,
::mlir::isa<::cir::VoidType>(returnType) ? llvm::ArrayRef<mlir::Type>{}
: llvm::ArrayRef{returnType},
isVarArg);
}]>
];

Expand All @@ -408,11 +413,11 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
/// Returns the result type of the function as an actual return type or
/// explicit !cir.void
mlir::Type getReturnType() const;

/// Returns whether the function is returns void.
/// Returns whether the function returns void.
bool isVoid() const;

/// Returns a clone of this function type with the given argument
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
assert(QFT.isCanonical());
const Type *Ty = QFT.getTypePtr();
const FunctionType *FT = cast<FunctionType>(QFT.getTypePtr());
// First, check whether we can build the full fucntion type. If the function
// First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
assert(isFuncTypeConvertible(FT) && "NYI");
Expand Down
39 changes: 29 additions & 10 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2218,6 +2218,26 @@ void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
getResAttrsAttrName(result.name));
}

// A specific version of function_interface_impl::parseFunctionSignature able to
// handle the "-> !void" special fake return type.
static ParseResult
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments,
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
if (function_interface_impl::parseFunctionArgumentList(parser, allowVariadic,
arguments, isVariadic))
return failure();
if (succeeded(parser.parseOptionalArrow())) {
if (parser.parseOptionalExclamationKeyword("!void").succeeded())
// This is just an empty return type and attribute.
return success();
return function_interface_impl::parseFunctionResultList(parser, resultTypes,
resultAttrs);
}
return success();
}

ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
llvm::SMLoc loc = parser.getCurrentLocation();

Expand Down Expand Up @@ -2278,9 +2298,8 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {

// Parse the function signature.
bool isVariadic = false;
if (function_interface_impl::parseFunctionSignature(
parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
resultAttrs))
if (parseFunctionSignature(parser, /*allowVariadic=*/true, arguments,
isVariadic, resultTypes, resultAttrs))
return failure();

for (auto &arg : arguments)
Expand Down Expand Up @@ -2483,13 +2502,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
p.printSymbolName(getSymName());
auto fnType = getFunctionType();
llvm::SmallVector<Type, 1> resultTypes;
if (!fnType.isVoid())
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(),
fnType.getReturnTypes());
else
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), {});
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());

if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
p << ' ';
Expand Down Expand Up @@ -2558,6 +2572,11 @@ LogicalResult cir::FuncOp::verifyType() {
if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0)
return emitError()
<< "prototyped function must have at least one non-variadic input";
if (auto rt = type.getReturnTypes();
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
return emitOpError("The return type for a function returning void should "
"be empty instead of an explicit !cir.void");

return success();
}

Expand Down
93 changes: 81 additions & 12 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <optional>

using cir::MissingFeatures;
Expand All @@ -42,13 +43,16 @@ using cir::MissingFeatures;
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
llvm::SmallVector<mlir::Type> &params, bool &isVarArg);

static void printFuncType(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> returnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);

static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
mlir::Attribute &addrSpaceAttr);

static void printPointerAddrSpace(mlir::AsmPrinter &p,
mlir::Attribute addrSpaceAttr);

Expand Down Expand Up @@ -913,9 +917,46 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
// A special parser is needed for function returning void to consume the "!void"
// returned type in the case there is no alias defined.
static mlir::ParseResult
parseFuncTypeReturn(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &returnTypes) {
if (p.parseOptionalExclamationKeyword("!void").succeeded())
// !void means no return type.
return p.parseLParen();
if (succeeded(p.parseOptionalLParen()))
// If we have already a '(', the function has no return type
return mlir::success();

mlir::Type type;
auto result = p.parseOptionalType(type);
if (!result.has_value())
return mlir::failure();
if (failed(*result) || isa<cir::VoidType>(type))
// No return type specified.
return p.parseLParen();
// Otherwise use the actual type.
returnTypes.push_back(type);
return p.parseLParen();
}

// A special pretty-printer for function returning void to emit a "!void"
// returned type. Note that there is no real type used here since it does not
// appear in the IR and thus the alias might not be defined and cannot be
// referred to. This is why this is a pure syntactic-sugar string which is used.
static void printFuncTypeReturn(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> returnTypes) {
if (returnTypes.empty())
// Pretty-print no return type as "!void"
p << "!void ";
else
p << returnTypes << ' ';
}

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
Expand Down Expand Up @@ -945,8 +986,10 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
p << '(';
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand All @@ -957,11 +1000,37 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}

llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
static mlir::ParseResult
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
llvm::SmallVector<mlir::Type> &params, bool &isVarArg) {
if (failed(parseFuncTypeReturn(p, returnTypes)))
return failure();
return parseFuncTypeArgs(p, params, isVarArg);
}

static void printFuncType(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> returnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
printFuncTypeReturn(p, returnTypes);
printFuncTypeArgs(p, params, isVarArg);
}

bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
// Return the actual return type or an explicit !cir.void if the function does
// not return anything
mlir::Type FuncType::getReturnType() const {
if (isVoid())
return cir::VoidType::get(getContext());
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes.front();
}

bool FuncType::isVoid() const {
auto rt = static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes;
assert(rt.empty() ||
!mlir::isa<cir::VoidType>(rt.front()) &&
"The return type for a function returning void should be empty "
"instead of a real !cir.void");
return rt.empty();
}

//===----------------------------------------------------------------------===//
// MethodType Definitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
}
}

return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic());
return FuncType::get(ArgTypes, resultType, FI.isVariadic());
}

/// Convert a CIR type to its ABI-specific default form.
Expand Down
35 changes: 35 additions & 0 deletions clang/test/CIR/IR/being_and_nothingness.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: cir-opt %s | FileCheck %s
// Exercise different ways to encode a function returning void
!s32i = !cir.int<s, 32>
!fnptr1 = !cir.ptr<!cir.func<!cir.void(!s32i)>>
// Note there is no !void alias defined
!fnptr2 = !cir.ptr<!cir.func<!void(!s32i)>>
!fnptr3 = !cir.ptr<!cir.func<(!s32i)>>
module {
cir.func @ind1(%fnptr: !fnptr1, %a : !s32i) {
// CHECK: cir.func @ind1(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
cir.return
}

cir.func @ind2(%fnptr: !fnptr2, %a : !s32i) {
// CHECK: cir.func @ind2(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
cir.return
}
cir.func @ind3(%fnptr: !fnptr3, %a : !s32i) {
// CHECK: cir.func @ind3(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
cir.return
}
cir.func @f1() -> !cir.void {
// CHECK: cir.func @f1() {
cir.return
}
// Note there is no !void alias defined
cir.func @f2() -> !void {
// CHECK: cir.func @f2() {
cir.return
}
cir.func @f3() {
// CHECK: cir.func @f3() {
cir.return
}
}
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,9 @@ class AsmParser {
/// Parse an optional keyword or string.
virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;

/// Parse the given exclamation-prefixed keyword if present.
virtual ParseResult parseOptionalExclamationKeyword(StringRef keyword) = 0;

//===--------------------------------------------------------------------===//
// Attribute/Type Parsing
//===--------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/Interfaces/FunctionImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs);

/// Parse a function argument list using `parser`. The `allowVariadic` argument
/// indicates whether functions with variadic arguments are supported. The
/// trailing arguments are populated by this function with names, types,
/// attributes and locations of the arguments.
ParseResult
parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::Argument> &arguments,
bool &isVariadic);

/// Parse a function result list using `parser`.
///
/// function-result-list ::= function-result-list-parens
/// | non-function-type
/// function-result-list-parens ::= `(` `)`
/// | `(` function-result-list-no-parens `)`
/// function-result-list-no-parens ::= function-result (`,` function-result)*
/// function-result ::= type attribute-dict?
///
ParseResult
parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs);

/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
/// input and output types. The parser sets the `typeAttrName` attribute to the
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,19 @@ class AsmParserImpl : public BaseT {
return parseOptionalString(result);
}

/// Parse the given exclamation-prefixed keyword if present.
ParseResult parseOptionalExclamationKeyword(StringRef keyword) override {
if (parser.getToken().isCodeCompletion())
return parser.codeCompleteOptionalTokens(keyword);

// Check that the current token has the same spelling.
if (!parser.getToken().is(Token::Kind::exclamation_identifier) ||
parser.getTokenSpelling() != keyword)
return failure();
parser.consumeToken();
return success();
}

//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
Expand Down
Loading
Loading