From 87ead1854929c0d701c4734b946316b53aa1f63e Mon Sep 17 00:00:00 2001
From: Nico Reissmann <nico.reissmann@gmail.com>
Date: Sun, 26 Jan 2025 10:43:08 +0100
Subject: [PATCH] Simplify translation of integer operations in LLVM frontend
 (#770)

---
 .../frontend/LlvmInstructionConversion.cpp    | 423 ++++++++----------
 1 file changed, 186 insertions(+), 237 deletions(-)

diff --git a/jlm/llvm/frontend/LlvmInstructionConversion.cpp b/jlm/llvm/frontend/LlvmInstructionConversion.cpp
index fa43540e3..de086f7a1 100644
--- a/jlm/llvm/frontend/LlvmInstructionConversion.cpp
+++ b/jlm/llvm/frontend/LlvmInstructionConversion.cpp
@@ -455,118 +455,97 @@ convert_unreachable_instruction(::llvm::Instruction * i, tacsvector_t &, context
   return nullptr;
 }
 
-static inline const variable *
-convert_icmp_instruction(::llvm::Instruction * instruction, tacsvector_t & tacs, context & ctx)
+static std::unique_ptr<rvsdg::BinaryOperation>
+ConvertIntegerIcmpPredicate(const ::llvm::CmpInst::Predicate predicate, const std::size_t numBits)
 {
-  JLM_ASSERT(instruction->getOpcode() == ::llvm::Instruction::ICmp);
-  auto & typeConverter = ctx.GetTypeConverter();
-  auto i = ::llvm::cast<const ::llvm::ICmpInst>(instruction);
-  auto t = i->getOperand(0)->getType();
+  switch (predicate)
+  {
+  case ::llvm::CmpInst::ICMP_SLT:
+    return std::make_unique<rvsdg::bitslt_op>(numBits);
+  case ::llvm::CmpInst::ICMP_ULT:
+    return std::make_unique<rvsdg::bitult_op>(numBits);
+  case ::llvm::CmpInst::ICMP_SLE:
+    return std::make_unique<rvsdg::bitsle_op>(numBits);
+  case ::llvm::CmpInst::ICMP_ULE:
+    return std::make_unique<rvsdg::bitule_op>(numBits);
+  case ::llvm::CmpInst::ICMP_EQ:
+    return std::make_unique<rvsdg::biteq_op>(numBits);
+  case ::llvm::CmpInst::ICMP_NE:
+    return std::make_unique<rvsdg::bitne_op>(numBits);
+  case ::llvm::CmpInst::ICMP_SGE:
+    return std::make_unique<rvsdg::bitsge_op>(numBits);
+  case ::llvm::CmpInst::ICMP_UGE:
+    return std::make_unique<rvsdg::bituge_op>(numBits);
+  case ::llvm::CmpInst::ICMP_SGT:
+    return std::make_unique<rvsdg::bitsgt_op>(numBits);
+  case ::llvm::CmpInst::ICMP_UGT:
+    return std::make_unique<rvsdg::bitugt_op>(numBits);
+  default:
+    JLM_UNREACHABLE("ConvertIntegerIcmpPredicate: Unsupported icmp predicate.");
+  }
+}
 
-  static std::
-      unordered_map<const ::llvm::CmpInst::Predicate, std::unique_ptr<rvsdg::Operation> (*)(size_t)>
-          map({ { ::llvm::CmpInst::ICMP_SLT,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitslt_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_ULT,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitult_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_SLE,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitsle_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_ULE,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitule_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_EQ,
-                  [](size_t nbits)
-                  {
-                    rvsdg::biteq_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_NE,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitne_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_SGE,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitsge_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_UGE,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bituge_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_SGT,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitsgt_op op(nbits);
-                    return op.copy();
-                  } },
-                { ::llvm::CmpInst::ICMP_UGT,
-                  [](size_t nbits)
-                  {
-                    rvsdg::bitugt_op op(nbits);
-                    return op.copy();
-                  } } });
-
-  static std::unordered_map<const ::llvm::CmpInst::Predicate, llvm::cmp> ptrmap(
-      { { ::llvm::CmpInst::ICMP_ULT, cmp::lt },
-        { ::llvm::CmpInst::ICMP_ULE, cmp::le },
-        { ::llvm::CmpInst::ICMP_EQ, cmp::eq },
-        { ::llvm::CmpInst::ICMP_NE, cmp::ne },
-        { ::llvm::CmpInst::ICMP_UGE, cmp::ge },
-        { ::llvm::CmpInst::ICMP_UGT, cmp::gt } });
-
-  auto p = i->getPredicate();
-  auto op1 = ConvertValue(i->getOperand(0), tacs, ctx);
-  auto op2 = ConvertValue(i->getOperand(1), tacs, ctx);
+static std::unique_ptr<rvsdg::BinaryOperation>
+ConvertPointerIcmpPredicate(const ::llvm::CmpInst::Predicate predicate)
+{
+  switch (predicate)
+  {
+  case ::llvm::CmpInst::ICMP_ULT:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::lt);
+  case ::llvm::CmpInst::ICMP_ULE:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::le);
+  case ::llvm::CmpInst::ICMP_EQ:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::eq);
+  case ::llvm::CmpInst::ICMP_NE:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::ne);
+  case ::llvm::CmpInst::ICMP_UGE:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::ge);
+  case ::llvm::CmpInst::ICMP_UGT:
+    return std::make_unique<ptrcmp_op>(PointerType::Create(), cmp::gt);
+  default:
+    JLM_UNREACHABLE("ConvertPointerIcmpPredicate: Unsupported icmp predicate.");
+  }
+}
 
