Skip to content

Commit

Permalink
Bundle Lllvm/Jlm type conversion in single class (#764)
Browse files Browse the repository at this point in the history
Refactors the type conversion of Llvm/Jlm types and bundles the logic in
a single class instead of distributing it all over the translation
passes.
  • Loading branch information
phate authored Jan 22, 2025
1 parent 7c3c119 commit c762641
Show file tree
Hide file tree
Showing 17 changed files with 1,247 additions and 663 deletions.
9 changes: 3 additions & 6 deletions jlm/llvm/Makefile.sub
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
libllvm_SOURCES = \
jlm/llvm/backend/jlm2llvm/instruction.cpp \
jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp \
jlm/llvm/backend/jlm2llvm/type.cpp \
jlm/llvm/backend/dot/DotWriter.cpp \
jlm/llvm/backend/rvsdg2jlm/rvsdg2jlm.cpp \
\
jlm/llvm/frontend/ControlFlowRestructuring.cpp \
jlm/llvm/frontend/InterProceduralGraphConversion.cpp \
jlm/llvm/frontend/LlvmInstructionConversion.cpp \
jlm/llvm/frontend/LlvmModuleConversion.cpp \
jlm/llvm/frontend/LlvmTypeConversion.cpp \
\
jlm/llvm/ir/aggregation.cpp \
jlm/llvm/ir/Annotation.cpp \
Expand All @@ -25,6 +23,7 @@ libllvm_SOURCES = \
jlm/llvm/ir/domtree.cpp \
jlm/llvm/ir/ipgraph.cpp \
jlm/llvm/ir/ipgraph-module.cpp \
jlm/llvm/ir/TypeConverter.cpp \
jlm/llvm/ir/operators/alloca.cpp \
jlm/llvm/ir/operators/call.cpp \
jlm/llvm/ir/operators/delta.cpp \
Expand Down Expand Up @@ -94,7 +93,6 @@ libllvm_HEADERS = \
jlm/llvm/opt/inversion.hpp \
jlm/llvm/opt/RvsdgTreePrinter.hpp \
jlm/llvm/frontend/LlvmModuleConversion.hpp \
jlm/llvm/frontend/LlvmTypeConversion.hpp \
jlm/llvm/frontend/ControlFlowRestructuring.hpp \
jlm/llvm/frontend/LlvmConversionContext.hpp \
jlm/llvm/frontend/LlvmInstructionConversion.hpp \
Expand All @@ -112,6 +110,7 @@ libllvm_HEADERS = \
jlm/llvm/ir/ipgraph.hpp \
jlm/llvm/ir/cfg.hpp \
jlm/llvm/ir/ssa.hpp \
jlm/llvm/ir/TypeConverter.hpp \
jlm/llvm/ir/variable.hpp \
jlm/llvm/ir/basic-block.hpp \
jlm/llvm/ir/print.hpp \
Expand All @@ -135,7 +134,6 @@ libllvm_HEADERS = \
jlm/llvm/backend/rvsdg2jlm/rvsdg2jlm.hpp \
jlm/llvm/backend/rvsdg2jlm/context.hpp \
jlm/llvm/backend/jlm2llvm/jlm2llvm.hpp \
jlm/llvm/backend/jlm2llvm/type.hpp \
jlm/llvm/backend/jlm2llvm/instruction.hpp \
jlm/llvm/backend/jlm2llvm/context.hpp \

Expand All @@ -151,8 +149,6 @@ libllvm_TESTS += \
tests/jlm/llvm/backend/llvm/jlm-llvm/test-function-calls \
tests/jlm/llvm/backend/llvm/jlm-llvm/test-ignore-memory-state \
tests/jlm/llvm/backend/llvm/jlm-llvm/test-select-with-state \
tests/jlm/llvm/backend/llvm/jlm-llvm/test-type-conversion \
tests/jlm/llvm/frontend/llvm/LlvmTypeConversionTests \
tests/jlm/llvm/frontend/llvm/LoadTests \
tests/jlm/llvm/frontend/llvm/MemCpyTests \
tests/jlm/llvm/frontend/llvm/StoreTests \
Expand Down Expand Up @@ -193,6 +189,7 @@ libllvm_TESTS += \
tests/jlm/llvm/ir/TestAnnotation \
tests/jlm/llvm/ir/TestCallSummary \
tests/jlm/llvm/ir/ThreeAddressCodeTests \
tests/jlm/llvm/ir/TypeConverterTests \
tests/jlm/llvm/opt/alias-analyses/TestAgnosticMemoryNodeProvider \
tests/jlm/llvm/opt/alias-analyses/TestAndersen \
tests/jlm/llvm/opt/alias-analyses/TestDifferencePropagation \
Expand Down
17 changes: 5 additions & 12 deletions jlm/llvm/backend/jlm2llvm/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef JLM_LLVM_BACKEND_JLM2LLVM_CONTEXT_HPP
#define JLM_LLVM_BACKEND_JLM2LLVM_CONTEXT_HPP

#include <jlm/llvm/ir/TypeConverter.hpp>
#include <jlm/llvm/ir/types.hpp>
#include <jlm/util/common.hpp>

Expand Down Expand Up @@ -108,26 +109,18 @@ class context final
return it->second;
}

