Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

theta: change API for mapping loop variables #675

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions jlm/hls/backend/rvsdg2rhls/ThetaConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@ ConvertThetaNode(rvsdg::ThetaNode & theta)
// smap.
for (size_t i = 0; i < theta.ninputs(); i++)
{
auto loopvar = theta.MapInputLoopVar(*theta.input(i));
// Check if the input is a loop invariant such that a loop constant buffer should be created.
// Memory state inputs are not loop variables containting a value, so we ignor these.
if (is_invariant(theta.input(i))
&& !jlm::rvsdg::is<jlm::llvm::MemoryStateType>(theta.input(i)->Type()))
if (ThetaLoopVarIsInvariant(loopvar)
&& !jlm::rvsdg::is<jlm::llvm::MemoryStateType>(loopvar.input->Type()))
{
smap.insert(theta.input(i)->argument(), loop->add_loopconst(theta.input(i)->origin()));
smap.insert(loopvar.pre, loop->add_loopconst(loopvar.input->origin()));
branches.push_back(nullptr);
// The HLS loop has no output for this input. The users of the theta output is
// therefore redirected to the input origin, as the value is loop invariant.
theta.output(i)->divert_users(theta.input(i)->origin());
loopvar.output->divert_users(loopvar.input->origin());
}
else
{
jlm::rvsdg::output * buffer;
loop->add_loopvar(theta.input(i)->origin(), &buffer);
smap.insert(theta.input(i)->argument(), buffer);
loop->AddLoopVar(loopvar.input->origin(), &buffer);
smap.insert(loopvar.pre, buffer);
// buffer out is only used by branch
branches.push_back(*buffer->begin());
// divert theta outputs
theta.output(i)->divert_users(loop->output(loop->noutputs() - 1));
loopvar.output->divert_users(loop->output(loop->noutputs() - 1));
}
}

