Skip to content

Commit

Permalink
theta: change API for mapping loop variables
Browse files Browse the repository at this point in the history
Provide an API to loop nodes that allows mapping the various
representation pieces of a loop variable. Remove
Theta{Input|Output|Argument|Result}.
  • Loading branch information
caleridas committed Jan 3, 2025
1 parent 087e4af commit 09d0296
Show file tree
Hide file tree
Showing 51 changed files with 1,054 additions and 1,041 deletions.
17 changes: 9 additions & 8 deletions jlm/hls/backend/rvsdg2rhls/ThetaConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@ ConvertThetaNode(rvsdg::ThetaNode & theta)
// smap.
for (size_t i = 0; i < theta.ninputs(); i++)
{
auto loopvar = theta.MapInputLoopVar(*theta.input(i));
// Check if the input is a loop invariant such that a loop constant buffer should be created.
// Memory state inputs are not loop variables containting a value, so we ignor these.
if (is_invariant(theta.input(i))
&& !jlm::rvsdg::is<jlm::llvm::MemoryStateType>(theta.input(i)->Type()))
if (ThetaLoopVarIsInvariant(loopvar)
&& !jlm::rvsdg::is<jlm::llvm::MemoryStateType>(loopvar.input->Type()))
{
smap.insert(theta.input(i)->argument(), loop->add_loopconst(theta.input(i)->origin()));
smap.insert(loopvar.pre, loop->add_loopconst(loopvar.input->origin()));
branches.push_back(nullptr);
// The HLS loop has no output for this input. The users of the theta output is
// therefore redirected to the input origin, as the value is loop invariant.
theta.output(i)->divert_users(theta.input(i)->origin());
loopvar.output->divert_users(loopvar.input->origin());
}
else
{
jlm::rvsdg::output * buffer;
loop->add_loopvar(theta.input(i)->origin(), &buffer);
smap.insert(theta.input(i)->argument(), buffer);
loop->AddLoopVar(loopvar.input->origin(), &buffer);
smap.insert(loopvar.pre, buffer);
// buffer out is only used by branch
branches.push_back(*buffer->begin());
// divert theta outputs
theta.output(i)->divert_users(loop->output(loop->noutputs() - 1));
loopvar.output->divert_users(loop->output(loop->noutputs() - 1));
}
}