inline ::llvm::StructType *
structtype(const StructType::Declaration * dcl)
TypeConverter &
GetTypeConverter()
{
auto it = structtypes_.find(dcl);
return it != structtypes_.end() ? it->second : nullptr;
}

inline void
add_structtype(const StructType::Declaration * dcl, ::llvm::StructType * type)
{
JLM_ASSERT(structtypes_.find(dcl) == structtypes_.end());
structtypes_[dcl] = type;
return TypeConverter_;
}

private:
::llvm::Module & lm_;
ipgraph_module & im_;
std::unordered_map<const llvm::variable *, ::llvm::Value *> variables_;
std::unordered_map<const llvm::cfg_node *, ::llvm::BasicBlock *> nodes_;
std::unordered_map<const StructType::Declaration *, ::llvm::StructType *> structtypes_;
TypeConverter TypeConverter_;
};

}
Expand Down
68 changes: 51 additions & 17 deletions jlm/llvm/backend/jlm2llvm/instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include <jlm/llvm/backend/jlm2llvm/context.hpp>
#include <jlm/llvm/backend/jlm2llvm/instruction.hpp>
#include <jlm/llvm/backend/jlm2llvm/type.hpp>

#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
Expand Down Expand Up @@ -164,14 +163,17 @@ convert_undef(
context & ctx)
{
JLM_ASSERT(is<UndefValueOperation>(op));
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto & resultType = *op.result(0);

// MemoryState has no llvm representation.
if (is<MemoryStateType>(resultType))
return nullptr;

return ::llvm::UndefValue::get(convert_type(resultType, ctx));
auto type = typeConverter.ConvertJlmType(resultType, llvmContext);
return ::llvm::UndefValue::get(type);
}

static ::llvm::Value *
Expand All @@ -181,7 +183,11 @@ convert(
::llvm::IRBuilder<> &,
context & ctx)
{
return ::llvm::PoisonValue::get(convert_type(operation.GetType(), ctx));
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto type = typeConverter.ConvertJlmType(operation.GetType(), llvmContext);
return ::llvm::PoisonValue::get(type);
}

static ::llvm::Value *
Expand All @@ -192,6 +198,8 @@ convert(
context & ctx)
{
auto function = ctx.value(args[0]);
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

std::vector<::llvm::Value *> operands;
for (size_t n = 1; n < args.size(); n++)
Expand All @@ -216,7 +224,7 @@ convert(
operands.push_back(ctx.value(argument));
}

auto ftype = convert_type(*op.GetFunctionType(), ctx);
auto ftype = typeConverter.ConvertFunctionType(*op.GetFunctionType(), llvmContext);
return builder.CreateCall(ftype, function, operands);
}

