Skip to content

Commit

Permalink
[CIR][LLVMLowering] Lower cir.objectsize (#545)
Browse files Browse the repository at this point in the history
Lowers `cir.objectsize` to `llvm.objectsize`
  • Loading branch information
ghehg authored and lanza committed Apr 17, 2024
1 parent e40c251 commit 87a61f3
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 82 deletions.
27 changes: 10 additions & 17 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1775,31 +1775,24 @@ def GetGlobalOp : CIR_Op<"get_global",
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Get the address of a global variable";
let description = [{
The `cir.get_global` operation retrieves the address pointing to a
named global variable. If the global variable is marked constant, writing
to the resulting address (such as through a `cir.store` operation) is
undefined. Resulting type must always be a `!cir.ptr<...>` type.
The `cir.get_global` operation retrieves the address pointing to a
named global variable. If the global variable is marked constant, writing
to the resulting address (such as through a `cir.store` operation) is
undefined. Resulting type must always be a `!cir.ptr<...>` type.

Addresses of thread local globals can only be retrieved if this operation
is marked `thread_local`, which indicates the address isn't constant.
Example:

Example:
```mlir
%x = cir.get_global @foo : !cir.ptr<i32>
...
%y = cir.get_global thread_local @batata : !cir.ptr<i32>
```
```mlir
%x = cir.get_global @foo : !cir.ptr<i32>
```
}];

let arguments = (ins FlatSymbolRefAttr:$name, UnitAttr:$tls);
let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs Res<CIR_PointerType, "", []>:$addr);

// FIXME: we should not be printing `cir.ptr` below, that should come
// from the pointer type directly.
let assemblyFormat = [{
(`thread_local` $tls^)?
$name `:` `cir.ptr` type($addr) attr-dict
}];
let assemblyFormat = "$name `:` `cir.ptr` type($addr) attr-dict";

// `GetGlobalOp` is fully verified by its traits.
let hasVerifier = 0;
Expand Down
8 changes: 3 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -697,11 +697,9 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::GlobalOp>(loc, uniqueName, type, isConst, linkage);
}

mlir::Value createGetGlobal(mlir::cir::GlobalOp global,
bool threadLocal = false) {
return create<mlir::cir::GetGlobalOp>(global.getLoc(),
getPointerTo(global.getSymType()),
global.getName(), threadLocal);
mlir::Value createGetGlobal(mlir::cir::GlobalOp global) {
return create<mlir::cir::GetGlobalOp>(
global.getLoc(), getPointerTo(global.getSymType()), global.getName());
}

mlir::Value createGetBitfield(mlir::Location loc, mlir::Type resultType,
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,11 @@ static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
if (CGF.getLangOpts().OpenMP)
llvm_unreachable("not implemented");

// Traditional LLVM codegen handles thread local separately, CIR handles
// as part of getAddrOfGlobalVar.
auto V = CGF.CGM.getAddrOfGlobalVar(VD);

if (VD->getTLSKind() != VarDecl::TLS_None)
llvm_unreachable("NYI");

auto RealVarTy = CGF.getTypes().convertTypeForMem(VD->getType());
auto realPtrTy = CGF.getBuilder().getPointerTo(RealVarTy);
if (realPtrTy != V.getType())
Expand Down
5 changes: 2 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,11 @@ mlir::Value CIRGenModule::getAddrOfGlobalVar(const VarDecl *D, mlir::Type Ty,
if (!Ty)
Ty = getTypes().convertTypeForMem(ASTTy);

bool tlsAccess = D->getTLSKind() != VarDecl::TLS_None;
auto g = buildGlobal(D, Ty, IsForDefinition);
auto ptrTy =
mlir::cir::PointerType::get(builder.getContext(), g.getSymType());
return builder.create<mlir::cir::GetGlobalOp>(
getLoc(D->getSourceRange()), ptrTy, g.getSymName(), tlsAccess);
return builder.create<mlir::cir::GetGlobalOp>(getLoc(D->getSourceRange()),
ptrTy, g.getSymName());
}

mlir::cir::GlobalViewAttr
Expand Down
8 changes: 2 additions & 6 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1634,13 +1634,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
<< "' does not reference a valid cir.global or cir.func";

mlir::Type symTy;
if (auto g = dyn_cast<GlobalOp>(op)) {
if (auto g = dyn_cast<GlobalOp>(op))
symTy = g.getSymType();
// Verify that for thread local global access, the global needs to
// be marked with tls bits.
if (getTls() && !g.getTlsModel())
return emitOpError("access to global not marked thread local");
} else if (auto f = dyn_cast<FuncOp>(op))
else if (auto f = dyn_cast<FuncOp>(op))
symTy = f.getFunctionType();
else
llvm_unreachable("shall not get here");
Expand Down
45 changes: 33 additions & 12 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1613,16 +1613,7 @@ class CIRGetGlobalOpLowering

auto type = getTypeConverter()->convertType(op.getType());
auto symbol = op.getName();
mlir::Operation *newop =
rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, symbol);

if (op.getTls()) {
// Handle access to TLS via intrinsic.
newop = rewriter.create<mlir::LLVM::ThreadlocalAddressOp>(
op.getLoc(), type, newop->getResult(0));
}

rewriter.replaceOp(op, newop);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddressOfOp>(op, type, symbol);
return mlir::success();
}
};
Expand Down Expand Up @@ -2287,6 +2278,36 @@ class CIRBitClrsbOpLowering
}
};