Expand All @@ -54,7 +55,7 @@ ConvertThetaNode(rvsdg::ThetaNode & theta)
{
if (branches[i])
{
branches[i]->divert_to(smap.lookup(theta.input(i)->result()->origin()));
branches[i]->divert_to(smap.lookup(theta.MapInputLoopVar(*theta.input(i)).post->origin()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/add-prints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ route_to_region(jlm::rvsdg::output * output, rvsdg::Region * region)
}
else if (auto theta = dynamic_cast<rvsdg::ThetaNode *>(region->node()))
{
output = theta->add_loopvar(output)->argument();
output = theta->AddLoopVar(output).pre;
}
else if (auto lambda = dynamic_cast<llvm::lambda::node *>(region->node()))
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/add-triggers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ add_triggers(rvsdg::Region * region)
{
JLM_ASSERT(trigger != nullptr);
JLM_ASSERT(get_trigger(t->subregion()) == nullptr);
t->add_loopvar(trigger);
t->AddLoopVar(trigger);
add_triggers(t->subregion());
}
else if (auto gn = dynamic_cast<rvsdg::GammaNode *>(node))
Expand Down
24 changes: 11 additions & 13 deletions jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,27 @@ distribute_constant(const rvsdg::SimpleOperation & op, rvsdg::simple_output * ou
changed = false;
for (auto user : *out)
{
auto node = rvsdg::input::GetNode(*user);
if (auto ti = dynamic_cast<rvsdg::ThetaInput *>(user))
if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*user))
{
auto arg = ti->argument();
auto res = ti->result();
if (res->origin() == arg)
auto loopvar = theta->MapInputLoopVar(*user);
if (loopvar.post->origin() == loopvar.pre)
{
// pass-through
auto arg_replacement = dynamic_cast<rvsdg::simple_output *>(
rvsdg::SimpleNode::create_normalized(ti->node()->subregion(), op, {})[0]);
ti->argument()->divert_users(arg_replacement);
ti->output()->divert_users(
rvsdg::SimpleNode::create_normalized(theta->subregion(), op, {})[0]);
loopvar.pre->divert_users(arg_replacement);
loopvar.output->divert_users(
rvsdg::SimpleNode::create_normalized(out->region(), op, {})[0]);
distribute_constant(op, arg_replacement);
arg->region()->RemoveResult(res->index());
arg->region()->RemoveArgument(arg->index());
arg->region()->node()->RemoveInput(arg->input()->index());
arg->region()->node()->RemoveOutput(res->output()->index());
theta->subregion()->RemoveResult(loopvar.post->index());
theta->subregion()->RemoveArgument(loopvar.pre->index());
theta->RemoveInput(loopvar.input->index());
theta->RemoveOutput(loopvar.output->index());
changed = true;
break;
}
}
if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(node))
if (auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*user))
{
if (gammaNode->predicate() == user)
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/mem-queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ separate_load_edge(
auto loop_node = jlm::util::AssertedCast<jlm::hls::loop_node>(sti->node());
jlm::rvsdg::output * buffer;

addr_edge = loop_node->add_loopvar(addr_edge, &buffer);
addr_edge = loop_node->AddLoopVar(addr_edge, &buffer);
addr_edge_user->divert_to(addr_edge);
mem_edge = find_loop_output(sti);
auto sti_arg = sti->arguments.first();
Expand Down
21 changes: 10 additions & 11 deletions jlm/hls/backend/rvsdg2rhls/mem-sep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ route_through(rvsdg::Region * target, jlm::rvsdg::output * response)
}
else if (auto tn = dynamic_cast<rvsdg::ThetaNode *>(target->node()))
{
auto lv = tn->add_loopvar(parent_response);
parrent_user->divert_to(lv);
return lv->argument();
auto lv = tn->AddLoopVar(parent_response);
parrent_user->divert_to(lv.output);
return lv.pre;
}
JLM_UNREACHABLE("THIS SHOULD NOT HAPPEN");
}
Expand Down Expand Up @@ -183,13 +183,12 @@ trace_edge(
JLM_ASSERT(new_edge->nusers() == 1);
auto user = *common_edge->begin();
auto new_next = *new_edge->begin();
auto node = rvsdg::input::GetNode(*user);
if (auto res = dynamic_cast<rvsdg::RegionResult *>(user))
{
// end of region reached
return res;
}
else if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(node))
else if (auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*user))
{
auto ip = gammaNode->AddEntryVar(new_edge);
std::vector<jlm::rvsdg::output *> vec;
Expand All @@ -208,13 +207,13 @@ trace_edge(
common_edge = subres->output();
}
}
else if (auto ti = dynamic_cast<rvsdg::ThetaInput *>(user))
else if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*user))
{
auto tn = ti->node();
auto lv = tn->add_loopvar(new_edge);
trace_edge(ti->argument(), lv->argument(), load_nodes, store_nodes, decouple_nodes);
common_edge = ti->output();
new_edge = lv;
auto olv = theta->MapInputLoopVar(*user);
auto lv = theta->AddLoopVar(new_edge);
trace_edge(olv.pre, lv.pre, load_nodes, store_nodes, decouple_nodes);
common_edge = olv.output;
new_edge = lv.output;
new_next->divert_to(new_edge);
}
else if (auto si = dynamic_cast<jlm::rvsdg::simple_input *>(user))
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rvsdg2rhls/rvsdg2rhls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ trace_call(jlm::rvsdg::input * input)

auto argument = dynamic_cast<const rvsdg::RegionArgument *>(input->origin());
const jlm::rvsdg::output * result;
if (auto to = dynamic_cast<const rvsdg::ThetaOutput *>(input->origin()))
if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*input->origin()))
{
result = trace_call(to->input());
result = trace_call(theta->MapOutputLoopVar(*input->origin()).input);
}
else if (argument == nullptr)
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/ir/hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ExitResult::Copy(rvsdg::output & origin, rvsdg::StructuralOutput * output)
}