Expand Down Expand Up @@ -277,13 +285,15 @@ convert_phi(
{
JLM_ASSERT(is<phi_op>(op));
auto & phi = *static_cast<const llvm::phi_op *>(&op);
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

if (rvsdg::is<iostatetype>(phi.type()))
return nullptr;
if (rvsdg::is<MemoryStateType>(phi.type()))
return nullptr;

auto t = convert_type(phi.type(), ctx);
auto t = typeConverter.ConvertJlmType(phi.type(), llvmContext);
return builder.CreatePHI(t, op.narguments());
}

Expand All @@ -296,7 +306,10 @@ CreateLoadInstruction(
::llvm::IRBuilder<> & builder,
context & ctx)
{
auto type = convert_type(loadedType, ctx);
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto type = typeConverter.ConvertJlmType(loadedType, llvmContext);
auto loadInstruction = builder.CreateLoad(type, ctx.value(address), isVolatile);
loadInstruction->setAlignment(::llvm::Align(alignment));
return loadInstruction;
Expand Down Expand Up @@ -385,8 +398,10 @@ convert_alloca(
{
JLM_ASSERT(is<alloca_op>(op));
auto & aop = *static_cast<const llvm::alloca_op *>(&op);
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto t = convert_type(aop.value_type(), ctx);
auto t = typeConverter.ConvertJlmType(aop.value_type(), llvmContext);
auto i = builder.CreateAlloca(t, ctx.value(args[0]));
i->setAlignment(::llvm::Align(aop.alignment()));
return i;
Expand All @@ -401,9 +416,11 @@ convert_getelementptr(
{
JLM_ASSERT(is<GetElementPtrOperation>(op) && args.size() >= 2);
auto & pop = *static_cast<const GetElementPtrOperation *>(&op);
auto & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

std::vector<::llvm::Value *> indices;
auto t = convert_type(pop.GetPointeeType(), ctx);
auto t = typeConverter.ConvertJlmType(pop.GetPointeeType(), llvmContext);
for (size_t n = 1; n < args.size(); n++)
indices.push_back(ctx.value(args[n]));

Expand Down Expand Up @@ -506,6 +523,8 @@ convert(
context & ctx)
{
JLM_ASSERT(is<ConstantArray>(op));
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

std::vector<::llvm::Constant *> data;
for (size_t n = 0; n < operands.size(); n++)
Expand All @@ -516,7 +535,7 @@ convert(
}

auto at = std::dynamic_pointer_cast<const arraytype>(op.result(0));
auto type = convert_type(*at, ctx);
auto type = typeConverter.ConvertArrayType(*at, llvmContext);
return ::llvm::ConstantArray::get(type, data);
}

Expand All @@ -527,7 +546,10 @@ convert(
::llvm::IRBuilder<> &,
context & ctx)
{
auto type = convert_type(*op.result(0), ctx);
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto type = typeConverter.ConvertJlmType(*op.result(0), llvmContext);
return ::llvm::ConstantAggregateZero::get(type);
}

Expand Down Expand Up @@ -642,11 +664,14 @@ convert(
::llvm::IRBuilder<> &,
context & ctx)
{
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

std::vector<::llvm::Constant *> operands;
for (const auto & arg : args)
operands.push_back(::llvm::cast<::llvm::Constant>(ctx.value(arg)));

auto t = convert_type(op.type(), ctx);
auto t = typeConverter.ConvertStructType(op.type(), llvmContext);
return ::llvm::ConstantStruct::get(t, operands);
}

Expand All @@ -657,7 +682,10 @@ convert(
::llvm::IRBuilder<> &,
context & ctx)
{
auto pointerType = convert_type(operation.GetPointerType(), ctx);
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();

auto pointerType = typeConverter.ConvertPointerType(operation.GetPointerType(), llvmContext);
return ::llvm::ConstantPointerNull::get(pointerType);
}

Expand Down Expand Up @@ -850,22 +878,24 @@ convert_cast(
context & ctx)
{
JLM_ASSERT(::llvm::Instruction::isCast(OPCODE));
auto & typeConverter = ctx.GetTypeConverter();
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto dsttype = std::dynamic_pointer_cast<const rvsdg::ValueType>(op.result(0));
auto operand = operands[0];

if (auto vt = dynamic_cast<const fixedvectortype *>(&operand->type()))
{
auto type = convert_type(fixedvectortype(dsttype, vt->size()), ctx);
auto type = typeConverter.ConvertJlmType(fixedvectortype(dsttype, vt->size()), llvmContext);
return builder.CreateCast(OPCODE, ctx.value(operand), type);
}

if (auto vt = dynamic_cast<const scalablevectortype *>(&operand->type()))
{
auto type = convert_type(scalablevectortype(dsttype, vt->size()), ctx);
auto type = typeConverter.ConvertJlmType(scalablevectortype(dsttype, vt->size()), llvmContext);
return builder.CreateCast(OPCODE, ctx.value(operand), type);
}

auto type = convert_type(*dsttype, ctx);
auto type = typeConverter.ConvertJlmType(*dsttype, llvmContext);
return builder.CreateCast(OPCODE, ctx.value(operand), type);
}

Expand All @@ -888,9 +918,10 @@ convert(
context & ctx)
{
JLM_ASSERT(args.size() == 1);
auto & typeConverter = ctx.GetTypeConverter();
auto & lm = ctx.llvm_module();

auto fcttype = convert_type(op.fcttype(), ctx);
auto fcttype = typeConverter.ConvertFunctionType(op.fcttype(), lm.getContext());
auto function = lm.getOrInsertFunction("malloc", fcttype);
auto operands = std::vector<::llvm::Value *>(1, ctx.value(args[0]));
return builder.CreateCall(function, operands);
Expand All @@ -903,9 +934,12 @@ convert(
::llvm::IRBuilder<> & builder,
context & ctx)
{
auto & typeConverter = ctx.GetTypeConverter();
auto & llvmmod = ctx.llvm_module();

auto fcttype = convert_type(rvsdg::FunctionType({ op.argument(0) }, {}), ctx);
auto fcttype = typeConverter.ConvertFunctionType(
rvsdg::FunctionType({ op.argument(0) }, {}),
llvmmod.getContext());
auto function = llvmmod.getOrInsertFunction("free", fcttype);
auto operands = std::vector<::llvm::Value *>(1, ctx.value(args[0]));
return builder.CreateCall(function, operands);
Expand Down
15 changes: 10 additions & 5 deletions jlm/llvm/backend/jlm2llvm/jlm2llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <jlm/llvm/backend/jlm2llvm/context.hpp>
#include <jlm/llvm/backend/jlm2llvm/instruction.hpp>
#include <jlm/llvm/backend/jlm2llvm/jlm2llvm.hpp>
#include <jlm/llvm/backend/jlm2llvm/type.hpp>

#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/IRBuilder.h>
Expand Down Expand Up @@ -102,6 +101,8 @@ static void
create_switch(const cfg_node * node, context & ctx)
{
JLM_ASSERT(node->noutedges() >= 2);
::llvm::LLVMContext & llvmContext = ctx.llvm_module().getContext();
auto & typeConverter = ctx.GetTypeConverter();
auto bb = static_cast<const basic_block *>(node);
::llvm::IRBuilder<> builder(ctx.basic_block(node));

Expand All @@ -120,7 +121,8 @@ create_switch(const cfg_node * node, context & ctx)
for (const auto & alt : *mop)
{
auto & type = *std::static_pointer_cast<const rvsdg::bittype>(mop->argument(0));
auto value = ::llvm::ConstantInt::get(convert_type(type, ctx), alt.first);
auto value =
::llvm::ConstantInt::get(typeConverter.ConvertBitType(type, llvmContext), alt.first);
sw->addCase(value, ctx.basic_block(node->outedge(alt.second)->sink()));
}
}
Expand Down Expand Up @@ -294,9 +296,11 @@ ConvertIntAttribute(const llvm::int_attribute & attribute, context & ctx)
static ::llvm::Attribute
ConvertTypeAttribute(const llvm::type_attribute & attribute, context & ctx)
{
auto & typeConverter = ctx.GetTypeConverter();
auto & llvmContext = ctx.llvm_module().getContext();

auto kind = convert_attribute_kind(attribute.kind());
auto type = convert_type(attribute.type(), ctx);
auto type = typeConverter.ConvertJlmType(attribute.type(), llvmContext);
return ::llvm::Attribute::get(llvmContext, kind, type);
}

Expand Down Expand Up @@ -499,6 +503,7 @@ convert_linkage(const llvm::linkage & linkage)
static void
convert_ipgraph(context & ctx)
{
auto & typeConverter = ctx.GetTypeConverter();
auto & jm = ctx.module();
auto & lm = ctx.llvm_module();

Expand All @@ -509,7 +514,7 @@ convert_ipgraph(context & ctx)

if (auto dataNode = dynamic_cast<const data_node *>(&node))
{
auto type = convert_type(*dataNode->GetValueType(), ctx);
auto type = typeConverter.ConvertJlmType(*dataNode->GetValueType(), lm.getContext());
auto linkage = convert_linkage(dataNode->linkage());

auto gv = new ::llvm::GlobalVariable(
Expand All @@ -524,7 +529,7 @@ convert_ipgraph(context & ctx)
}
else if (auto n = dynamic_cast<const function_node *>(&node))
{
auto type = convert_type(n->fcttype(), ctx);
auto type = typeConverter.ConvertFunctionType(n->fcttype(), lm.getContext());
auto linkage = convert_linkage(n->linkage());
auto f = ::llvm::Function::Create(type, linkage, n->name(), &lm);
ctx.insert(v, f);
Expand Down
Loading

0 comments on commit c762641

Please sign in to comment.