diff --git a/jlm/llvm/Makefile.sub b/jlm/llvm/Makefile.sub index ec979bb5b..c0d1c34dc 100644 --- a/jlm/llvm/Makefile.sub +++ b/jlm/llvm/Makefile.sub @@ -4,7 +4,6 @@ libllvm_SOURCES = \ jlm/llvm/backend/jlm2llvm/instruction.cpp \ jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp \ - jlm/llvm/backend/jlm2llvm/type.cpp \ jlm/llvm/backend/dot/DotWriter.cpp \ jlm/llvm/backend/rvsdg2jlm/rvsdg2jlm.cpp \ \ @@ -12,7 +11,6 @@ libllvm_SOURCES = \ jlm/llvm/frontend/InterProceduralGraphConversion.cpp \ jlm/llvm/frontend/LlvmInstructionConversion.cpp \ jlm/llvm/frontend/LlvmModuleConversion.cpp \ - jlm/llvm/frontend/LlvmTypeConversion.cpp \ \ jlm/llvm/ir/aggregation.cpp \ jlm/llvm/ir/Annotation.cpp \ @@ -25,6 +23,7 @@ libllvm_SOURCES = \ jlm/llvm/ir/domtree.cpp \ jlm/llvm/ir/ipgraph.cpp \ jlm/llvm/ir/ipgraph-module.cpp \ + jlm/llvm/ir/TypeConverter.cpp \ jlm/llvm/ir/operators/alloca.cpp \ jlm/llvm/ir/operators/call.cpp \ jlm/llvm/ir/operators/delta.cpp \ @@ -94,7 +93,6 @@ libllvm_HEADERS = \ jlm/llvm/opt/inversion.hpp \ jlm/llvm/opt/RvsdgTreePrinter.hpp \ jlm/llvm/frontend/LlvmModuleConversion.hpp \ - jlm/llvm/frontend/LlvmTypeConversion.hpp \ jlm/llvm/frontend/ControlFlowRestructuring.hpp \ jlm/llvm/frontend/LlvmConversionContext.hpp \ jlm/llvm/frontend/LlvmInstructionConversion.hpp \ @@ -112,6 +110,7 @@ libllvm_HEADERS = \ jlm/llvm/ir/ipgraph.hpp \ jlm/llvm/ir/cfg.hpp \ jlm/llvm/ir/ssa.hpp \ + jlm/llvm/ir/TypeConverter.hpp \ jlm/llvm/ir/variable.hpp \ jlm/llvm/ir/basic-block.hpp \ jlm/llvm/ir/print.hpp \ @@ -135,7 +134,6 @@ libllvm_HEADERS = \ jlm/llvm/backend/rvsdg2jlm/rvsdg2jlm.hpp \ jlm/llvm/backend/rvsdg2jlm/context.hpp \ jlm/llvm/backend/jlm2llvm/jlm2llvm.hpp \ - jlm/llvm/backend/jlm2llvm/type.hpp \ jlm/llvm/backend/jlm2llvm/instruction.hpp \ jlm/llvm/backend/jlm2llvm/context.hpp \ @@ -151,8 +149,6 @@ libllvm_TESTS += \ tests/jlm/llvm/backend/llvm/jlm-llvm/test-function-calls \ tests/jlm/llvm/backend/llvm/jlm-llvm/test-ignore-memory-state \ tests/jlm/llvm/backend/llvm/jlm-llvm/test-select-with-state \ - tests/jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion \ - tests/jlm/llvm/frontend/llvm/LlvmTypeConversionTests \ tests/jlm/llvm/frontend/llvm/LoadTests \ tests/jlm/llvm/frontend/llvm/MemCpyTests \ tests/jlm/llvm/frontend/llvm/StoreTests \ @@ -193,6 +189,7 @@ libllvm_TESTS += \ tests/jlm/llvm/ir/TestAnnotation \ tests/jlm/llvm/ir/TestCallSummary \ tests/jlm/llvm/ir/ThreeAddressCodeTests \ + tests/jlm/llvm/ir/TypeConverterTests \ tests/jlm/llvm/opt/alias-analyses/TestAgnosticMemoryNodeProvider \ tests/jlm/llvm/opt/alias-analyses/TestAndersen \ tests/jlm/llvm/opt/alias-analyses/TestDifferencePropagation \ diff --git a/jlm/llvm/backend/jlm2llvm/context.hpp b/jlm/llvm/backend/jlm2llvm/context.hpp index 1164d1403..19ce04ff4 100644 --- a/jlm/llvm/backend/jlm2llvm/context.hpp +++ b/jlm/llvm/backend/jlm2llvm/context.hpp @@ -6,6 +6,7 @@ #ifndef JLM_LLVM_BACKEND_JLM2LLVM_CONTEXT_HPP #define JLM_LLVM_BACKEND_JLM2LLVM_CONTEXT_HPP +#include #include #include @@ -108,18 +109,10 @@ class context final return it->second; } - inline ::llvm::StructType * - structtype(const StructType::Declaration * dcl) + TypeConverter & + GetTypeConverter() { - auto it = structtypes_.find(dcl); - return it != structtypes_.end() ? it->second : nullptr; - } - - inline void - add_structtype(const StructType::Declaration * dcl, ::llvm::StructType * type) - { - JLM_ASSERT(structtypes_.find(dcl) == structtypes_.end()); - structtypes_[dcl] = type; + return TypeConverter_; } private: @@ -127,7 +120,7 @@ class context final ipgraph_module & im_; std::unordered_map variables_; std::unordered_map nodes_; - std::unordered_map structtypes_; + TypeConverter TypeConverter_; }; } diff --git a/jlm/llvm/backend/jlm2llvm/instruction.cpp b/jlm/llvm/backend/jlm2llvm/instruction.cpp index 5100ef035..02775c76b 100644 --- a/jlm/llvm/backend/jlm2llvm/instruction.cpp +++ b/jlm/llvm/backend/jlm2llvm/instruction.cpp @@ -14,7 +14,6 @@ #include #include -#include #include #include @@ -164,6 +163,8 @@ convert_undef( context & ctx) { JLM_ASSERT(is(op)); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); auto & resultType = *op.result(0); @@ -171,7 +172,8 @@ convert_undef( if (is(resultType)) return nullptr; - return ::llvm::UndefValue::get(convert_type(resultType, ctx)); + auto type = typeConverter.ConvertJlmType(resultType, llvmContext); + return ::llvm::UndefValue::get(type); } static ::llvm::Value * @@ -181,7 +183,11 @@ convert( ::llvm::IRBuilder<> &, context & ctx) { - return ::llvm::PoisonValue::get(convert_type(operation.GetType(), ctx)); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); + + auto type = typeConverter.ConvertJlmType(operation.GetType(), llvmContext); + return ::llvm::PoisonValue::get(type); } static ::llvm::Value * @@ -192,6 +198,8 @@ convert( context & ctx) { auto function = ctx.value(args[0]); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); std::vector<::llvm::Value *> operands; for (size_t n = 1; n < args.size(); n++) @@ -216,7 +224,7 @@ convert( operands.push_back(ctx.value(argument)); } - auto ftype = convert_type(*op.GetFunctionType(), ctx); + auto ftype = typeConverter.ConvertFunctionType(*op.GetFunctionType(), llvmContext); return builder.CreateCall(ftype, function, operands); } @@ -277,13 +285,15 @@ convert_phi( { JLM_ASSERT(is(op)); auto & phi = *static_cast(&op); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); if (rvsdg::is(phi.type())) return nullptr; if (rvsdg::is(phi.type())) return nullptr; - auto t = convert_type(phi.type(), ctx); + auto t = typeConverter.ConvertJlmType(phi.type(), llvmContext); return builder.CreatePHI(t, op.narguments()); } @@ -296,7 +306,10 @@ CreateLoadInstruction( ::llvm::IRBuilder<> & builder, context & ctx) { - auto type = convert_type(loadedType, ctx); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); + + auto type = typeConverter.ConvertJlmType(loadedType, llvmContext); auto loadInstruction = builder.CreateLoad(type, ctx.value(address), isVolatile); loadInstruction->setAlignment(::llvm::Align(alignment)); return loadInstruction; @@ -385,8 +398,10 @@ convert_alloca( { JLM_ASSERT(is(op)); auto & aop = *static_cast(&op); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); - auto t = convert_type(aop.value_type(), ctx); + auto t = typeConverter.ConvertJlmType(aop.value_type(), llvmContext); auto i = builder.CreateAlloca(t, ctx.value(args[0])); i->setAlignment(::llvm::Align(aop.alignment())); return i; @@ -401,9 +416,11 @@ convert_getelementptr( { JLM_ASSERT(is(op) && args.size() >= 2); auto & pop = *static_cast(&op); + auto & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); std::vector<::llvm::Value *> indices; - auto t = convert_type(pop.GetPointeeType(), ctx); + auto t = typeConverter.ConvertJlmType(pop.GetPointeeType(), llvmContext); for (size_t n = 1; n < args.size(); n++) indices.push_back(ctx.value(args[n])); @@ -506,6 +523,8 @@ convert( context & ctx) { JLM_ASSERT(is(op)); + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); std::vector<::llvm::Constant *> data; for (size_t n = 0; n < operands.size(); n++) @@ -516,7 +535,7 @@ convert( } auto at = std::dynamic_pointer_cast(op.result(0)); - auto type = convert_type(*at, ctx); + auto type = typeConverter.ConvertArrayType(*at, llvmContext); return ::llvm::ConstantArray::get(type, data); } @@ -527,7 +546,10 @@ convert( ::llvm::IRBuilder<> &, context & ctx) { - auto type = convert_type(*op.result(0), ctx); + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); + + auto type = typeConverter.ConvertJlmType(*op.result(0), llvmContext); return ::llvm::ConstantAggregateZero::get(type); } @@ -642,11 +664,14 @@ convert( ::llvm::IRBuilder<> &, context & ctx) { + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); + std::vector<::llvm::Constant *> operands; for (const auto & arg : args) operands.push_back(::llvm::cast<::llvm::Constant>(ctx.value(arg))); - auto t = convert_type(op.type(), ctx); + auto t = typeConverter.ConvertStructType(op.type(), llvmContext); return ::llvm::ConstantStruct::get(t, operands); } @@ -657,7 +682,10 @@ convert( ::llvm::IRBuilder<> &, context & ctx) { - auto pointerType = convert_type(operation.GetPointerType(), ctx); + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); + + auto pointerType = typeConverter.ConvertPointerType(operation.GetPointerType(), llvmContext); return ::llvm::ConstantPointerNull::get(pointerType); } @@ -850,22 +878,24 @@ convert_cast( context & ctx) { JLM_ASSERT(::llvm::Instruction::isCast(OPCODE)); + auto & typeConverter = ctx.GetTypeConverter(); + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); auto dsttype = std::dynamic_pointer_cast(op.result(0)); auto operand = operands[0]; if (auto vt = dynamic_cast(&operand->type())) { - auto type = convert_type(fixedvectortype(dsttype, vt->size()), ctx); + auto type = typeConverter.ConvertJlmType(fixedvectortype(dsttype, vt->size()), llvmContext); return builder.CreateCast(OPCODE, ctx.value(operand), type); } if (auto vt = dynamic_cast(&operand->type())) { - auto type = convert_type(scalablevectortype(dsttype, vt->size()), ctx); + auto type = typeConverter.ConvertJlmType(scalablevectortype(dsttype, vt->size()), llvmContext); return builder.CreateCast(OPCODE, ctx.value(operand), type); } - auto type = convert_type(*dsttype, ctx); + auto type = typeConverter.ConvertJlmType(*dsttype, llvmContext); return builder.CreateCast(OPCODE, ctx.value(operand), type); } @@ -888,9 +918,10 @@ convert( context & ctx) { JLM_ASSERT(args.size() == 1); + auto & typeConverter = ctx.GetTypeConverter(); auto & lm = ctx.llvm_module(); - auto fcttype = convert_type(op.fcttype(), ctx); + auto fcttype = typeConverter.ConvertFunctionType(op.fcttype(), lm.getContext()); auto function = lm.getOrInsertFunction("malloc", fcttype); auto operands = std::vector<::llvm::Value *>(1, ctx.value(args[0])); return builder.CreateCall(function, operands); @@ -903,9 +934,12 @@ convert( ::llvm::IRBuilder<> & builder, context & ctx) { + auto & typeConverter = ctx.GetTypeConverter(); auto & llvmmod = ctx.llvm_module(); - auto fcttype = convert_type(rvsdg::FunctionType({ op.argument(0) }, {}), ctx); + auto fcttype = typeConverter.ConvertFunctionType( + rvsdg::FunctionType({ op.argument(0) }, {}), + llvmmod.getContext()); auto function = llvmmod.getOrInsertFunction("free", fcttype); auto operands = std::vector<::llvm::Value *>(1, ctx.value(args[0])); return builder.CreateCall(function, operands); diff --git a/jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp b/jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp index 396a1ed1e..3b4dc705a 100644 --- a/jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp +++ b/jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include @@ -102,6 +101,8 @@ static void create_switch(const cfg_node * node, context & ctx) { JLM_ASSERT(node->noutedges() >= 2); + ::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext(); + auto & typeConverter = ctx.GetTypeConverter(); auto bb = static_cast(node); ::llvm::IRBuilder<> builder(ctx.basic_block(node)); @@ -120,7 +121,8 @@ create_switch(const cfg_node * node, context & ctx) for (const auto & alt : *mop) { auto & type = *std::static_pointer_cast(mop->argument(0)); - auto value = ::llvm::ConstantInt::get(convert_type(type, ctx), alt.first); + auto value = + ::llvm::ConstantInt::get(typeConverter.ConvertBitType(type, llvmContext), alt.first); sw->addCase(value, ctx.basic_block(node->outedge(alt.second)->sink())); } } @@ -294,9 +296,11 @@ ConvertIntAttribute(const llvm::int_attribute & attribute, context & ctx) static ::llvm::Attribute ConvertTypeAttribute(const llvm::type_attribute & attribute, context & ctx) { + auto & typeConverter = ctx.GetTypeConverter(); auto & llvmContext = ctx.llvm_module().getContext(); + auto kind = convert_attribute_kind(attribute.kind()); - auto type = convert_type(attribute.type(), ctx); + auto type = typeConverter.ConvertJlmType(attribute.type(), llvmContext); return ::llvm::Attribute::get(llvmContext, kind, type); } @@ -499,6 +503,7 @@ convert_linkage(const llvm::linkage & linkage) static void convert_ipgraph(context & ctx) { + auto & typeConverter = ctx.GetTypeConverter(); auto & jm = ctx.module(); auto & lm = ctx.llvm_module(); @@ -509,7 +514,7 @@ convert_ipgraph(context & ctx) if (auto dataNode = dynamic_cast(&node)) { - auto type = convert_type(*dataNode->GetValueType(), ctx); + auto type = typeConverter.ConvertJlmType(*dataNode->GetValueType(), lm.getContext()); auto linkage = convert_linkage(dataNode->linkage()); auto gv = new ::llvm::GlobalVariable( @@ -524,7 +529,7 @@ convert_ipgraph(context & ctx) } else if (auto n = dynamic_cast(&node)) { - auto type = convert_type(n->fcttype(), ctx); + auto type = typeConverter.ConvertFunctionType(n->fcttype(), lm.getContext()); auto linkage = convert_linkage(n->linkage()); auto f = ::llvm::Function::Create(type, linkage, n->name(), &lm); ctx.insert(v, f); diff --git a/jlm/llvm/backend/jlm2llvm/type.cpp b/jlm/llvm/backend/jlm2llvm/type.cpp deleted file mode 100644 index e6603f6e2..000000000 --- a/jlm/llvm/backend/jlm2llvm/type.cpp +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Copyright 2017 Nico Reißmann - * See COPYING for terms of redistribution. - */ - -#include -#include -#include - -#include - -#include -#include - -namespace jlm::llvm -{ - -namespace jlm2llvm -{ - -static ::llvm::Type * -convert(const rvsdg::bittype & type, context & ctx) -{ - return ::llvm::Type::getIntNTy(ctx.llvm_module().getContext(), type.nbits()); -} - -static ::llvm::Type * -convert(const rvsdg::FunctionType & functionType, context & ctx) -{ - auto & lctx = ctx.llvm_module().getContext(); - - bool isvararg = false; - std::vector<::llvm::Type *> argumentTypes; - for (auto & argumentType : functionType.Arguments()) - { - if (rvsdg::is(argumentType)) - { - isvararg = true; - continue; - } - - if (rvsdg::is(argumentType)) - continue; - if (rvsdg::is(argumentType)) - continue; - - argumentTypes.push_back(convert_type(*argumentType, ctx)); - } - - /* - The return type can either be (ValueType, StateType, StateType, ...) if the function has - a return value, or (StateType, StateType, ...) if the function returns void. - */ - auto resultType = ::llvm::Type::getVoidTy(lctx); - if (functionType.NumResults() > 0 && rvsdg::is(functionType.ResultType(0))) - resultType = convert_type(functionType.ResultType(0), ctx); - - return ::llvm::FunctionType::get(resultType, argumentTypes, isvararg); -} - -static ::llvm::Type * -convert(const PointerType &, context & ctx) -{ - return ::llvm::PointerType::get(ctx.llvm_module().getContext(), 0); -} - -static ::llvm::Type * -convert(const arraytype & type, context & ctx) -{ - return ::llvm::ArrayType::get(convert_type(type.element_type(), ctx), type.nelements()); -} - -static ::llvm::Type * -convert(const rvsdg::ControlType & type, context & ctx) -{ - if (type.nalternatives() == 2) - return ::llvm::Type::getInt1Ty(ctx.llvm_module().getContext()); - - return ::llvm::Type::getInt32Ty(ctx.llvm_module().getContext()); -} - -static ::llvm::Type * -convert(const fptype & type, context & ctx) -{ - static std::unordered_map map( - { { fpsize::half, ::llvm::Type::getHalfTy }, - { fpsize::flt, ::llvm::Type::getFloatTy }, - { fpsize::dbl, ::llvm::Type::getDoubleTy }, - { fpsize::x86fp80, ::llvm::Type::getX86_FP80Ty }, - { fpsize::fp128, ::llvm::Type::getFP128Ty } }); - - JLM_ASSERT(map.find(type.size()) != map.end()); - return map[type.size()](ctx.llvm_module().getContext()); -} - -static ::llvm::Type * -convert(const StructType & type, context & ctx) -{ - auto & decl = type.GetDeclaration(); - - if (auto st = ctx.structtype(&decl)) - return st; - - auto st = ::llvm::StructType::create(ctx.llvm_module().getContext()); - ctx.add_structtype(&decl, st); - - std::vector<::llvm::Type *> elements; - for (size_t n = 0; n < decl.NumElements(); n++) - elements.push_back(convert_type(decl.GetElement(n), ctx)); - - if (type.HasName()) - st->setName(type.GetName()); - st->setBody(elements, type.IsPacked()); - - return st; -} - -static ::llvm::Type * -convert(const fixedvectortype & type, context & ctx) -{ - return ::llvm::VectorType::get(convert_type(type.type(), ctx), type.size(), false); -} - -static ::llvm::Type * -convert(const scalablevectortype & type, context & ctx) -{ - return ::llvm::VectorType::get(convert_type(type.type(), ctx), type.size(), true); -} - -template -static ::llvm::Type * -convert(const rvsdg::Type & type, context & ctx) -{ - JLM_ASSERT(rvsdg::is(type)); - return convert(*static_cast(&type), ctx); -} - -::llvm::Type * -convert_type(const rvsdg::Type & type, context & ctx) -{ - static std:: - unordered_map> - map({ { typeid(rvsdg::bittype), convert }, - { typeid(rvsdg::FunctionType), convert }, - { typeid(PointerType), convert }, - { typeid(arraytype), convert }, - { typeid(rvsdg::ControlType), convert }, - { typeid(fptype), convert }, - { typeid(StructType), convert }, - { typeid(fixedvectortype), convert }, - { typeid(scalablevectortype), convert } }); - - JLM_ASSERT(map.find(typeid(type)) != map.end()); - return map[typeid(type)](type, ctx); -} - -} -} diff --git a/jlm/llvm/backend/jlm2llvm/type.hpp b/jlm/llvm/backend/jlm2llvm/type.hpp deleted file mode 100644 index 70cf4669f..000000000 --- a/jlm/llvm/backend/jlm2llvm/type.hpp +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2017 Nico Reißmann - * See COPYING for terms of redistribution. - */ - -#ifndef JLM_LLVM_BACKEND_JLM2LLVM_TYPE_HPP -#define JLM_LLVM_BACKEND_JLM2LLVM_TYPE_HPP - -#include -#include -#include -#include - -#include -#include - -namespace llvm -{ - -class FunctionType; - -} - -namespace jlm::llvm::jlm2llvm -{ - -class context; - -::llvm::Type * -convert_type(const rvsdg::Type & type, context & ctx); - -static inline ::llvm::IntegerType * -convert_type(const rvsdg::bittype & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->getTypeID() == ::llvm::Type::IntegerTyID); - return ::llvm::cast<::llvm::IntegerType>(t); -} - -static inline ::llvm::FunctionType * -convert_type(const rvsdg::FunctionType & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->getTypeID() == ::llvm::Type::FunctionTyID); - return ::llvm::cast<::llvm::FunctionType>(t); -} - -static inline ::llvm::PointerType * -convert_type(const PointerType & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->getTypeID() == ::llvm::Type::PointerTyID); - return ::llvm::cast<::llvm::PointerType>(t); -} - -static inline ::llvm::ArrayType * -convert_type(const arraytype & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->getTypeID() == ::llvm::Type::ArrayTyID); - return ::llvm::cast<::llvm::ArrayType>(t); -} - -static inline ::llvm::IntegerType * -convert_type(const rvsdg::ControlType & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->getTypeID() == ::llvm::Type::IntegerTyID); - return ::llvm::cast<::llvm::IntegerType>(t); -} - -static inline ::llvm::Type * -convert_type(const fptype & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->isHalfTy() || t->isFloatTy() || t->isDoubleTy()); - return t; -} - -static inline ::llvm::StructType * -convert_type(const StructType & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->isStructTy()); - return ::llvm::cast<::llvm::StructType>(t); -} - -static inline ::llvm::VectorType * -convert_type(const vectortype & type, context & ctx) -{ - auto t = convert_type(*static_cast(&type), ctx); - JLM_ASSERT(t->isVectorTy()); - return ::llvm::cast<::llvm::VectorType>(t); -} - -} - -#endif diff --git a/jlm/llvm/frontend/LlvmConversionContext.hpp b/jlm/llvm/frontend/LlvmConversionContext.hpp index 6dcdb3523..64c29961c 100644 --- a/jlm/llvm/frontend/LlvmConversionContext.hpp +++ b/jlm/llvm/frontend/LlvmConversionContext.hpp @@ -6,10 +6,10 @@ #ifndef JLM_LLVM_FRONTEND_LLVMCONVERSIONCONTEXT_HPP #define JLM_LLVM_FRONTEND_LLVMCONVERSIONCONTEXT_HPP -#include #include #include #include +#include #include @@ -182,26 +182,6 @@ class context final vmap_[value] = variable; } - const StructType::Declaration * - lookup_declaration(const ::llvm::StructType * type) - { - // Return declaration if we already created one for this type instance - if (auto it = declarations_.find(type); it != declarations_.end()) - { - return it->second; - } - - // Otherwise create a new one and return it - auto declaration = StructType::Declaration::Create(); - for (size_t n = 0; n < type->getNumElements(); n++) - { - declaration->Append(ConvertType(type->getElementType(n), *this)); - } - - declarations_[type] = declaration.get(); - return &module().AddStructTypeDeclaration(std::move(declaration)); - } - inline ipgraph_module & module() const noexcept { @@ -220,6 +200,12 @@ class context final return node_; } + TypeConverter & + GetTypeConverter() noexcept + { + return TypeConverter_; + } + private: ipgraph_module & module_; basic_block_map bbmap_; @@ -228,7 +214,7 @@ class context final llvm::variable * iostate_; llvm::variable * memory_state_; std::unordered_map vmap_; - std::unordered_map declarations_; + TypeConverter TypeConverter_; }; } diff --git a/jlm/llvm/frontend/LlvmInstructionConversion.cpp b/jlm/llvm/frontend/LlvmInstructionConversion.cpp index dc13a908b..db6122233 100644 --- a/jlm/llvm/frontend/LlvmInstructionConversion.cpp +++ b/jlm/llvm/frontend/LlvmInstructionConversion.cpp @@ -104,7 +104,7 @@ convert_undefvalue( { JLM_ASSERT(c->getValueID() == ::llvm::Value::UndefValueVal); - auto t = ConvertType(c->getType(), ctx); + auto t = ctx.GetTypeConverter().ConvertLlvmType(*c->getType()); tacs.push_back(UndefValueOperation::Create(t)); return tacs.back()->result(0); @@ -144,7 +144,7 @@ convert_constantFP( JLM_ASSERT(constant->getValueID() == ::llvm::Value::ConstantFPVal); auto c = ::llvm::cast<::llvm::ConstantFP>(constant); - auto type = ConvertType(c->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*c->getType()); tacs.push_back(ConstantFP::create(c->getValueAPF(), type)); return tacs.back()->result(0); @@ -169,7 +169,7 @@ convert_constantPointerNull( JLM_ASSERT(::llvm::dyn_cast(constant)); auto & c = *::llvm::cast(constant); - auto t = ConvertPointerType(c.getType(), ctx); + auto t = ctx.GetTypeConverter().ConvertPointerType(*c.getType()); tacs.push_back(ConstantPointerNullOperation::Create(t)); return tacs.back()->result(0); @@ -194,7 +194,7 @@ convert_constantAggregateZero( { JLM_ASSERT(c->getValueID() == ::llvm::Value::ConstantAggregateZeroVal); - auto type = ConvertType(c->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*c->getType()); tacs.push_back(ConstantAggregateZero::create(type)); return tacs.back()->result(0); @@ -270,7 +270,7 @@ ConvertConstantStruct( for (size_t n = 0; n < c->getNumOperands(); n++) elements.push_back(ConvertConstant(c->getAggregateElement(n), tacs, ctx)); - auto type = ConvertType(c->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*c->getType()); tacs.push_back(ConstantStruct::create(elements, type)); return tacs.back()->result(0); @@ -288,7 +288,7 @@ convert_constantVector( for (size_t n = 0; n < c->getNumOperands(); n++) elements.push_back(ConvertConstant(c->getAggregateElement(n), tacs, ctx)); - auto type = ConvertType(c->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*c->getType()); tacs.push_back(constantvector_op::create(elements, type)); return tacs.back()->result(0); @@ -318,7 +318,7 @@ ConvertConstant( tacsvector_t & threeAddressCodeVector, llvm::context & context) { - auto type = ConvertType(poisonValue->getType(), context); + auto type = context.GetTypeConverter().ConvertLlvmType(*poisonValue->getType()); threeAddressCodeVector.push_back(PoisonValueOperation::Create(type)); return threeAddressCodeVector.back()->result(0); @@ -459,6 +459,7 @@ static inline const variable * convert_icmp_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, context & ctx) { JLM_ASSERT(instruction->getOpcode() == ::llvm::Instruction::ICmp); + auto & typeConverter = ctx.GetTypeConverter(); auto i = ::llvm::cast(instruction); auto t = i->getOperand(0)->getType(); @@ -547,12 +548,12 @@ convert_icmp_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, else if (t->isPointerTy() || (t->isVectorTy() && t->getScalarType()->isPointerTy())) { auto pt = ::llvm::cast<::llvm::PointerType>(t->isVectorTy() ? t->getScalarType() : t); - binop = std::make_unique(ConvertPointerType(pt, ctx), ptrmap[p]); + binop = std::make_unique(typeConverter.ConvertPointerType(*pt), ptrmap[p]); } else JLM_UNREACHABLE("This should have never happend."); - auto type = ConvertType(i->getType(), ctx); + auto type = typeConverter.ConvertLlvmType(*i->getType()); JLM_ASSERT(is(*binop)); if (t->isVectorTy()) @@ -575,6 +576,7 @@ static const variable * convert_fcmp_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, context & ctx) { JLM_ASSERT(instruction->getOpcode() == ::llvm::Instruction::FCmp); + auto & typeConverter = ctx.GetTypeConverter(); auto i = ::llvm::cast(instruction); auto t = i->getOperand(0)->getType(); @@ -596,14 +598,14 @@ convert_fcmp_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, { ::llvm::CmpInst::FCMP_ULE, fpcmp::ule }, { ::llvm::CmpInst::FCMP_UNE, fpcmp::une } }); - auto type = ConvertType(i->getType(), ctx); + auto type = typeConverter.ConvertLlvmType(*i->getType()); auto op1 = ConvertValue(i->getOperand(0), tacs, ctx); auto op2 = ConvertValue(i->getOperand(1), tacs, ctx); JLM_ASSERT(map.find(i->getPredicate()) != map.end()); auto fptype = t->isVectorTy() ? t->getScalarType() : t; - fpcmp_op operation(map[i->getPredicate()], ExtractFloatingPointSize(fptype)); + fpcmp_op operation(map[i->getPredicate()], typeConverter.ExtractFloatingPointSize(*fptype)); if (t->isVectorTy()) tacs.push_back(vectorbinary_op::create(operation, op1, op2, type)); @@ -621,7 +623,7 @@ convert_load_instruction(::llvm::Instruction * i, tacsvector_t & tacs, context & auto alignment = instruction->getAlign().value(); auto address = ConvertValue(instruction->getPointerOperand(), tacs, ctx); - auto loadedType = ConvertType(instruction->getType(), ctx); + auto loadedType = ctx.GetTypeConverter().ConvertLlvmType(*instruction->getType()); const tacvariable * loadedValue; const tacvariable * memoryState; @@ -743,7 +745,7 @@ convert_phi_instruction(::llvm::Instruction * i, tacsvector_t & tacs, context & // instructions that have not yet been converted. // For now, a phi_op with no operands is created. // Once all basic blocks have been converted, all phi_ops get visited again and given operands. - auto type = ConvertType(i->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*i->getType()); tacs.push_back(phi_op::create({}, type)); return tacs.back()->result(0); } @@ -752,6 +754,7 @@ static const variable * convert_getelementptr_instruction(::llvm::Instruction * inst, tacsvector_t & tacs, context & ctx) { JLM_ASSERT(::llvm::dyn_cast(inst)); + auto & typeConverter = ctx.GetTypeConverter(); auto i = ::llvm::cast<::llvm::GetElementPtrInst>(inst); std::vector indices; @@ -759,8 +762,8 @@ convert_getelementptr_instruction(::llvm::Instruction * inst, tacsvector_t & tac for (auto it = i->idx_begin(); it != i->idx_end(); it++) indices.push_back(ConvertValue(*it, tacs, ctx)); - auto pointeeType = ConvertType(i->getSourceElementType(), ctx); - auto resultType = ConvertType(i->getType(), ctx); + auto pointeeType = typeConverter.ConvertLlvmType(*i->getSourceElementType()); + auto resultType = typeConverter.ConvertLlvmType(*i->getType()); tacs.push_back(GetElementPtrOperation::Create(base, indices, pointeeType, resultType)); @@ -905,7 +908,7 @@ convert_call_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, return convert_memcpy_call(i, tacs, ctx); auto ftype = i->getFunctionType(); - auto convertedFType = ConvertFunctionType(ftype, ctx); + auto convertedFType = ctx.GetTypeConverter().ConvertFunctionType(*ftype); auto arguments = create_arguments(i, tacs, ctx); if (ftype->isVarArg()) @@ -1097,7 +1100,7 @@ convert_binary_operator(::llvm::Instruction * instruction, tacsvector_t & tacs, else JLM_ASSERT(0); - auto type = ConvertType(i->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*i->getType()); auto op1 = ConvertValue(i->getOperand(0), tacs, ctx); auto op2 = ConvertValue(i->getOperand(1), tacs, ctx); @@ -1125,7 +1128,7 @@ convert_alloca_instruction(::llvm::Instruction * instruction, tacsvector_t & tac auto memstate = ctx.memory_state(); auto size = ConvertValue(i->getArraySize(), tacs, ctx); - auto vtype = ConvertType(i->getAllocatedType(), ctx); + auto vtype = ctx.GetTypeConverter().ConvertLlvmType(*i->getAllocatedType()); auto alignment = i->getAlign().value(); tacs.push_back(alloca_op::create(vtype, size, alignment)); @@ -1194,14 +1197,15 @@ static const variable * convert(::llvm::UnaryOperator * unaryOperator, tacsvector_t & threeAddressCodeVector, context & ctx) { JLM_ASSERT(unaryOperator->getOpcode() == ::llvm::Instruction::FNeg); + auto & typeConverter = ctx.GetTypeConverter(); auto type = unaryOperator->getType(); - auto scalarType = ConvertType(type->getScalarType(), ctx); + auto scalarType = typeConverter.ConvertLlvmType(*type->getScalarType()); auto operand = ConvertValue(unaryOperator->getOperand(0), threeAddressCodeVector, ctx); if (type->isVectorTy()) { - auto vectorType = ConvertType(type, ctx); + auto vectorType = typeConverter.ConvertLlvmType(*type); threeAddressCodeVector.push_back(vectorunary_op::create( fpneg_op(std::static_pointer_cast(scalarType)), operand, @@ -1226,6 +1230,7 @@ static const variable * convert_cast_instruction(::llvm::Instruction * i, tacsvector_t & tacs, context & ctx) { JLM_ASSERT(::llvm::dyn_cast<::llvm::CastInst>(i)); + auto & typeConverter = ctx.GetTypeConverter(); auto st = i->getOperand(0)->getType(); auto dt = i->getType(); @@ -1247,11 +1252,11 @@ convert_cast_instruction(::llvm::Instruction * i, tacsvector_t & tacs, context & { ::llvm::Instruction::FPExt, create_unop }, { ::llvm::Instruction::BitCast, create_unop } }); - auto type = ConvertType(i->getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*i->getType()); auto op = ConvertValue(i->getOperand(0), tacs, ctx); - auto srctype = ConvertType(st->isVectorTy() ? st->getScalarType() : st, ctx); - auto dsttype = ConvertType(dt->isVectorTy() ? dt->getScalarType() : dt, ctx); + auto srctype = typeConverter.ConvertLlvmType(*(st->isVectorTy() ? st->getScalarType() : st)); + auto dsttype = typeConverter.ConvertLlvmType(*(dt->isVectorTy() ? dt->getScalarType() : dt)); JLM_ASSERT(map.find(i->getOpcode()) != map.end()); auto unop = map[i->getOpcode()](std::move(srctype), std::move(dsttype)); diff --git a/jlm/llvm/frontend/LlvmModuleConversion.cpp b/jlm/llvm/frontend/LlvmModuleConversion.cpp index 312b74c3a..8b61327c7 100644 --- a/jlm/llvm/frontend/LlvmModuleConversion.cpp +++ b/jlm/llvm/frontend/LlvmModuleConversion.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -218,13 +219,13 @@ ConvertTypeAttribute(const ::llvm::Attribute & attribute, context & ctx) if (attribute.getKindAsEnum() == ::llvm::Attribute::AttrKind::ByVal) { - auto type = ConvertType(attribute.getValueAsType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*attribute.getValueAsType()); return { attribute::kind::ByVal, std::move(type) }; } if (attribute.getKindAsEnum() == ::llvm::Attribute::AttrKind::StructRet) { - auto type = ConvertType(attribute.getValueAsType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*attribute.getValueAsType()); return { attribute::kind::StructRet, std::move(type) }; } @@ -274,7 +275,7 @@ convert_argument(const ::llvm::Argument & argument, context & ctx) { auto function = argument.getParent(); auto name = argument.getName().str(); - auto type = ConvertType(argument.getType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*argument.getType()); auto attributes = convert_attributes(function->getAttributes().getParamAttrs(argument.getArgNo()), ctx); @@ -394,7 +395,7 @@ create_cfg(::llvm::Function & f, context & ctx) const tacvariable * result = nullptr; if (!f.getReturnType()->isVoidTy()) { - auto type = ConvertType(f.getReturnType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*f.getReturnType()); entry_block->append_last(UndefValueOperation::Create(type, "_r_")); result = entry_block->last()->result(0); @@ -459,7 +460,7 @@ declare_globals(::llvm::Module & lm, context & ctx) { auto name = gv.getName().str(); auto constant = gv.isConstant(); - auto type = ConvertType(gv.getValueType(), ctx); + auto type = ctx.GetTypeConverter().ConvertLlvmType(*gv.getValueType()); auto linkage = convert_linkage(gv.getLinkage()); auto section = gv.getSection().str(); @@ -476,7 +477,7 @@ declare_globals(::llvm::Module & lm, context & ctx) { auto name = f.getName().str(); auto linkage = convert_linkage(f.getLinkage()); - auto type = ConvertFunctionType(f.getFunctionType(), ctx); + auto type = ctx.GetTypeConverter().ConvertFunctionType(*f.getFunctionType()); auto attributes = convert_attributes(f.getAttributes().getFnAttrs(), ctx); return function_node::create(ctx.module().ipgraph(), name, type, linkage, attributes); @@ -539,6 +540,8 @@ ConvertLlvmModule(::llvm::Module & m) declare_globals(m, ctx); convert_globals(m, ctx); + im->SetStructTypeDeclarations(ctx.GetTypeConverter().ReleaseStructTypeDeclarations()); + return im; } diff --git a/jlm/llvm/frontend/LlvmTypeConversion.cpp b/jlm/llvm/frontend/LlvmTypeConversion.cpp deleted file mode 100644 index f712e1e9f..000000000 --- a/jlm/llvm/frontend/LlvmTypeConversion.cpp +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright 2014 Nico Reißmann - * See COPYING for terms of redistribution. - */ - -#include -#include - -#include -#include - -#include - -namespace jlm::llvm -{ - -fpsize -ExtractFloatingPointSize(const ::llvm::Type * type) -{ - JLM_ASSERT(type->isFloatingPointTy()); - - static std::unordered_map map( - { { ::llvm::Type::HalfTyID, fpsize::half }, - { ::llvm::Type::FloatTyID, fpsize::flt }, - { ::llvm::Type::DoubleTyID, fpsize::dbl }, - { ::llvm::Type::X86_FP80TyID, fpsize::x86fp80 }, - { ::llvm::Type::FP128TyID, fpsize::fp128 } }); - - auto i = map.find(type->getTypeID()); - JLM_ASSERT(i != map.end()); - return i->second; -} - -static std::shared_ptr -convert_integer_type(const ::llvm::Type * t, context &) -{ - JLM_ASSERT(t->getTypeID() == ::llvm::Type::IntegerTyID); - auto * type = static_cast(t); - - return rvsdg::bittype::Create(type->getBitWidth()); -} - -static std::shared_ptr -convert_pointer_type(const ::llvm::Type * t, context &) -{ - JLM_ASSERT(t->getTypeID() == ::llvm::Type::PointerTyID); - return PointerType::Create(); -} - -static std::shared_ptr -convert_function_type(const ::llvm::Type * t, context & ctx) -{ - JLM_ASSERT(t->getTypeID() == ::llvm::Type::FunctionTyID); - auto type = ::llvm::cast(t); - - /* arguments */ - std::vector> argumentTypes; - for (size_t n = 0; n < type->getNumParams(); n++) - argumentTypes.push_back(ConvertType(type->getParamType(n), ctx)); - if (type->isVarArg()) - argumentTypes.push_back(create_varargtype()); - argumentTypes.push_back(iostatetype::Create()); - argumentTypes.push_back(MemoryStateType::Create()); - - /* results */ - std::vector> resultTypes; - if (type->getReturnType()->getTypeID() != ::llvm::Type::VoidTyID) - resultTypes.push_back(ConvertType(type->getReturnType(), ctx)); - resultTypes.push_back(iostatetype::Create()); - resultTypes.push_back(MemoryStateType::Create()); - - return rvsdg::FunctionType::Create(std::move(argumentTypes), std::move(resultTypes)); -} - -static std::shared_ptr -convert_fp_type(const ::llvm::Type * t, context &) -{ - static const std::unordered_map<::llvm::Type::TypeID, fpsize> map( - { { ::llvm::Type::HalfTyID, fpsize::half }, - { ::llvm::Type::FloatTyID, fpsize::flt }, - { ::llvm::Type::DoubleTyID, fpsize::dbl }, - { ::llvm::Type::X86_FP80TyID, fpsize::x86fp80 }, - { ::llvm::Type::FP128TyID, fpsize::fp128 } }); - - auto i = map.find(t->getTypeID()); - JLM_ASSERT(i != map.end()); - return fptype::Create(i->second); -} - -static std::shared_ptr -convert_struct_type(const ::llvm::Type * t, context & ctx) -{ - JLM_ASSERT(t->isStructTy()); - auto type = static_cast(t); - - auto isPacked = type->isPacked(); - auto & declaration = *ctx.lookup_declaration(type); - - return type->hasName() ? StructType::Create(type->getName().str(), isPacked, declaration) - : StructType::Create(isPacked, declaration); -} - -static std::shared_ptr -convert_array_type(const ::llvm::Type * t, context & ctx) -{ - JLM_ASSERT(t->isArrayTy()); - auto etype = ConvertType(t->getArrayElementType(), ctx); - return arraytype::Create(std::move(etype), t->getArrayNumElements()); -} - -static std::shared_ptr -convert_fixed_vector_type(const ::llvm::Type * t, context & ctx) -{ - JLM_ASSERT(t->getTypeID() == ::llvm::Type::FixedVectorTyID); - auto type = ConvertType(t->getScalarType(), ctx); - return fixedvectortype::Create( - std::move(type), - ::llvm::cast<::llvm::FixedVectorType>(t)->getNumElements()); -} - -static std::shared_ptr -convert_scalable_vector_type(const ::llvm::Type * t, context & ctx) -{ - JLM_ASSERT(t->getTypeID() == ::llvm::Type::ScalableVectorTyID); - auto type = ConvertType(t->getScalarType(), ctx); - return scalablevectortype::Create( - std::move(type), - ::llvm::cast<::llvm::ScalableVectorType>(t)->getMinNumElements()); -} - -std::shared_ptr -ConvertType(const ::llvm::Type * t, context & ctx) -{ - static std::unordered_map< - ::llvm::Type::TypeID, - std::function(const ::llvm::Type *, context &)>> - map({ { ::llvm::Type::IntegerTyID, convert_integer_type }, - { ::llvm::Type::PointerTyID, convert_pointer_type }, - { ::llvm::Type::FunctionTyID, convert_function_type }, - { ::llvm::Type::HalfTyID, convert_fp_type }, - { ::llvm::Type::FloatTyID, convert_fp_type }, - { ::llvm::Type::DoubleTyID, convert_fp_type }, - { ::llvm::Type::X86_FP80TyID, convert_fp_type }, - { ::llvm::Type::FP128TyID, convert_fp_type }, - { ::llvm::Type::StructTyID, convert_struct_type }, - { ::llvm::Type::ArrayTyID, convert_array_type }, - { ::llvm::Type::FixedVectorTyID, convert_fixed_vector_type }, - { ::llvm::Type::ScalableVectorTyID, convert_scalable_vector_type } }); - - JLM_ASSERT(map.find(t->getTypeID()) != map.end()); - return map[t->getTypeID()](t, ctx); -} - -} diff --git a/jlm/llvm/frontend/LlvmTypeConversion.hpp b/jlm/llvm/frontend/LlvmTypeConversion.hpp deleted file mode 100644 index 5d279e762..000000000 --- a/jlm/llvm/frontend/LlvmTypeConversion.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2014 2015 Nico Reißmann - * See COPYING for terms of redistribution. - */ - -#ifndef JLM_LLVM_FRONTEND_LLVMTYPECONVERSION_HPP -#define JLM_LLVM_FRONTEND_LLVMTYPECONVERSION_HPP - -#include - -#include -#include - -#include - -namespace llvm -{ -class ArrayType; -class Type; -} - -namespace jlm::llvm -{ - -class context; - -fpsize -ExtractFloatingPointSize(const ::llvm::Type * type); - -std::shared_ptr -ConvertType(const ::llvm::Type * type, context & ctx); - -static inline std::shared_ptr -ConvertFunctionType(const ::llvm::FunctionType * type, context & ctx) -{ - auto t = ConvertType(::llvm::cast<::llvm::Type>(type), ctx); - JLM_ASSERT(dynamic_cast(t.get())); - return std::dynamic_pointer_cast(t); -} - -static inline std::shared_ptr -ConvertPointerType(const ::llvm::PointerType * type, context & ctx) -{ - auto t = ConvertType(::llvm::cast<::llvm::Type>(type), ctx); - JLM_ASSERT(dynamic_cast(t.get())); - return std::dynamic_pointer_cast(t); -} - -} - -#endif diff --git a/jlm/llvm/ir/TypeConverter.cpp b/jlm/llvm/ir/TypeConverter.cpp new file mode 100644 index 000000000..4a2eb80c7 --- /dev/null +++ b/jlm/llvm/ir/TypeConverter.cpp @@ -0,0 +1,307 @@ +/* + * Copyright 2025 Nico Reißmann + * See COPYING for terms of redistribution. + */ + +#include +#include +#include + +#include +#include + +namespace jlm::llvm +{ + +fpsize +TypeConverter::ExtractFloatingPointSize(const ::llvm::Type & type) +{ + JLM_ASSERT(type.isFloatingPointTy()); + + switch (type.getTypeID()) + { + case ::llvm::Type::HalfTyID: + return fpsize::half; + case ::llvm::Type::FloatTyID: + return fpsize::flt; + case ::llvm::Type::DoubleTyID: + return fpsize::dbl; + case ::llvm::Type::X86_FP80TyID: + return fpsize::x86fp80; + case ::llvm::Type::FP128TyID: + return fpsize::fp128; + default: + JLM_UNREACHABLE("TypeConverter::ExtractFloatingPointSize: Unsupported floating point size."); + } +} + +::llvm::IntegerType * +TypeConverter::ConvertBitType(const rvsdg::bittype & bitType, ::llvm::LLVMContext & context) +{ + return ::llvm::Type::getIntNTy(context, bitType.nbits()); +} + +::llvm::FunctionType * +TypeConverter::ConvertFunctionType( + const rvsdg::FunctionType & functionType, + ::llvm::LLVMContext & context) +{ + bool isVariableArgument = false; + std::vector<::llvm::Type *> argumentTypes; + for (auto & argumentType : functionType.Arguments()) + { + if (rvsdg::is(argumentType)) + { + isVariableArgument = true; + continue; + } + + if (rvsdg::is(argumentType)) + continue; + if (rvsdg::is(argumentType)) + continue; + + argumentTypes.push_back(ConvertJlmType(*argumentType, context)); + } + + // The return type can either be (ValueType, StateType, StateType, ...) if the function has + // a return value, or (StateType, StateType, ...) if the function returns void. + auto resultType = ::llvm::Type::getVoidTy(context); + if (functionType.NumResults() > 0 && rvsdg::is(functionType.ResultType(0))) + resultType = ConvertJlmType(functionType.ResultType(0), context); + + return ::llvm::FunctionType::get(resultType, argumentTypes, isVariableArgument); +} + +std::shared_ptr +TypeConverter::ConvertFunctionType(const ::llvm::FunctionType & functionType) +{ + // Arguments + std::vector> argumentTypes; + for (size_t n = 0; n < functionType.getNumParams(); n++) + argumentTypes.push_back(ConvertLlvmType(*functionType.getParamType(n))); + if (functionType.isVarArg()) + argumentTypes.push_back(create_varargtype()); + argumentTypes.push_back(iostatetype::Create()); + argumentTypes.push_back(MemoryStateType::Create()); + + // Results + std::vector> resultTypes; + if (functionType.getReturnType()->getTypeID() != ::llvm::Type::VoidTyID) + resultTypes.push_back(ConvertLlvmType(*functionType.getReturnType())); + resultTypes.push_back(iostatetype::Create()); + resultTypes.push_back(MemoryStateType::Create()); + + return rvsdg::FunctionType::Create(std::move(argumentTypes), std::move(resultTypes)); +} + +::llvm::PointerType * +TypeConverter::ConvertPointerType(const PointerType &, ::llvm::LLVMContext & context) +{ + // FIXME: we default the address space to zero + return ::llvm::PointerType::get(context, 0); +} + +std::shared_ptr +TypeConverter::ConvertPointerType(const ::llvm::PointerType & pointerType) +{ + JLM_ASSERT(pointerType.getAddressSpace() == 0); + return PointerType::Create(); +} + +::llvm::ArrayType * +TypeConverter::ConvertArrayType(const arraytype & type, ::llvm::LLVMContext & context) +{ + return ::llvm::ArrayType::get(ConvertJlmType(type.element_type(), context), type.nelements()); +} + +::llvm::Type * +TypeConverter::ConvertFloatingPointType(const fptype & type, ::llvm::LLVMContext & context) +{ + switch (type.size()) + { + case fpsize::half: + return ::llvm::Type::getHalfTy(context); + case fpsize::flt: + return ::llvm::Type::getFloatTy(context); + case fpsize::dbl: + return ::llvm::Type::getDoubleTy(context); + case fpsize::fp128: + return ::llvm::Type::getFP128Ty(context); + case fpsize::x86fp80: + return ::llvm::Type::getX86_FP80Ty(context); + default: + JLM_UNREACHABLE("TypeConverter::ConvertFloatingPointType: Unhandled floating point size."); + } +} + +::llvm::StructType * +TypeConverter::ConvertStructType(const StructType & type, ::llvm::LLVMContext & context) +{ + auto & declaration = type.GetDeclaration(); + + if (StructTypeMap_.HasValue(&declaration)) + return StructTypeMap_.LookupValue(&declaration); + + const auto llvmStructType = ::llvm::StructType::create(context); + StructTypeMap_.Insert(llvmStructType, &declaration); + + std::vector<::llvm::Type *> elements; + for (size_t n = 0; n < declaration.NumElements(); n++) + elements.push_back(ConvertJlmType(declaration.GetElement(n), context)); + + if (type.HasName()) + llvmStructType->setName(type.GetName()); + llvmStructType->setBody(elements, type.IsPacked()); + + return llvmStructType; +} + +::llvm::Type * +TypeConverter::ConvertJlmType(const rvsdg::Type & type, ::llvm::LLVMContext & context) +{ + if (const auto bitType = dynamic_cast(&type)) + { + return ConvertBitType(*bitType, context); + } + + if (const auto functionType = dynamic_cast(&type)) + { + return ConvertFunctionType(*functionType, context); + } + + if (const auto pointerType = dynamic_cast(&type)) + { + return ConvertPointerType(*pointerType, context); + } + + if (const auto arrayType = dynamic_cast(&type)) + { + return ConvertArrayType(*arrayType, context); + } + + if (const auto controlType = dynamic_cast(&type)) + { + return controlType->nalternatives() == 2 ? ::llvm::Type::getInt1Ty(context) + : ::llvm::Type::getInt32Ty(context); + } + + if (const auto floatingPointType = dynamic_cast(&type)) + { + return ConvertFloatingPointType(*floatingPointType, context); + } + + if (const auto structType = dynamic_cast(&type)) + { + return ConvertStructType(*structType, context); + } + + if (const auto fixedVectorType = dynamic_cast(&type)) + { + return ::llvm::VectorType::get( + ConvertJlmType(fixedVectorType->type(), context), + fixedVectorType->size(), + false); + } + + if (const auto scalableVectorType = dynamic_cast(&type)) + { + return ::llvm::VectorType::get( + ConvertJlmType(scalableVectorType->type(), context), + scalableVectorType->size(), + true); + } + + JLM_UNREACHABLE("TypeConverter::ConvertJlmType: Unhandled jlm type."); +} + +std::shared_ptr +TypeConverter::ConvertLlvmType(::llvm::Type & type) +{ + switch (type.getTypeID()) + { + case ::llvm::Type::IntegerTyID: + { + const auto integerType = ::llvm::cast<::llvm::IntegerType>(&type); + return rvsdg::bittype::Create(integerType->getBitWidth()); + } + case ::llvm::Type::PointerTyID: + return ConvertPointerType(*::llvm::cast<::llvm::PointerType>(&type)); + case ::llvm::Type::FunctionTyID: + return ConvertFunctionType(*::llvm::cast<::llvm::FunctionType>(&type)); + case ::llvm::Type::HalfTyID: + return fptype::Create(fpsize::half); + case ::llvm::Type::FloatTyID: + return fptype::Create(fpsize::flt); + case ::llvm::Type::DoubleTyID: + return fptype::Create(fpsize::dbl); + case ::llvm::Type::X86_FP80TyID: + return fptype::Create(fpsize::x86fp80); + case ::llvm::Type::FP128TyID: + return fptype::Create(fpsize::fp128); + case ::llvm::Type::StructTyID: + { + const auto structType = ::llvm::cast<::llvm::StructType>(&type); + const auto isPacked = structType->isPacked(); + auto & declaration = GetOrCreateStructDeclaration(*structType); + + return structType->hasName() + ? StructType::Create(structType->getName().str(), isPacked, declaration) + : StructType::Create(isPacked, declaration); + } + case ::llvm::Type::ArrayTyID: + { + auto elementType = ConvertLlvmType(*type.getArrayElementType()); + return arraytype::Create(std::move(elementType), type.getArrayNumElements()); + } + case ::llvm::Type::FixedVectorTyID: + { + auto scalarType = ConvertLlvmType(*type.getScalarType()); + return fixedvectortype::Create( + std::move(scalarType), + ::llvm::cast<::llvm::FixedVectorType>(&type)->getNumElements()); + } + case ::llvm::Type::ScalableVectorTyID: + { + auto scalarType = ConvertLlvmType(*type.getScalarType()); + return scalablevectortype::Create( + std::move(scalarType), + ::llvm::cast<::llvm::ScalableVectorType>(&type)->getMinNumElements()); + } + default: + JLM_UNREACHABLE("TypeConverter::ConvertLlvmType: Unhandled llvm type."); + } +} + +std::vector> && +TypeConverter::ReleaseStructTypeDeclarations() +{ + StructTypeMap_.Clear(); + return std::move(Declarations_); +} + +const StructType::Declaration & +TypeConverter::GetOrCreateStructDeclaration(::llvm::StructType & structType) +{ + // Return declaration if we already created one for this type instance + if (StructTypeMap_.HasKey(&structType)) + { + return *StructTypeMap_.LookupKey(&structType); + } + + // Otherwise create a new one, insert it, and return it + auto declaration = StructType::Declaration::Create(); + for (size_t n = 0; n < structType.getNumElements(); n++) + { + declaration->Append(ConvertLlvmType(*structType.getElementType(n))); + } + + const auto ptr = declaration.get(); + Declarations_.push_back(std::move(declaration)); + const bool wasInserted = StructTypeMap_.Insert(&structType, ptr); + JLM_ASSERT(wasInserted); + + return *ptr; +} + +} diff --git a/jlm/llvm/ir/TypeConverter.hpp b/jlm/llvm/ir/TypeConverter.hpp new file mode 100644 index 000000000..bbfb6692d --- /dev/null +++ b/jlm/llvm/ir/TypeConverter.hpp @@ -0,0 +1,107 @@ +/* + * Copyright 2025 Nico Reißmann + * See COPYING for terms of redistribution. + */ + +#ifndef JLM_LLVM_IR_TYPECONVERTER_HPP +#define JLM_LLVM_IR_TYPECONVERTER_HPP + +#include +#include + +#include + +namespace llvm +{ +class ArrayType; +class FunctionType; +class IntegerType; +class LLVMContext; +class PointerType; +class StructType; +class Type; +} + +namespace jlm::rvsdg +{ +class bittype; +class ControlType; +class FunctionType; +class Type; +} + +namespace jlm::llvm +{ + +/** + * Converts Llvm to Jlm types and vice versa. + */ +class TypeConverter final +{ +public: + TypeConverter() = default; + + TypeConverter(const TypeConverter &) = delete; + + TypeConverter(const TypeConverter &&) = delete; + + TypeConverter & + operator=(const TypeConverter &) = delete; + + TypeConverter & + operator=(const TypeConverter &&) = delete; + + static fpsize + ExtractFloatingPointSize(const ::llvm::Type & type); + + static ::llvm::IntegerType * + ConvertBitType(const rvsdg::bittype & bitType, ::llvm::LLVMContext & context); + + ::llvm::FunctionType * + ConvertFunctionType(const rvsdg::FunctionType & functionType, ::llvm::LLVMContext & context); + + static ::llvm::PointerType * + ConvertPointerType(const PointerType & type, ::llvm::LLVMContext & context); + + ::llvm::StructType * + ConvertStructType(const StructType & type, ::llvm::LLVMContext & context); + + ::llvm::ArrayType * + ConvertArrayType(const arraytype & type, ::llvm::LLVMContext & context); + + ::llvm::Type * + ConvertJlmType(const rvsdg::Type & type, ::llvm::LLVMContext & context); + + std::shared_ptr + ConvertFunctionType(const ::llvm::FunctionType & functionType); + + static std::shared_ptr + ConvertPointerType(const ::llvm::PointerType & pointerType); + + std::shared_ptr + ConvertLlvmType(::llvm::Type & type); + + /** + * Releases all struct type declarations from the type converter and clears the mapping between + * Jlm struct type declarations and Llvm struct types. The caller is the new owner of the + * declarations. + * + * @return A vector of declarations. + */ + std::vector> && + ReleaseStructTypeDeclarations(); + +private: + static ::llvm::Type * + ConvertFloatingPointType(const fptype & type, ::llvm::LLVMContext & context); + + const StructType::Declaration & + GetOrCreateStructDeclaration(::llvm::StructType & structType); + + std::vector> Declarations_; + util::BijectiveMap<::llvm::StructType *, const StructType::Declaration *> StructTypeMap_; +}; + +} + +#endif // JLM_LLVM_IR_TYPECONVERTER_HPP diff --git a/jlm/llvm/ir/ipgraph-module.hpp b/jlm/llvm/ir/ipgraph-module.hpp index 0f9436559..aab6b1c5e 100644 --- a/jlm/llvm/ir/ipgraph-module.hpp +++ b/jlm/llvm/ir/ipgraph-module.hpp @@ -173,16 +173,14 @@ class ipgraph_module final } /** - * Adds a struct type declaration to the module. The module becomes the owner of the declaration. + * Adds struct type declarations to the module. The module becomes the ownwer of the declarations. * - * @param declaration A declaration that is added to the module. - * @return A reference to the added documentation. + * @param declarations The declarations added to the module */ - const StructType::Declaration & - AddStructTypeDeclaration(std::unique_ptr declaration) + void + SetStructTypeDeclarations(std::vector> && declarations) { - StructTypeDeclarations_.emplace_back(std::move(declaration)); - return *StructTypeDeclarations_.back(); + StructTypeDeclarations_ = std::move(declarations); } /** diff --git a/tests/jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion.cpp b/tests/jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion.cpp deleted file mode 100644 index 4c7df3c68..000000000 --- a/tests/jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2020 Nico Reißmann - * See COPYING for terms of redistribution. - */ - -#include -#include - -#include -#include -#include - -static void -test_structtype(jlm::llvm::jlm2llvm::context & ctx) -{ - using namespace jlm::llvm; - - auto decl1 = StructType::Declaration::Create( - { jlm::rvsdg::bittype::Create(8), jlm::rvsdg::bittype::Create(32) }); - StructType st1("mystruct", false, *decl1); - auto ct = jlm2llvm::convert_type(st1, ctx); - - assert(ct->getName() == "mystruct"); - assert(!ct->isPacked()); - assert(ct->getNumElements() == 2); - assert(ct->getElementType(0)->isIntegerTy(8)); - assert(ct->getElementType(1)->isIntegerTy(32)); - - auto decl2 = StructType::Declaration::Create({ jlm::rvsdg::bittype::Create(32), - jlm::rvsdg::bittype::Create(8), - jlm::rvsdg::bittype::Create(32) }); - StructType st2(true, *decl2); - ct = jlm2llvm::convert_type(st2, ctx); - - assert(ct->getName().empty()); - assert(ct->isPacked()); - assert(ct->getNumElements() == 3); -} - -static int -test() -{ - using namespace jlm::llvm; - - llvm::LLVMContext ctx; - llvm::Module lm("module", ctx); - - ipgraph_module im(jlm::util::filepath(""), "", ""); - jlm2llvm::context jctx(im, lm); - - test_structtype(jctx); - - return 0; -} - -JLM_UNIT_TEST_REGISTER("jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion", test) diff --git a/tests/jlm/llvm/frontend/llvm/LlvmTypeConversionTests.cpp b/tests/jlm/llvm/frontend/llvm/LlvmTypeConversionTests.cpp deleted file mode 100644 index 3333c4752..000000000 --- a/tests/jlm/llvm/frontend/llvm/LlvmTypeConversionTests.cpp +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2024 Halvor Linder Henriksen - * See COPYING for terms of redistribution. - */ - -#include -#include - -#include -#include - -static void -TestTypeConversion( - jlm::llvm::context & jlm_context, - llvm::Type * llvm_type, - jlm::llvm::fpsize jlm_type_size) -{ - using namespace llvm; - - auto jlm_type = jlm::llvm::ConvertType(llvm_type, jlm_context); - auto floating_point_type = dynamic_cast(jlm_type.get()); - - assert(floating_point_type && floating_point_type->size() == jlm_type_size); -} - -static int -TypeConversion() -{ - using namespace jlm::llvm; - - llvm::LLVMContext llvm_ctx; - llvm::Module lm("module", llvm_ctx); - - ipgraph_module im(jlm::util::filepath(""), "", ""); - auto jlm_ctx = context(im); - - TestTypeConversion(jlm_ctx, ::llvm::Type::getHalfTy(llvm_ctx), jlm::llvm::fpsize::half); - TestTypeConversion(jlm_ctx, ::llvm::Type::getFloatTy(llvm_ctx), jlm::llvm::fpsize::flt); - TestTypeConversion(jlm_ctx, ::llvm::Type::getDoubleTy(llvm_ctx), jlm::llvm::fpsize::dbl); - TestTypeConversion(jlm_ctx, ::llvm::Type::getX86_FP80Ty(llvm_ctx), jlm::llvm::fpsize::x86fp80); - TestTypeConversion(jlm_ctx, ::llvm::Type::getFP128Ty(llvm_ctx), jlm::llvm::fpsize::fp128); - - return 0; -} - -JLM_UNIT_TEST_REGISTER( - "jlm/llvm/frontend/llvm/LlvmTypeConversionTests-TypeConversion", - TypeConversion) diff --git a/tests/jlm/llvm/ir/TypeConverterTests.cpp b/tests/jlm/llvm/ir/TypeConverterTests.cpp new file mode 100644 index 000000000..5a614e974 --- /dev/null +++ b/tests/jlm/llvm/ir/TypeConverterTests.cpp @@ -0,0 +1,714 @@ +/* + * Copyright 2025 Nico Reißmann + * See COPYING for terms of redistribution. + */ + +#include + +#include +#include +#include +#include + +#include +#include + +static int +LlvmIntegerTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto i1 = ::llvm::IntegerType::get(context, 1); + const auto i2 = ::llvm::IntegerType::get(context, 2); + const auto i4 = ::llvm::IntegerType::get(context, 4); + const auto i8 = ::llvm::IntegerType::get(context, 8); + const auto i16 = ::llvm::IntegerType::get(context, 16); + const auto i32 = ::llvm::IntegerType::get(context, 32); + const auto i64 = ::llvm::IntegerType::get(context, 64); + + // Act + const auto i1BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i1)); + const auto i2BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i2)); + const auto i4BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i4)); + const auto i8BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i8)); + const auto i16BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i16)); + const auto i32BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i32)); + const auto i64BitType = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*i64)); + + // Assert + assert(i1BitType && i1BitType->nbits() == 1); + assert(i2BitType && i2BitType->nbits() == 2); + assert(i4BitType && i4BitType->nbits() == 4); + assert(i8BitType && i8BitType->nbits() == 8); + assert(i16BitType && i16BitType->nbits() == 16); + assert(i32BitType && i32BitType->nbits() == 32); + assert(i64BitType && i64BitType->nbits() == 64); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmIntegerTypeConversion", + LlvmIntegerTypeConversion); + +static int +LlvmPointerTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto pointerTypeLlvm = ::llvm::PointerType::get(context, 0); + + // Act + const auto pointerTypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*pointerTypeLlvm)); + + // Assert + assert(pointerTypeJlm); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmPointerTypeConversion", + LlvmPointerTypeConversion); + +static int +LlvmFunctionTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto voidType = ::llvm::Type::getVoidTy(context); + auto i32Type = ::llvm::Type::getInt32Ty(context); + const auto functionType1Llvm = ::llvm::FunctionType::get(voidType, { i32Type, i32Type }, false); + const auto functionType2Llvm = ::llvm::FunctionType::get(i32Type, {}, false); + const auto functionType3Llvm = ::llvm::FunctionType::get(i32Type, { i32Type, i32Type }, true); + + // Act + const auto functionType1Jlm = std::dynamic_pointer_cast( + typeConverter.ConvertLlvmType(*functionType1Llvm)); + const auto functionType2Jlm = std::dynamic_pointer_cast( + typeConverter.ConvertLlvmType(*functionType2Llvm)); + const auto functionType3Jlm = std::dynamic_pointer_cast( + typeConverter.ConvertLlvmType(*functionType3Llvm)); + + // Assert + assert(functionType1Jlm != nullptr); + assert(functionType1Jlm->NumArguments() == 4); + assert(functionType1Jlm->NumResults() == 2); + auto arguments = functionType1Jlm->Arguments(); + assert(is(arguments[0])); + assert(is(arguments[1])); + assert(is(arguments[2])); + assert(is(arguments[3])); + auto results = functionType1Jlm->Results(); + assert(is(results[0])); + assert(is(results[1])); + + assert(functionType2Jlm != nullptr); + assert(functionType2Jlm->NumArguments() == 2); + assert(functionType2Jlm->NumResults() == 3); + arguments = functionType2Jlm->Arguments(); + assert(is(arguments[0])); + assert(is(arguments[1])); + results = functionType2Jlm->Results(); + assert(is(results[0])); + assert(is(results[1])); + assert(is(results[2])); + + assert(functionType3Jlm != nullptr); + assert(functionType3Jlm->NumArguments() == 5); + assert(functionType3Jlm->NumResults() == 3); + arguments = functionType3Jlm->Arguments(); + assert(is(arguments[0])); + assert(is(arguments[1])); + assert(is(arguments[2])); + assert(is(arguments[3])); + assert(is(arguments[4])); + results = functionType3Jlm->Results(); + assert(is(results[0])); + assert(is(results[1])); + assert(is(results[2])); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmFunctionTypeConversion", + LlvmFunctionTypeConversion); + +static int +LlvmFloatingPointTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto halfTypeLlvm = ::llvm::Type::getHalfTy(context); + const auto floatTypeLlvm = ::llvm::Type::getFloatTy(context); + const auto doubleTypeLlvm = ::llvm::Type::getDoubleTy(context); + const auto x86fp80TypeLlvm = ::llvm::Type::getX86_FP80Ty(context); + const auto fp128TypeLlvm = ::llvm::Type::getFP128Ty(context); + + // Act + const auto halfTypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*halfTypeLlvm)); + const auto floatTypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*floatTypeLlvm)); + const auto doubleTypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*doubleTypeLlvm)); + const auto x86fp80TypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*x86fp80TypeLlvm)); + const auto fp128TypeJlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*fp128TypeLlvm)); + + // Assert + assert(halfTypeJlm && halfTypeJlm->size() == fpsize::half); + assert(floatTypeJlm && floatTypeJlm->size() == fpsize::flt); + assert(doubleTypeJlm && doubleTypeJlm->size() == fpsize::dbl); + assert(x86fp80TypeJlm && x86fp80TypeJlm->size() == fpsize::x86fp80); + assert(fp128TypeJlm->size() == fpsize::fp128); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmFloatingPointTypeConversion", + LlvmFloatingPointTypeConversion); + +static int +LlvmStructTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + auto i32Type = ::llvm::Type::getInt32Ty(context); + const auto halfType = ::llvm::Type::getHalfTy(context); + const auto structType1Llvm = ::llvm::StructType::get(context, { i32Type, halfType }, false); + const auto structType2Llvm = + ::llvm::StructType::get(context, { i32Type, i32Type, i32Type }, true); + const auto structType3Llvm = ::llvm::StructType::get(context, { i32Type }, true); + structType3Llvm->setName("myStruct"); + + // Act + const auto structType1Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*structType1Llvm)); + const auto structType2Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*structType2Llvm)); + const auto structType3Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*structType3Llvm)); + + const auto structType4Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*structType1Llvm)); + + // Assert + assert(structType1Jlm); + assert(structType1Jlm->GetDeclaration().NumElements() == 2); + assert(!structType1Jlm->IsPacked()); + assert(!structType1Jlm->HasName()); + + assert(structType2Jlm); + assert(structType2Jlm->GetDeclaration().NumElements() == 3); + assert(structType2Jlm->IsPacked()); + assert(!structType2Jlm->HasName()); + + assert(structType3Jlm); + assert(structType3Jlm->GetDeclaration().NumElements() == 1); + assert(structType3Jlm->IsPacked()); + assert(structType3Jlm->HasName() && structType3Jlm->GetName() == "myStruct"); + + assert(&structType1Jlm->GetDeclaration() != &structType2Jlm->GetDeclaration()); + assert(&structType1Jlm->GetDeclaration() != &structType3Jlm->GetDeclaration()); + assert(&structType1Jlm->GetDeclaration() == &structType4Jlm->GetDeclaration()); + assert(&structType2Jlm->GetDeclaration() != &structType3Jlm->GetDeclaration()); + + const auto declarations = typeConverter.ReleaseStructTypeDeclarations(); + assert(declarations.size() == 3); + + // We released all struct declarations. After that, translating the same type again should get + // us a new declarations. + const auto structType5Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*structType1Llvm)); + + assert(&structType5Jlm->GetDeclaration() != &structType1Jlm->GetDeclaration()); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmStructTypeConversion", + LlvmStructTypeConversion); + +static int +LlvmArrayTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto i32Type = ::llvm::Type::getInt32Ty(context); + const auto halfType = ::llvm::Type::getHalfTy(context); + const auto arrayType1Llvm = ::llvm::ArrayType::get(i32Type, 4); + const auto arrayType2Llvm = ::llvm::ArrayType::get(halfType, 9); + + // Act + const auto arrayType1Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*arrayType1Llvm)); + const auto arrayType2Jlm = + std::dynamic_pointer_cast(typeConverter.ConvertLlvmType(*arrayType2Llvm)); + + // Assert + assert(arrayType1Jlm); + assert(is(arrayType1Jlm->element_type())); + assert(arrayType1Jlm->nelements() == 4); + + assert(arrayType2Jlm); + assert(is(arrayType2Jlm->element_type())); + assert(arrayType2Jlm->nelements() == 9); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmArrayTypeConversion", + LlvmArrayTypeConversion); + +static int +LlvmVectorTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto i32Type = ::llvm::Type::getInt32Ty(context); + const auto halfType = ::llvm::Type::getHalfTy(context); + const auto vectorType1Llvm = ::llvm::VectorType::get(i32Type, 4, false); + const auto vectorType2Llvm = ::llvm::VectorType::get(halfType, 9, true); + + // Act + const auto vectorType1Jlm = std::dynamic_pointer_cast( + typeConverter.ConvertLlvmType(*vectorType1Llvm)); + const auto vectorType2Jlm = std::dynamic_pointer_cast( + typeConverter.ConvertLlvmType(*vectorType2Llvm)); + + // Assert + assert(vectorType1Jlm); + assert(is(vectorType1Jlm->type())); + assert(vectorType1Jlm->size() == 4); + + assert(vectorType2Jlm); + assert(is(vectorType2Jlm->type())); + assert(vectorType2Jlm->size() == 9); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-LlvmVectorTypeConversion", + LlvmVectorTypeConversion); + +static int +JLmBitTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto i1 = jlm::rvsdg::bittype::Create(1); + const auto i2 = jlm::rvsdg::bittype::Create(2); + const auto i4 = jlm::rvsdg::bittype::Create(4); + const auto i8 = jlm::rvsdg::bittype::Create(8); + const auto i16 = jlm::rvsdg::bittype::Create(16); + const auto i32 = jlm::rvsdg::bittype::Create(32); + const auto i64 = jlm::rvsdg::bittype::Create(64); + + // Act + const auto i1Type = typeConverter.ConvertJlmType(*i1, context); + const auto i2Type = typeConverter.ConvertJlmType(*i2, context); + const auto i4Type = typeConverter.ConvertJlmType(*i4, context); + const auto i8Type = typeConverter.ConvertJlmType(*i8, context); + const auto i16Type = typeConverter.ConvertJlmType(*i16, context); + const auto i32Type = typeConverter.ConvertJlmType(*i32, context); + const auto i64Type = typeConverter.ConvertJlmType(*i64, context); + + // Assert + assert(i1Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i1Type->getIntegerBitWidth() == 1); + + assert(i2Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i2Type->getIntegerBitWidth() == 2); + + assert(i4Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i4Type->getIntegerBitWidth() == 4); + + assert(i8Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i8Type->getIntegerBitWidth() == 8); + + assert(i16Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i16Type->getIntegerBitWidth() == 16); + + assert(i32Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i32Type->getIntegerBitWidth() == 32); + + assert(i64Type->getTypeID() == llvm::Type::IntegerTyID); + assert(i64Type->getIntegerBitWidth() == 64); + + return 0; +} + +JLM_UNIT_TEST_REGISTER("jlm/llvm/ir/TypeConverterTests-JLmBitTypeConversion", JLmBitTypeConversion); + +static int +JlmFunctionTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + auto bit32Type = bittype::Create(32); + auto ioStateType = iostatetype::Create(); + auto memoryStateType = MemoryStateType::Create(); + auto varArgType = varargtype::Create(); + const auto functionType1Jlm = FunctionType::Create( + { bit32Type, bit32Type, ioStateType, memoryStateType }, + { memoryStateType, ioStateType }); + const auto functionType2Jlm = FunctionType::Create( + { ioStateType, memoryStateType }, + { bit32Type, memoryStateType, ioStateType }); + const auto functionType3Jlm = FunctionType::Create( + { bit32Type, bit32Type, varArgType, ioStateType, memoryStateType }, + { bit32Type, memoryStateType, ioStateType }); + + // Act + const auto functionType1Llvm = + llvm::dyn_cast(typeConverter.ConvertJlmType(*functionType1Jlm, context)); + const auto functionType2Llvm = + llvm::dyn_cast(typeConverter.ConvertJlmType(*functionType2Jlm, context)); + const auto functionType3Llvm = + llvm::dyn_cast(typeConverter.ConvertJlmType(*functionType3Jlm, context)); + + // Assert + assert(functionType1Llvm != nullptr); + assert(functionType1Llvm->getNumParams() == 2); + assert(functionType1Llvm->getParamType(0)->getTypeID() == llvm::Type::IntegerTyID); + assert(functionType1Llvm->getParamType(1)->getTypeID() == llvm::Type::IntegerTyID); + assert(functionType1Llvm->getReturnType()->getTypeID() == llvm::Type::VoidTyID); + assert(!functionType1Llvm->isVarArg()); + + assert(functionType2Llvm != nullptr); + assert(functionType2Llvm->getNumParams() == 0); + assert(functionType2Llvm->getReturnType()->getTypeID() == llvm::Type::IntegerTyID); + assert(!functionType2Llvm->isVarArg()); + + assert(functionType3Llvm != nullptr); + assert(functionType3Llvm->getNumParams() == 2); + assert(functionType3Llvm->getParamType(0)->getTypeID() == llvm::Type::IntegerTyID); + assert(functionType3Llvm->getParamType(1)->getTypeID() == llvm::Type::IntegerTyID); + assert(functionType3Llvm->getReturnType()->getTypeID() == llvm::Type::IntegerTyID); + assert(functionType3Llvm->isVarArg()); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmFunctionTypeConversion", + JlmFunctionTypeConversion); + +static int +JlmPointerTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto pointerTypeJlm = PointerType::Create(); + + // Act + const auto pointerTypeLlvm = + llvm::dyn_cast(typeConverter.ConvertJlmType(*pointerTypeJlm, context)); + + // Assert + assert(pointerTypeLlvm); + assert(pointerTypeLlvm->getAddressSpace() == 0); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmPointerTypeConversion", + JlmPointerTypeConversion); + +static int +JlmArrayTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto bit32Type = bittype::Create(32); + const auto halfType = fptype::Create(fpsize::half); + const auto arrayType1Jlm = arraytype::Create(bit32Type, 4); + const auto arrayType2Jlm = arraytype::Create(halfType, 9); + + // Act + const auto arrayType1Llvm = typeConverter.ConvertJlmType(*arrayType1Jlm, context); + const auto arrayType2Llvm = typeConverter.ConvertJlmType(*arrayType2Jlm, context); + + // Assert + assert(arrayType1Llvm->isArrayTy()); + assert(arrayType1Llvm->getArrayNumElements() == 4); + assert(arrayType1Llvm->getArrayElementType()->getTypeID() == llvm::Type::IntegerTyID); + + assert(arrayType2Llvm->isArrayTy()); + assert(arrayType2Llvm->getArrayNumElements() == 9); + assert(arrayType2Llvm->getArrayElementType()->getTypeID() == llvm::Type::HalfTyID); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmArrayTypeConversion", + JlmArrayTypeConversion); + +static int +JlmControlTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto controlType1 = jlm::rvsdg::ControlType::Create(2); + const auto controlType10 = jlm::rvsdg::ControlType::Create(10); + + // Act + const auto integerType1Llvm = typeConverter.ConvertJlmType(*controlType1, context); + const auto integerType2Llvm = typeConverter.ConvertJlmType(*controlType10, context); + + // Assert + assert(integerType1Llvm->getTypeID() == llvm::Type::IntegerTyID); + assert(integerType1Llvm->getIntegerBitWidth() == 1); + + assert(integerType2Llvm->getTypeID() == llvm::Type::IntegerTyID); + assert(integerType2Llvm->getIntegerBitWidth() == 32); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmControlTypeConversion", + JlmControlTypeConversion); + +static int +JlmFloatingPointTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto halfTypeJlm = fptype::Create(fpsize::half); + const auto floatTypeJlm = fptype::Create(fpsize::flt); + const auto doubleTypeJlm = fptype::Create(fpsize::dbl); + const auto x86fp80TypeJlm = fptype::Create(fpsize::x86fp80); + const auto fp128TypeJlm = fptype::Create(fpsize::fp128); + + // Act + const auto halfTypeLlvm = typeConverter.ConvertJlmType(*halfTypeJlm, context); + const auto floatTypeLlvm = typeConverter.ConvertJlmType(*floatTypeJlm, context); + const auto doubleTypeLlvm = typeConverter.ConvertJlmType(*doubleTypeJlm, context); + const auto x86fp80TypeLlvm = typeConverter.ConvertJlmType(*x86fp80TypeJlm, context); + const auto fp128TypeLlvm = typeConverter.ConvertJlmType(*fp128TypeJlm, context); + + // Assert + assert(halfTypeLlvm->getTypeID() == llvm::Type::HalfTyID); + assert(floatTypeLlvm->getTypeID() == llvm::Type::FloatTyID); + assert(doubleTypeLlvm->getTypeID() == llvm::Type::DoubleTyID); + assert(x86fp80TypeLlvm->getTypeID() == llvm::Type::X86_FP80TyID); + assert(fp128TypeLlvm->getTypeID() == llvm::Type::FP128TyID); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmFloatingPointTypeConversion", + JlmFloatingPointTypeConversion); + +static int +JlmStructTypeConversion() +{ + using namespace jlm::llvm; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto bit32Type = jlm::rvsdg::bittype::Create(32); + const auto halfType = fptype::Create(fpsize::half); + + const auto declaration1 = StructType::Declaration::Create({ bit32Type, halfType }); + const auto declaration2 = StructType::Declaration::Create({ bit32Type, bit32Type, bit32Type }); + const auto declaration3 = StructType::Declaration::Create({ bit32Type }); + + const auto structType1Jlm = StructType::Create(false, *declaration1); + const auto structType2Jlm = StructType::Create(true, *declaration2); + const auto structType3Jlm = StructType::Create("myStruct", true, *declaration3); + + // Act + const auto structType1Llvm = typeConverter.ConvertJlmType(*structType1Jlm, context); + const auto structType2Llvm = typeConverter.ConvertJlmType(*structType2Jlm, context); + const auto structType3Llvm = typeConverter.ConvertJlmType(*structType3Jlm, context); + + const auto structType4Llvm = typeConverter.ConvertJlmType(*structType1Jlm, context); + + // Assert + assert(structType1Llvm->getTypeID() == llvm::Type::StructTyID); + assert(structType1Llvm->getStructNumElements() == 2); + assert(structType1Llvm->getStructElementType(0)->getTypeID() == llvm::Type::IntegerTyID); + assert(structType1Llvm->getStructElementType(1)->getTypeID() == llvm::Type::HalfTyID); + assert(!llvm::dyn_cast(structType1Llvm)->isPacked()); + + assert(structType2Llvm->getTypeID() == llvm::Type::StructTyID); + assert(structType2Llvm->getStructNumElements() == 3); + assert(structType2Llvm->getStructElementType(0)->getTypeID() == llvm::Type::IntegerTyID); + assert(structType2Llvm->getStructElementType(1)->getTypeID() == llvm::Type::IntegerTyID); + assert(structType2Llvm->getStructElementType(2)->getTypeID() == llvm::Type::IntegerTyID); + assert(llvm::dyn_cast(structType2Llvm)->isPacked()); + + assert(structType3Llvm->getTypeID() == llvm::Type::StructTyID); + assert(structType3Llvm->getStructNumElements() == 1); + assert(structType3Llvm->getStructElementType(0)->getTypeID() == llvm::Type::IntegerTyID); + assert(structType3Llvm->getStructName() == "myStruct"); + assert(llvm::dyn_cast(structType3Llvm)->isPacked()); + + assert(structType4Llvm == structType1Llvm); + + // The type converter created no jlm struct types. It is therefore not the owner of any + // declarations. + const auto declarations = typeConverter.ReleaseStructTypeDeclarations(); + assert(declarations.size() == 0); + + // Converting the same type again after the declaration release should give us a new Llvm type + const auto structType5Llvm = typeConverter.ConvertJlmType(*structType1Jlm, context); + assert(structType5Llvm != structType1Llvm); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmStructTypeConversion", + JlmStructTypeConversion); + +static int +JlmFixedVectorTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto bit32Type = bittype::Create(32); + const auto fixedVectorType1 = fixedvectortype::Create(bit32Type, 2); + const auto fixedVectorType2 = fixedvectortype::Create(bit32Type, 4); + + // Act + const auto vectorType1 = + llvm::dyn_cast(typeConverter.ConvertJlmType(*fixedVectorType1, context)); + const auto vectorType2 = + llvm::dyn_cast(typeConverter.ConvertJlmType(*fixedVectorType2, context)); + + // Assert + assert(vectorType1->getTypeID() == llvm::Type::FixedVectorTyID); + assert(vectorType1->getElementType()->getTypeID() == llvm::Type::IntegerTyID); + assert(vectorType1->getElementCount().getFixedValue() == 2); + + assert(vectorType2->getTypeID() == llvm::Type::FixedVectorTyID); + assert(vectorType2->getElementType()->getTypeID() == llvm::Type::IntegerTyID); + assert(vectorType2->getElementCount().getFixedValue() == 4); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmFixedVectorTypeConversion", + JlmFixedVectorTypeConversion); + +static int +JlmScalableVectorTypeConversion() +{ + using namespace jlm::llvm; + using namespace jlm::rvsdg; + + // Arrange + llvm::LLVMContext context; + TypeConverter typeConverter; + + const auto bit32Type = bittype::Create(32); + const auto scalableVectorType1 = scalablevectortype::Create(bit32Type, 2); + const auto scalableVectorType2 = scalablevectortype::Create(bit32Type, 4); + + // Act + const auto vectorType1 = + llvm::dyn_cast(typeConverter.ConvertJlmType(*scalableVectorType1, context)); + const auto vectorType2 = + llvm::dyn_cast(typeConverter.ConvertJlmType(*scalableVectorType2, context)); + + // Assert + assert(vectorType1->getTypeID() == llvm::Type::ScalableVectorTyID); + assert(vectorType1->getElementType()->getTypeID() == llvm::Type::IntegerTyID); + assert(vectorType1->getElementCount().getKnownMinValue() == 2); + + assert(vectorType2->getTypeID() == llvm::Type::ScalableVectorTyID); + assert(vectorType2->getElementType()->getTypeID() == llvm::Type::IntegerTyID); + assert(vectorType2->getElementCount().getKnownMinValue() == 4); + + return 0; +} + +JLM_UNIT_TEST_REGISTER( + "jlm/llvm/ir/TypeConverterTests-JlmScalableVectorTypeConversion", + JlmScalableVectorTypeConversion);