rvsdg::StructuralOutput *
loop_node::add_loopvar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer)
loop_node::AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer)
{
auto input = rvsdg::StructuralInput::create(this, origin, origin->Type());
auto output = rvsdg::StructuralOutput::create(this, origin->Type());
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/ir/hls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ class loop_node final : public rvsdg::StructuralNode
add_backedge(std::shared_ptr<const jlm::rvsdg::Type> type);

rvsdg::StructuralOutput *
add_loopvar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer = nullptr);
AddLoopVar(jlm::rvsdg::output * origin, jlm::rvsdg::output ** buffer = nullptr);

jlm::rvsdg::output *
add_loopconst(jlm::rvsdg::output * origin);
Expand Down
66 changes: 39 additions & 27 deletions jlm/hls/opt/cne.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,43 @@ congruent(jlm::rvsdg::output * o1, jlm::rvsdg::output * o2, vset & vs, cnectx &
if (o1->type() != o2->type())
return false;

if (is<rvsdg::ThetaArgument>(o1) && is<rvsdg::ThetaArgument>(o2))
if (auto theta1 = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*o1))
{
JLM_ASSERT(o1->region()->node() == o2->region()->node());
auto a1 = static_cast<rvsdg::RegionArgument *>(o1);
auto a2 = static_cast<rvsdg::RegionArgument *>(o2);
vs.insert(a1, a2);
auto i1 = a1->input(), i2 = a2->input();
if (!congruent(a1->input()->origin(), a2->input()->origin(), vs, ctx))
return false;
if (auto theta2 = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*o2))
{
JLM_ASSERT(o1->region()->node() == o2->region()->node());
auto loopvar1 = theta1->MapPreLoopVar(*o1);
auto loopvar2 = theta2->MapPreLoopVar(*o2);
vs.insert(o1, o2);
auto i1 = loopvar1.input, i2 = loopvar2.input;
if (!congruent(loopvar1.input->origin(), loopvar2.input->origin(), vs, ctx))
return false;

auto output1 = o1->region()->node()->output(i1->index());
auto output2 = o2->region()->node()->output(i2->index());
return congruent(output1, output2, vs, ctx);
auto output1 = o1->region()->node()->output(i1->index());
auto output2 = o2->region()->node()->output(i2->index());
return congruent(output1, output2, vs, ctx);
}
}

auto n1 = jlm::rvsdg::output::GetNode(*o1);
auto n2 = jlm::rvsdg::output::GetNode(*o2);
if (is<jlm::rvsdg::ThetaOperation>(n1) && is<jlm::rvsdg::ThetaOperation>(n2) && n1 == n2)
if (auto theta1 = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*o1))
{
auto so1 = static_cast<StructuralOutput *>(o1);
auto so2 = static_cast<StructuralOutput *>(o2);
vs.insert(o1, o2);
auto r1 = so1->results.first();
auto r2 = so2->results.first();
return congruent(r1->origin(), r2->origin(), vs, ctx);
if (auto theta2 = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*o2))
{
if (theta1 == theta2)
{
vs.insert(o1, o2);
auto loopvar1 = theta1->MapOutputLoopVar(*o1);
auto loopvar2 = theta2->MapOutputLoopVar(*o2);
auto r1 = loopvar1.post;
auto r2 = loopvar2.post;
return congruent(r1->origin(), r2->origin(), vs, ctx);
}
}
}

auto n1 = jlm::rvsdg::output::GetNode(*o1);
auto n2 = jlm::rvsdg::output::GetNode(*o2);