Expand All @@ -54,7 +55,7 @@ ConvertThetaNode(rvsdg::ThetaNode & theta)
{
if (branches[i])
{
branches[i]->divert_to(smap.lookup(theta.input(i)->result()->origin()));
branches[i]->divert_to(smap.lookup(theta.MapInputLoopVar(*theta.input(i)).post->origin()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/add-prints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ route_to_region(jlm::rvsdg::output * output, rvsdg::Region * region)
}
else if (auto theta = dynamic_cast<rvsdg::ThetaNode *>(region->node()))
{
output = theta->add_loopvar(output)->argument();
output = theta->AddLoopVar(output).pre;
}
else if (auto lambda = dynamic_cast<llvm::lambda::node *>(region->node()))
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/add-triggers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ add_triggers(rvsdg::Region * region)
{
JLM_ASSERT(trigger != nullptr);
JLM_ASSERT(get_trigger(t->subregion()) == nullptr);
t->add_loopvar(trigger);
t->AddLoopVar(trigger);
add_triggers(t->subregion());
}
else if (auto gn = dynamic_cast<rvsdg::GammaNode *>(node))
Expand Down
24 changes: 11 additions & 13 deletions jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,27 @@ distribute_constant(const rvsdg::SimpleOperation & op, rvsdg::simple_output * ou
changed = false;
for (auto user : *out)
{
auto node = rvsdg::input::GetNode(*user);
if (auto ti = dynamic_cast<rvsdg::ThetaInput *>(user))
if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*user))
{
auto arg = ti->argument();
auto res = ti->result();
if (res->origin() == arg)
auto loopvar = theta->MapInputLoopVar(*user);
if (loopvar.post->origin() == loopvar.pre)
{
// pass-through
auto arg_replacement = dynamic_cast<rvsdg::simple_output *>(
rvsdg::SimpleNode::create_normalized(ti->node()->subregion(), op, {})[0]);
ti->argument()->divert_users(arg_replacement);
ti->output()->divert_users(
rvsdg::SimpleNode::create_normalized(theta->subregion(), op, {})[0]);
loopvar.pre->divert_users(arg_replacement);
loopvar.output->divert_users(
rvsdg::SimpleNode::create_normalized(out->region(), op, {})[0]);
distribute_constant(op, arg_replacement);
arg->region()->RemoveResult(res->index());
arg->region()->RemoveArgument(arg->index());
arg->region()->node()->RemoveInput(arg->input()->index());
arg->region()->node()->RemoveOutput(res->output()->index());
theta->subregion()->RemoveResult(loopvar.post->index());
theta->subregion()->RemoveArgument(loopvar.pre->index());
theta->RemoveInput(loopvar.input->index());
theta->RemoveOutput(loopvar.output->index());
changed = true;
break;
}
}
if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(node))
if (auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*user))
{
if (gammaNode->predicate() == user)
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/mem-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ separate_load_edge(
auto loop_node = jlm::util::AssertedCast<jlm::hls::loop_node>(sti->node());
jlm::rvsdg::output * buffer;

addr_edge = loop_node->add_loopvar(addr_edge, &buffer);
addr_edge = loop_node->AddLoopVar(addr_edge, &buffer);
addr_edge_user->divert_to(addr_edge);
mem_edge = find_loop_output(sti);
auto sti_arg = sti->arguments.first();
Expand Down
21 changes: 10 additions & 11 deletions jlm/hls/backend/rvsdg2rhls/mem-sep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ route_through(rvsdg::Region * target, jlm::rvsdg::output * response)
}
else if (auto tn = dynamic_cast<rvsdg::ThetaNode *>(target->node()))
{
auto lv = tn->add_loopvar(parent_response);
parrent_user->divert_to(lv);
return lv->argument();
auto lv = tn->AddLoopVar(parent_response);
parrent_user->divert_to(lv.output);
return lv.pre;
}
JLM_UNREACHABLE("THIS SHOULD NOT HAPPEN");
}
Expand Down Expand Up @@ -183,13 +183,12 @@ trace_edge(
JLM_ASSERT(new_edge->nusers() == 1);
auto user = *common_edge->begin();
auto new_next = *new_edge->begin();
auto node = rvsdg::input::GetNode(*user);
if (auto res = dynamic_cast<rvsdg::RegionResult *>(user))
{
// end of region reached
return res;
}
else if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(node))
else if (auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*user))
{
auto ip = gammaNode->AddEntryVar(new_edge);
std::vector<jlm::rvsdg::output *> vec;
Expand All @@ -208,13 +207,13 @@ trace_edge(
common_edge = subres->output();
}
}
else if (auto ti = dynamic_cast<rvsdg::ThetaInput *>(user))
else if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*user))
{
auto tn = ti->node();
auto lv = tn->add_loopvar(new_edge);
trace_edge(ti->argument(), lv->argument(), load_nodes, store_nodes, decouple_nodes);
common_edge = ti->output();
new_edge = lv;
auto olv = theta->MapInputLoopVar(*user);
auto lv = theta->AddLoopVar(new_edge);
trace_edge(olv.pre, lv.pre, load_nodes, store_nodes, decouple_nodes);
common_edge = olv.output;
new_edge = lv.output;
new_next->divert_to(new_edge);
}
else if (auto si = dynamic_cast<jlm::rvsdg::simple_input *>(user))
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ trace_call(jlm::rvsdg::input * input)