-  std::unique_ptr<rvsdg::Operation> binop;
+static const variable *
+convert(const ::llvm::ICmpInst * instruction, tacsvector_t & tacs, context & ctx)
+{
+  const auto predicate = instruction->getPredicate();
+  const auto operandType = instruction->getOperand(0)->getType();
+  auto op1 = ConvertValue(instruction->getOperand(0), tacs, ctx);
+  auto op2 = ConvertValue(instruction->getOperand(1), tacs, ctx);
 
-  if (t->isIntegerTy() || (t->isVectorTy() && t->getScalarType()->isIntegerTy()))
+  std::unique_ptr<rvsdg::BinaryOperation> operation;
+  if (operandType->isVectorTy() && operandType->getScalarType()->isIntegerTy())
+  {
+    operation =
+        ConvertIntegerIcmpPredicate(predicate, operandType->getScalarType()->getIntegerBitWidth());
+  }
+  else if (operandType->isVectorTy() && operandType->getScalarType()->isPointerTy())
   {
-    auto it = t->isVectorTy() ? t->getScalarType() : t;
-    binop = map[p](it->getIntegerBitWidth());
+    operation = ConvertPointerIcmpPredicate(predicate);
   }
-  else if (t->isPointerTy() || (t->isVectorTy() && t->getScalarType()->isPointerTy()))
+  else if (operandType->isIntegerTy())
   {
-    auto pt = ::llvm::cast<::llvm::PointerType>(t->isVectorTy() ? t->getScalarType() : t);
-    binop = std::make_unique<ptrcmp_op>(typeConverter.ConvertPointerType(*pt), ptrmap[p]);
+    operation = ConvertIntegerIcmpPredicate(predicate, operandType->getIntegerBitWidth());
+  }
+  else if (operandType->isPointerTy())
+  {
+    operation = ConvertPointerIcmpPredicate(predicate);
   }
   else
-    JLM_UNREACHABLE("This should have never happend.");
-
-  auto type = typeConverter.ConvertLlvmType(*i->getType());
+  {
+    JLM_UNREACHABLE("convert: Unhandled icmp type.");
+  }
 
-  JLM_ASSERT(is<rvsdg::BinaryOperation>(*binop));
-  if (t->isVectorTy())
+  if (operandType->isVectorTy())
   {
-    tacs.push_back(vectorbinary_op::create(
-        *static_cast<rvsdg::BinaryOperation *>(binop.get()),
-        op1,
-        op2,
-        type));
+    const auto instructionType = ctx.GetTypeConverter().ConvertLlvmType(*instruction->getType());
+    tacs.push_back(vectorbinary_op::create(*operation, op1, op2, instructionType));
   }
   else
   {
-    tacs.push_back(tac::create(*static_cast<rvsdg::SimpleOperation *>(binop.get()), { op1, op2 }));
+    tacs.push_back(tac::create(*operation, { op1, op2 }));
   }
 
   return tacs.back()->result(0);
@@ -982,139 +961,109 @@ convert_select_instruction(::llvm::Instruction * i, tacsvector_t & tacs, context
   return tacs.back()->result(0);
 }
 
-static inline const variable *
-convert_binary_operator(::llvm::Instruction * instruction, tacsvector_t & tacs, context & ctx)
+static std::unique_ptr<rvsdg::BinaryOperation>
+ConvertIntegerBinaryOperation(
+    const ::llvm::Instruction::BinaryOps binaryOperation,
+    std::size_t numBits)
 {
-  JLM_ASSERT(::llvm::dyn_cast<const ::llvm::BinaryOperator>(instruction));
-  auto i = ::llvm::cast<const ::llvm::BinaryOperator>(instruction);
+  switch (binaryOperation)
+  {
+  case ::llvm::Instruction::Add:
+    return std::make_unique<rvsdg::bitadd_op>(numBits);
+  case ::llvm::Instruction::And:
+    return std::make_unique<rvsdg::bitand_op>(numBits);
+  case ::llvm::Instruction::AShr:
+    return std::make_unique<rvsdg::bitashr_op>(numBits);
+  case ::llvm::Instruction::LShr:
+    return std::make_unique<rvsdg::bitshr_op>(numBits);
+  case ::llvm::Instruction::Mul:
+    return std::make_unique<rvsdg::bitmul_op>(numBits);
+  case ::llvm::Instruction::Or:
+    return std::make_unique<rvsdg::bitor_op>(numBits);
+  case ::llvm::Instruction::SDiv:
+    return std::make_unique<rvsdg::bitsdiv_op>(numBits);
+  case ::llvm::Instruction::Shl:
+    return std::make_unique<rvsdg::bitshl_op>(numBits);
+  case ::llvm::Instruction::SRem:
+    return std::make_unique<rvsdg::bitsmod_op>(numBits);
+  case ::llvm::Instruction::Sub:
+    return std::make_unique<rvsdg::bitsub_op>(numBits);
+  case ::llvm::Instruction::UDiv:
+    return std::make_unique<rvsdg::bitudiv_op>(numBits);
+  case ::llvm::Instruction::URem:
+    return std::make_unique<rvsdg::bitumod_op>(numBits);
+  case ::llvm::Instruction::Xor:
+    return std::make_unique<rvsdg::bitxor_op>(numBits);
+  default:
+    JLM_UNREACHABLE("ConvertIntegerBinaryOperation: Unsupported integer binary operation");
+  }
+}
 
-  static std::unordered_map<
-      const ::llvm::Instruction::BinaryOps,
-      std::unique_ptr<rvsdg::Operation> (*)(size_t)>
-      bitmap({ { ::llvm::Instruction::Add,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitadd_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::And,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitand_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::AShr,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitashr_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::Sub,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitsub_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::UDiv,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitudiv_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::SDiv,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitsdiv_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::URem,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitumod_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::SRem,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitsmod_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::Shl,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitshl_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::LShr,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitshr_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::Or,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitor_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::Xor,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitxor_op o(nbits);
-                   return o.copy();
-                 } },
-               { ::llvm::Instruction::Mul,
-                 [](size_t nbits)
-                 {
-                   rvsdg::bitmul_op o(nbits);
-                   return o.copy();
-                 } } });
-
-  static std::unordered_map<const ::llvm::Instruction::BinaryOps, llvm::fpop> fpmap(
-      { { ::llvm::Instruction::FAdd, fpop::add },
-        { ::llvm::Instruction::FSub, fpop::sub },
-        { ::llvm::Instruction::FMul, fpop::mul },
-        { ::llvm::Instruction::FDiv, fpop::div },
-        { ::llvm::Instruction::FRem, fpop::mod } });
-
-  static std::unordered_map<const ::llvm::Type::TypeID, llvm::fpsize> fpsizemap(
-      { { ::llvm::Type::HalfTyID, fpsize::half },
-        { ::llvm::Type::FloatTyID, fpsize::flt },
-        { ::llvm::Type::DoubleTyID, fpsize::dbl },
-        { ::llvm::Type::X86_FP80TyID, fpsize::x86fp80 },
-        { ::llvm::Type::FP128TyID, fpsize::fp128 } });
-
-  std::unique_ptr<rvsdg::Operation> operation;
-  auto t = i->getType()->isVectorTy() ? i->getType()->getScalarType() : i->getType();
-  if (t->isIntegerTy())
+static std::unique_ptr<rvsdg::BinaryOperation>
+ConvertFloatingPointBinaryOperation(
+    const ::llvm::Instruction::BinaryOps binaryOperation,
+    fpsize floatingPointSize)
+{
+  switch (binaryOperation)
+  {
+  case ::llvm::Instruction::FAdd:
+    return std::make_unique<fpbin_op>(fpop::add, floatingPointSize);
+  case ::llvm::Instruction::FSub:
+    return std::make_unique<fpbin_op>(fpop::sub, floatingPointSize);
+  case ::llvm::Instruction::FMul:
+    return std::make_unique<fpbin_op>(fpop::mul, floatingPointSize);
+  case ::llvm::Instruction::FDiv:
+    return std::make_unique<fpbin_op>(fpop::div, floatingPointSize);
+  case ::llvm::Instruction::FRem:
+    return std::make_unique<fpbin_op>(fpop::mod, floatingPointSize);
+  default:
+    JLM_UNREACHABLE("ConvertFloatingPointBinaryOperation: Unsupported binary operation");
+  }
+}
+
+static const variable *
+convert(const ::llvm::BinaryOperator * instruction, tacsvector_t & tacs, context & ctx)
+{
+  const auto llvmType = instruction->getType();
+  auto & typeConverter = ctx.GetTypeConverter();
+  const auto opcode = instruction->getOpcode();
+
+  std::unique_ptr<rvsdg::BinaryOperation> operation;
+  if (llvmType->isVectorTy() && llvmType->getScalarType()->isIntegerTy())
   {
-    JLM_ASSERT(bitmap.find(i->getOpcode()) != bitmap.end());
-    operation = bitmap[i->getOpcode()](t->getIntegerBitWidth());
+    const auto numBits = llvmType->getScalarType()->getIntegerBitWidth();
+    operation = ConvertIntegerBinaryOperation(opcode, numBits);
   }
-  else if (t->isFloatingPointTy())
+  else if (llvmType->isVectorTy() && llvmType->getScalarType()->isFloatingPointTy())
   {
-    JLM_ASSERT(fpmap.find(i->getOpcode()) != fpmap.end());
-    JLM_ASSERT(fpsizemap.find(t->getTypeID()) != fpsizemap.end());
-    operation = std::make_unique<fpbin_op>(fpmap[i->getOpcode()], fpsizemap[t->getTypeID()]);
+    const auto size = typeConverter.ExtractFloatingPointSize(*llvmType->getScalarType());
+    operation = ConvertFloatingPointBinaryOperation(opcode, size);
+  }
+  else if (llvmType->isIntegerTy())
+  {
+    operation = ConvertIntegerBinaryOperation(opcode, llvmType->getIntegerBitWidth());
+  }
+  else if (llvmType->isFloatingPointTy())
+  {
+    const auto size = typeConverter.ExtractFloatingPointSize(*llvmType);
+    operation = ConvertFloatingPointBinaryOperation(opcode, size);
   }
   else
-    JLM_ASSERT(0);
-
-  auto type = ctx.GetTypeConverter().ConvertLlvmType(*i->getType());
+  {
+    JLM_ASSERT("convert: Unhandled binary operation type.");
+  }
 
-  auto op1 = ConvertValue(i->getOperand(0), tacs, ctx);
-  auto op2 = ConvertValue(i->getOperand(1), tacs, ctx);
-  JLM_ASSERT(is<rvsdg::BinaryOperation>(*operation));
+  const auto jlmType = typeConverter.ConvertLlvmType(*llvmType);
+  auto operand1 = ConvertValue(instruction->getOperand(0), tacs, ctx);
+  auto operand2 = ConvertValue(instruction->getOperand(1), tacs, ctx);
 
-  if (i->getType()->isVectorTy())
+  if (llvmType->isVectorTy())
   {
-    auto & binop = *static_cast<rvsdg::BinaryOperation *>(operation.get());
-    tacs.push_back(vectorbinary_op::create(binop, op1, op2, type));
+    tacs.push_back(vectorbinary_op::create(*operation, operand1, operand2, jlmType));
   }
   else
   {
-    tacs.push_back(
-        tac::create(*static_cast<rvsdg::SimpleOperation *>(operation.get()), { op1, op2 }));
+    tacs.push_back(tac::create(*operation, { operand1, operand2 }));
   }
 
   return tacs.back()->result(0);
@@ -1297,26 +1246,26 @@ ConvertInstruction(
             { ::llvm::Instruction::Br, convert_branch_instruction },
             { ::llvm::Instruction::Switch, convert_switch_instruction },
             { ::llvm::Instruction::Unreachable, convert_unreachable_instruction },
-            { ::llvm::Instruction::Add, convert_binary_operator },
-            { ::llvm::Instruction::And, convert_binary_operator },
-            { ::llvm::Instruction::AShr, convert_binary_operator },
-            { ::llvm::Instruction::Sub, convert_binary_operator },
-            { ::llvm::Instruction::UDiv, convert_binary_operator },
-            { ::llvm::Instruction::SDiv, convert_binary_operator },
-            { ::llvm::Instruction::URem, convert_binary_operator },
-            { ::llvm::Instruction::SRem, convert_binary_operator },
-            { ::llvm::Instruction::Shl, convert_binary_operator },
-            { ::llvm::Instruction::LShr, convert_binary_operator },
-            { ::llvm::Instruction::Or, convert_binary_operator },
-            { ::llvm::Instruction::Xor, convert_binary_operator },
-            { ::llvm::Instruction::Mul, convert_binary_operator },
-            { ::llvm::Instruction::FAdd, convert_binary_operator },
-            { ::llvm::Instruction::FSub, convert_binary_operator },
-            { ::llvm::Instruction::FMul, convert_binary_operator },
-            { ::llvm::Instruction::FDiv, convert_binary_operator },
-            { ::llvm::Instruction::FRem, convert_binary_operator },
+            { ::llvm::Instruction::Add, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::And, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::AShr, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::Sub, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::UDiv, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::SDiv, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::URem, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::SRem, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::Shl, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::LShr, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::Or, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::Xor, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::Mul, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::FAdd, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::FSub, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::FMul, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::FDiv, convert<::llvm::BinaryOperator> },
+            { ::llvm::Instruction::FRem, convert<::llvm::BinaryOperator> },
             { ::llvm::Instruction::FNeg, convert<::llvm::UnaryOperator> },
-            { ::llvm::Instruction::ICmp, convert_icmp_instruction },
+            { ::llvm::Instruction::ICmp, convert<::llvm::ICmpInst> },
             { ::llvm::Instruction::FCmp, convert_fcmp_instruction },
             { ::llvm::Instruction::Load, convert_load_instruction },
             { ::llvm::Instruction::Store, convert_store_instruction },