class CIRObjSizeOpLowering
: public mlir::OpConversionPattern<mlir::cir::ObjSizeOp> {
public:
using OpConversionPattern<mlir::cir::ObjSizeOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::ObjSizeOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto llvmResTy = getTypeConverter()->convertType(op.getType());
auto loc = op->getLoc();

auto llvmIntrinNameAttr =
mlir::StringAttr::get(rewriter.getContext(), "llvm.objectsize");
mlir::cir::SizeInfoType kindInfo = op.getKind();
auto falseValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), false);
auto trueValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, rewriter.getI1Type(), true);

rewriter.replaceOpWithNewOp<mlir::LLVM::CallIntrinsicOp>(
op, llvmResTy, llvmIntrinNameAttr,
mlir::ValueRange{adaptor.getPtr(),
kindInfo == mlir::cir::SizeInfoType::max ? falseValue
: trueValue,
trueValue, op.getDynamic() ? trueValue : falseValue});

return mlir::LogicalResult::success();
}
};

class CIRBitClzOpLowering
: public mlir::OpConversionPattern<mlir::cir::BitClzOp> {
public:
Expand Down Expand Up @@ -3033,8 +3054,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRVectorShuffleVecLowering, CIRStackSaveLowering,
CIRStackRestoreLowering, CIRUnreachableLowering, CIRTrapLowering,
CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering,
CIRPrefetchLowering, CIRIsConstantOpLowering>(converter,
patterns.getContext());
CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering>(
converter, patterns.getContext());
}

namespace {
Expand Down
28 changes: 28 additions & 0 deletions clang/test/CIR/CodeGen/pass-object-size.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=LLVM
void b(void *__attribute__((pass_object_size(0))));
void e(void *__attribute__((pass_object_size(2))));
void c() {
int a;
int d[a];
b(d);
e(d);
}

// CIR: cir.func no_proto @c()
// CIR: [[TMP0:%.*]] = cir.alloca !s32i, cir.ptr <!s32i>, %{{[0-9]+}} : !u64i, ["vla"] {alignment = 16 : i64}
// CIR: [[TMP1:%.*]] = cir.cast(bitcast, [[TMP0]] : !cir.ptr<!s32i>), !cir.ptr<!void>
// CIR-NEXT: [[TMP2:%.*]] = cir.objsize([[TMP1]] : <!void>, max) -> !u64i
// CIR-NEXT: cir.call @b([[TMP1]], [[TMP2]]) : (!cir.ptr<!void>, !u64i) -> ()
// CIR: [[TMP3:%.*]] = cir.cast(bitcast, [[TMP0]] : !cir.ptr<!s32i>), !cir.ptr<!void>
// CIR: [[TMP4:%.*]] = cir.objsize([[TMP3]] : <!void>, min) -> !u64i
// CIR-NEXT: cir.call @e([[TMP3]], [[TMP4]]) : (!cir.ptr<!void>, !u64i) -> ()

// LLVM: define void @c()
// LLVM: [[TMP0:%.*]] = alloca i32, i64 %{{[0-9]+}},
// LLVM: [[TMP1:%.*]] = call i64 @llvm.objectsize.i64.p0(ptr [[TMP0]], i1 false, i1 true, i1 false),
// LLVM-NEXT: call void @b(ptr [[TMP0]], i64 [[TMP1]])
// LLVM: [[TMP2:%.*]] = call i64 @llvm.objectsize.i64.p0(ptr [[TMP0]], i1 true, i1 true, i1 false),
// LLVM-NEXT: call void @e(ptr [[TMP0]], i64 [[TMP2]])
11 changes: 0 additions & 11 deletions clang/test/CIR/CodeGen/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s

extern __thread int b;
int c(void) { return *&b; }
// CIR: cir.global "private" external tls_dyn @b : !s32i
// CIR: cir.func @c() -> !s32i
// CIR: %[[TLS_ADDR:.*]] = cir.get_global thread_local @b : cir.ptr <!s32i>

__thread int a;
// CIR: cir.global external tls_dyn @a = #cir.int<0> : !s32i

// LLVM: @b = external thread_local global i32
// LLVM: @a = thread_local global i32 0

// LLVM-LABEL: @c
// LLVM: = call ptr @llvm.threadlocal.address.p0(ptr @b)
14 changes: 1 addition & 13 deletions clang/test/CIR/IR/global.cir
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ module {
cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i

cir.global "private" external tls_dyn @batata : !s32i
cir.func @f35() {
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
cir.return
}
}

// CHECK: cir.global external @a = #cir.int<3> : !s32i
Expand Down Expand Up @@ -97,10 +91,4 @@ module {
// CHECK: cir.global external tls_dyn @model0 = #cir.int<0> : !s32i
// CHECK: cir.global external tls_local_dyn @model1 = #cir.int<0> : !s32i
// CHECK: cir.global external tls_init_exec @model2 = #cir.int<0> : !s32i
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i

// CHECK: cir.global "private" external tls_dyn @batata : !s32i
// CHECK: cir.func @f35() {
// CHECK: %0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
// CHECK: cir.return
// CHECK: }
// CHECK: cir.global external tls_local_exec @model3 = #cir.int<0> : !s32i
13 changes: 0 additions & 13 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1033,16 +1033,3 @@ cir.func @bad_fetch(%x: !cir.ptr<!cir.float>, %y: !cir.float) -> () {
%12 = cir.atomic.fetch(xor, %x : !cir.ptr<!cir.float>, %y : !cir.float, seq_cst) : !cir.float
cir.return
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.global "private" external @batata : !s32i
cir.func @f35() {
// expected-error@+1 {{access to global not marked thread local}}
%0 = cir.get_global thread_local @batata : cir.ptr <!s32i>
cir.return
}
}

0 comments on commit 87a61f3

Please sign in to comment.