auto argument = dynamic_cast<const rvsdg::RegionArgument *>(input->origin());
const jlm::rvsdg::output * result;
if (auto to = dynamic_cast<const rvsdg::ThetaOutput *>(input->origin()))
if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*input->origin()))
{
result = trace_call(to->input());
result = trace_call(theta->MapOutputLoopVar(*input->origin()).input);
}
else if (argument == nullptr)
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/ir/hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ExitResult::Copy(rvsdg::output & origin, rvsdg::StructuralOutput * output)
}

rvsdg::StructuralOutput *
loop_node::add_loopvar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer)
loop_node::AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer)
{
auto input = rvsdg::StructuralInput::create(this, origin, origin->Type());
auto output = rvsdg::StructuralOutput::create(this, origin->Type());
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/ir/hls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ class loop_node final : public rvsdg::StructuralNode
add_backedge(std::shared_ptr<const jlm::rvsdg::Type> type);

rvsdg::StructuralOutput *
add_loopvar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer = nullptr);
AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer = nullptr);

jlm::rvsdg::output *
add_loopconst(jlm::rvsdg::output * origin);
Expand Down
62 changes: 35 additions & 27 deletions jlm/hls/opt/cne.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,40 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx &
if (o1->type() != o2->type())
return false;

if (is<rvsdg::ThetaArgument>(o1) && is<rvsdg::ThetaArgument>(o2))
if (auto theta1 = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*o1))
{
JLM_ASSERT(o1->region()->node() == o2->region()->node());
auto a1 = static_cast<rvsdg::RegionArgument *>(o1);
auto a2 = static_cast<rvsdg::RegionArgument *>(o2);
vs.insert(a1, a2);
auto i1 = a1->input(), i2 = a2->input();
if (!congruent(a1->input()->origin(), a2->input()->origin(), vs, ctx))
return false;
if (auto theta2 = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*o2))
{
JLM_ASSERT(o1->region()->node() == o2->region()->node());
auto loopvar1 = theta1->MapPreLoopVar(*o1);
auto loopvar2 = theta2->MapPreLoopVar(*o2);
vs.insert(o1, o2);
auto i1 = loopvar1.input, i2 = loopvar2.input;
if (!congruent(loopvar1.input->origin(), loopvar2.input->origin(), vs, ctx))
return false;

auto output1 = o1->region()->node()->output(i1->index());
auto output2 = o2->region()->node()->output(i2->index());
return congruent(output1, output2, vs, ctx);
auto output1 = o1->region()->node()->output(i1->index());
auto output2 = o2->region()->node()->output(i2->index());
return congruent(output1, output2, vs, ctx);
}
}

auto n1 = jlm::rvsdg::output::GetNode(*o1);
auto n2 = jlm::rvsdg::output::GetNode(*o2);
if (is<jlm::rvsdg::ThetaOperation>(n1) && is<jlm::rvsdg::ThetaOperation>(n2) && n1 == n2)
if (auto theta1 = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*o1))
{
auto so1 = static_cast<StructuralOutput *>(o1);
auto so2 = static_cast<StructuralOutput *>(o2);
vs.insert(o1, o2);
auto r1 = so1->results.first();
auto r2 = so2->results.first();
return congruent(r1->origin(), r2->origin(), vs, ctx);
if (auto theta2 = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*o2))
{
vs.insert(o1, o2);
auto loopvar1 = theta1->MapOutputLoopVar(*o1);
auto loopvar2 = theta2->MapOutputLoopVar(*o2);
auto r1 = loopvar1.post;
auto r2 = loopvar2.post;
return congruent(r1->origin(), r2->origin(), vs, ctx);
}
}

auto n1 = jlm::rvsdg::output::GetNode(*o1);
auto n2 = jlm::rvsdg::output::GetNode(*o2);

