Skip to content

Commit

Permalink
Add FlattenBitConcatOperation normalization (#724)
Browse files Browse the repository at this point in the history
  • Loading branch information
phate authored Jan 8, 2025
1 parent 55160ba commit 14268ac
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
35 changes: 35 additions & 0 deletions jlm/rvsdg/bitstring/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,39 @@ bitconcat_op::copy() const
return std::make_unique<bitconcat_op>(*this);
}

static std::vector<std::shared_ptr<const bittype>>
GetTypesFromOperands(const std::vector<rvsdg::output *> & args)
{
std::vector<std::shared_ptr<const bittype>> types;
for (const auto arg : args)
{
types.push_back(std::dynamic_pointer_cast<const bittype>(arg->Type()));
}
return types;
}

std::optional<std::vector<rvsdg::output *>>
FlattenBitConcatOperation(const bitconcat_op &, const std::vector<rvsdg::output *> & operands)
{
JLM_ASSERT(!operands.empty());

const auto newOperands = base::detail::associative_flatten(
operands,
[](jlm::rvsdg::output * arg)
{
// FIXME: switch to comparing operator, not just typeid, after
// converting "concat" to not be a binary operator anymore
return is<bitconcat_op>(output::GetNode(*arg));
});

if (operands == newOperands)
{
JLM_ASSERT(newOperands.size() == 2);
return std::nullopt;
}

JLM_ASSERT(newOperands.size() > 2);
return outputs(&CreateOpNode<bitconcat_op>(newOperands, GetTypesFromOperands(newOperands)));
}

}
5 changes: 5 additions & 0 deletions jlm/rvsdg/bitstring/concat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class bitconcat_op final : public BinaryOperation
jlm::rvsdg::output *
bitconcat(const std::vector<jlm::rvsdg::output *> & operands);

std::optional<std::vector<rvsdg::output *>>
FlattenBitConcatOperation(
const bitconcat_op & operation,
const std::vector<rvsdg::output *> & operands);

}

#endif
11 changes: 10 additions & 1 deletion tests/jlm/rvsdg/bitstring/bitstring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <test-operation.hpp>

#include <jlm/rvsdg/bitstring.hpp>
#include <jlm/rvsdg/NodeNormalization.hpp>
#include <jlm/rvsdg/view.hpp>

static int
Expand Down Expand Up @@ -1176,8 +1177,10 @@ ConcatFlattening()
{
using namespace jlm::rvsdg;

// Arrange & Act
// Arrange
Graph graph;
const auto nf = graph.GetNodeNormalForm(typeid(bitconcat_op));
nf->set_mutable(false);

auto x = &jlm::tests::GraphImport::Create(graph, bittype::Create(8), "x");
auto y = &jlm::tests::GraphImport::Create(graph, bittype::Create(8), "y");
Expand All @@ -1189,6 +1192,12 @@ ConcatFlattening()
auto & ex = jlm::tests::GraphExport::Create(*concatResult2, "dummy");
view(graph, stdout);

// Act
const auto concatNode = output::GetNode(*ex.origin());
ReduceNode<bitconcat_op>(FlattenBitConcatOperation, *concatNode);

view(graph, stdout);

// Assert
auto node = output::GetNode(*ex.origin());
assert(dynamic_cast<const bitconcat_op *>(&node->GetOperation()));
Expand Down

0 comments on commit 14268ac

Please sign in to comment.