auto a1 = dynamic_cast<rvsdg::RegionArgument *>(o1);
auto a2 = dynamic_cast<rvsdg::RegionArgument *>(o2);
if (a1 && is<hls::loop_op>(a1->region()->node()) && a2 && is<hls::loop_op>(a2->region()->node()))
Expand Down Expand Up @@ -331,10 +341,12 @@ mark_theta(const rvsdg::StructuralNode * node, cnectx & ctx)
{
auto input1 = theta->input(i1);
auto input2 = theta->input(i2);
if (congruent(input1->argument(), input2->argument(), ctx))
auto loopvar1 = theta->MapInputLoopVar(*input1);
auto loopvar2 = theta->MapInputLoopVar(*input2);
if (congruent(loopvar1.pre, loopvar2.pre, ctx))
{
ctx.mark(input1->argument(), input2->argument());
ctx.mark(input1->output(), input2->output());
ctx.mark(loopvar1.pre, loopvar2.pre);
ctx.mark(loopvar1.output, loopvar2.output);
}
}
}
Expand Down Expand Up @@ -530,11 +542,11 @@ divert_theta(rvsdg::StructuralNode * node, cnectx & ctx)
auto theta = static_cast<rvsdg::ThetaNode *>(node);
auto subregion = node->subregion(0);

for (const auto & lv : *theta)
for (const auto & lv : theta->GetLoopVars())
{
JLM_ASSERT(ctx.set(lv->argument())->size() == ctx.set(lv)->size());
divert_users(lv->argument(), ctx);
divert_users(lv, ctx);
JLM_ASSERT(ctx.set(lv.pre)->size() == ctx.set(lv.output)->size());
divert_users(lv.pre, ctx);
divert_users(lv.output, ctx);
}

divert(subregion, ctx);
Expand Down
5 changes: 3 additions & 2 deletions jlm/hls/util/view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ region_to_dot(rvsdg::Region * region)
{
dot << edge(be->argument(), be, true);
}
else if (auto to = dynamic_cast<rvsdg::ThetaOutput *>(region->result(i)->output()))
else if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*region->result(i)->output()))
{
dot << edge(to->argument(), to->result(), true);
auto loopvar = theta->MapOutputLoopVar(*region->result(i)->output());
dot << edge(loopvar.pre, loopvar.post, true);
}
}

Expand Down
13 changes: 7 additions & 6 deletions jlm/llvm/frontend/InterProceduralGraphConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ Convert(
* Add loop variables
*/
auto & demandSet = demandMap.Lookup<LoopAnnotationSet>(loopAggregationNode);
std::unordered_map<const variable *, rvsdg::ThetaOutput *> thetaOutputMap;
std::unordered_map<const variable *, rvsdg::ThetaNode::LoopVar> thetaLoopVarMap;
for (auto & v : demandSet.LoopVariables().Variables())
{
rvsdg::output * value = nullptr;
Expand All @@ -778,8 +778,9 @@ Convert(
{
value = outerVariableMap.lookup(&v);
}
thetaOutputMap[&v] = theta->add_loopvar(value);
thetaVariableMap.insert(&v, thetaOutputMap[&v]->argument());
auto loopvar = theta->AddLoopVar(value);
thetaLoopVarMap[&v] = loopvar;
thetaVariableMap.insert(&v, loopvar.pre);
}

/*
Expand All @@ -797,8 +798,8 @@ Convert(
*/
for (auto & v : demandSet.LoopVariables().Variables())
{
JLM_ASSERT(thetaOutputMap.find(&v) != thetaOutputMap.end());
thetaOutputMap[&v]->result()->divert_to(thetaVariableMap.lookup(&v));
JLM_ASSERT(thetaLoopVarMap.find(&v) != thetaLoopVarMap.end());
thetaLoopVarMap[&v].post->divert_to(thetaVariableMap.lookup(&v));
}

/*
Expand All @@ -820,7 +821,7 @@ Convert(
for (auto & v : demandSet.LoopVariables().Variables())
{
JLM_ASSERT(outerVariableMap.contains(&v));
outerVariableMap.insert(&v, thetaOutputMap[&v]);
outerVariableMap.insert(&v, thetaLoopVarMap[&v].output);
}
}

Expand Down
Loading
Loading