auto a1 = dynamic_cast<rvsdg::RegionArgument *>(o1);
auto a2 = dynamic_cast<rvsdg::RegionArgument *>(o2);
if (a1 && is<hls::loop_op>(a1->region()->node()) && a2 && is<hls::loop_op>(a2->region()->node()))
Expand Down Expand Up @@ -331,10 +338,12 @@ mark_theta(const rvsdg::StructuralNode * node, cnectx & ctx)
{
auto input1 = theta->input(i1);
auto input2 = theta->input(i2);
if (congruent(input1->argument(), input2->argument(), ctx))
auto loopvar1 = theta->MapInputLoopVar(*input1);
auto loopvar2 = theta->MapInputLoopVar(*input2);
if (congruent(loopvar1.pre, loopvar2.pre, ctx))
{
ctx.mark(input1->argument(), input2->argument());
ctx.mark(input1->output(), input2->output());
ctx.mark(loopvar1.pre, loopvar2.pre);
ctx.mark(loopvar1.output, loopvar2.output);
}
}
}
Expand Down Expand Up @@ -530,11 +539,10 @@ divert_theta(rvsdg::StructuralNode * node, cnectx & ctx)
auto theta = static_cast<rvsdg::ThetaNode *>(node);
auto subregion = node->subregion(0);

for (const auto & lv : *theta)
for (const auto & lv : theta->GetLoopVars())
{
JLM_ASSERT(ctx.set(lv->argument())->size() == ctx.set(lv)->size());
divert_users(lv->argument(), ctx);
divert_users(lv, ctx);
JLM_ASSERT(ctx.set(lv.pre)->size() == ctx.set(lv.output)->size());
divert_users(lv.pre, ctx);
}

divert(subregion, ctx);
Expand Down
5 changes: 3 additions & 2 deletions jlm/hls/util/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ region_to_dot(rvsdg::Region * region)
{
dot << edge(be->argument(), be, true);
}
else if (auto to = dynamic_cast<rvsdg::ThetaOutput *>(region->result(i)->output()))
else if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*region->result(i)->output()))
{
dot << edge(to->argument(), to->result(), true);
auto loopvar = theta->MapOutputLoopVar(*region->result(i)->output());
dot << edge(loopvar.pre, loopvar.post, true);
}
}

Expand Down
13 changes: 7 additions & 6 deletions jlm/llvm/frontend/InterProceduralGraphConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ Convert(
* Add loop variables
*/
auto & demandSet = demandMap.Lookup<LoopAnnotationSet>(loopAggregationNode);
std::unordered_map<const variable *, rvsdg::ThetaOutput *> thetaOutputMap;
std::unordered_map<const variable *, rvsdg::ThetaNode::LoopVar> thetaLoopVarMap;
for (auto & v : demandSet.LoopVariables().Variables())
{
rvsdg::output * value = nullptr;
Expand All @@ -778,8 +778,9 @@ Convert(
{
value = outerVariableMap.lookup(&v);
}
thetaOutputMap[&v] = theta->add_loopvar(value);
thetaVariableMap.insert(&v, thetaOutputMap[&v]->argument());
auto loopvar = theta->AddLoopVar(value);
thetaLoopVarMap[&v] = loopvar;
thetaVariableMap.insert(&v, loopvar.pre);
}

/*
Expand All @@ -797,8 +798,8 @@ Convert(
*/
for (auto & v : demandSet.LoopVariables().Variables())
{
JLM_ASSERT(thetaOutputMap.find(&v) != thetaOutputMap.end());
thetaOutputMap[&v]->result()->divert_to(thetaVariableMap.lookup(&v));
JLM_ASSERT(thetaLoopVarMap.find(&v) != thetaLoopVarMap.end());
thetaLoopVarMap[&v].post->divert_to(thetaVariableMap.lookup(&v));
}

/*
Expand All @@ -820,7 +821,7 @@ Convert(
for (auto & v : demandSet.LoopVariables().Variables())
{
JLM_ASSERT(outerVariableMap.contains(&v));
outerVariableMap.insert(&v, thetaOutputMap[&v]);
outerVariableMap.insert(&v, thetaLoopVarMap[&v].output);
}
}

Expand Down
Loading

0 comments on commit 09d0296

Please sign in to comment.