Skip to content

Commit

Permalink
Move and rename lambda node
Browse files Browse the repository at this point in the history
Move lambda node to rvsdg, and rename it to LambdaNode / LambdaOperation.
Provide "baseline" LambdaOperation and derived llvm-specific features.
The name LambdaOperation is provisional and may later change to better
capture its intent.
  • Loading branch information
caleridas committed Jan 28, 2025
1 parent fc0546f commit 309e5d0
Show file tree
Hide file tree
Showing 88 changed files with 1,521 additions and 1,431 deletions.
5 changes: 3 additions & 2 deletions jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2813,13 +2813,14 @@ RhlsToFirrtlConverter::TraceStructuralOutput(rvsdg::StructuralOutput * output)

// Emit a circuit
circt::firrtl::CircuitOp
RhlsToFirrtlConverter::MlirGen(const llvm::lambda::node * lambdaNode)
RhlsToFirrtlConverter::MlirGen(const rvsdg::LambdaNode * lambdaNode)
{

// Ensure consistent naming across runs
create_node_names(lambdaNode->subregion());
// The same name is used for the circuit and main module
auto moduleName = Builder_->getStringAttr(lambdaNode->GetOperation().name() + "_lambda_mod");
auto moduleName = Builder_->getStringAttr(
dynamic_cast<llvm::LlvmLambdaOperation &>(lambdaNode->GetOperation()).name() + "_lambda_mod");
// Create the top level FIRRTL circuit
auto circuit = Builder_->create<circt::firrtl::CircuitOp>(Builder_->getUnknownLoc(), moduleName);
// The body will be populated with a list of modules
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rhls2firrtl/RhlsToFirrtlConverter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RhlsToFirrtlConverter : public BaseHLS
operator=(RhlsToFirrtlConverter &&) = delete;

circt::firrtl::CircuitOp
MlirGen(const llvm::lambda::node * lamdaNode);
MlirGen(const rvsdg::LambdaNode * lamdaNode);

void
WriteModuleToFile(const circt::firrtl::FModuleOp fModuleOp, const rvsdg::Node * node);
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rhls2firrtl/base-hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ BaseHLS::create_node_names(rvsdg::Region * r)
}
}

const jlm::llvm::lambda::node *
const jlm::rvsdg::LambdaNode *
BaseHLS::get_hls_lambda(llvm::RvsdgModule & rm)
{
auto region = &rm.Rvsdg().GetRootRegion();
auto ln = dynamic_cast<const llvm::lambda::node *>(region->Nodes().begin().ptr());
auto ln = dynamic_cast<const rvsdg::LambdaNode *>(region->Nodes().begin().ptr());
if (region->nnodes() == 1 && ln)
{
return ln;
Expand Down
10 changes: 5 additions & 5 deletions jlm/hls/backend/rhls2firrtl/base-hls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BaseHLS
static std::string
get_port_name(jlm::rvsdg::output * port);

const llvm::lambda::node *
const rvsdg::LambdaNode *
get_hls_lambda(llvm::RvsdgModule & rm);

void
Expand All @@ -81,7 +81,7 @@ class BaseHLS
* @return the arguments that represent memory responses
*/
std::vector<rvsdg::RegionArgument *>
get_mem_resps(const llvm::lambda::node & lambda)
get_mem_resps(const rvsdg::LambdaNode & lambda)
{
std::vector<rvsdg::RegionArgument *> mem_resps;
for (auto arg : lambda.subregion()->Arguments())
Expand All @@ -99,7 +99,7 @@ class BaseHLS
* @return the results that represent memory requests
*/
std::vector<rvsdg::RegionResult *>
get_mem_reqs(const llvm::lambda::node & lambda)
get_mem_reqs(const rvsdg::LambdaNode & lambda)
{
std::vector<rvsdg::RegionResult *> mem_resps;
for (auto result : lambda.subregion()->Results())
Expand All @@ -118,7 +118,7 @@ class BaseHLS
* @return the arguments of the lambda that represent kernel inputs
*/
std::vector<rvsdg::RegionArgument *>
get_reg_args(const llvm::lambda::node & lambda)
get_reg_args(const rvsdg::LambdaNode & lambda)
{
std::vector<rvsdg::RegionArgument *> args;
for (auto argument : lambda.subregion()->Arguments())
Expand All @@ -136,7 +136,7 @@ class BaseHLS
* @return the results of the lambda that represent the kernel outputs
*/
std::vector<rvsdg::RegionResult *>
get_reg_results(const llvm::lambda::node & lambda)
get_reg_results(const rvsdg::LambdaNode & lambda)
{
std::vector<rvsdg::RegionResult *> results;
for (auto result : lambda.subregion()->Results())
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rhls2firrtl/json-hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ JsonHLS::GetText(llvm::RvsdgModule & rm)
{
std::ostringstream json;
const auto & ln = *get_hls_lambda(rm);
auto function_name = ln.GetOperation().name();
auto function_name = dynamic_cast<llvm::LlvmLambdaOperation &>(ln.GetOperation()).name();
auto file_name = get_base_file_name(rm);
json << "{\n";

Expand Down
7 changes: 4 additions & 3 deletions jlm/hls/backend/rhls2firrtl/verilator-harness-hls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ConvertToCType(const rvsdg::Type * type)
* @return the return type of the kernel as written in C, or nullopt if it has no return value.
*/
std::optional<std::string>
GetReturnTypeAsC(const llvm::lambda::node & kernel)
GetReturnTypeAsC(const rvsdg::LambdaNode & kernel)
{
const auto & results = kernel.GetOperation().type().Results();

Expand All @@ -65,7 +65,7 @@ GetReturnTypeAsC(const llvm::lambda::node & kernel)
* @return a tuple (number of parameters, string of parameters, string of call arguments)
*/
std::tuple<size_t, std::string, std::string>
GetParameterListAsC(const llvm::lambda::node & kernel)
GetParameterListAsC(const rvsdg::LambdaNode & kernel)
{
size_t argument_index = 0;
std::ostringstream parameters;
Expand Down Expand Up @@ -97,7 +97,8 @@ VerilatorHarnessHLS::GetText(llvm::RvsdgModule & rm)
{
std::ostringstream cpp;
const auto & kernel = *get_hls_lambda(rm);
const auto & function_name = kernel.GetOperation().name();
const auto & function_name =
dynamic_cast<llvm::LlvmLambdaOperation &>(kernel.GetOperation()).name();

// The request and response parts of memory queues
const auto mem_reqs = get_mem_reqs(kernel);
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/DeadNodeElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ EliminateDeadNodes(llvm::RvsdgModule & rvsdgModule)
throw util::error("Root should have only one node now");
}

auto lambdaNode = dynamic_cast<const llvm::lambda::node *>(rootRegion.Nodes().begin().ptr());
auto lambdaNode = dynamic_cast<const rvsdg::LambdaNode *>(rootRegion.Nodes().begin().ptr());
if (!lambdaNode)
{
throw util::error("Node needs to be a lambda");
Expand Down
15 changes: 6 additions & 9 deletions jlm/hls/backend/rvsdg2rhls/UnusedStateRemoval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ IsPassthroughResult(const rvsdg::input & result)
}

static void
RemoveUnusedStatesFromLambda(llvm::lambda::node & lambdaNode)
RemoveUnusedStatesFromLambda(rvsdg::LambdaNode & lambdaNode)
{
const auto & op = lambdaNode.GetOperation();
const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(lambdaNode.GetOperation());
auto & oldFunctionType = op.type();

std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newArgumentTypes;
Expand Down Expand Up @@ -65,12 +65,9 @@ RemoveUnusedStatesFromLambda(llvm::lambda::node & lambdaNode)
}

auto newFunctionType = rvsdg::FunctionType::Create(newArgumentTypes, newResultTypes);
auto newLambda = llvm::lambda::node::create(
lambdaNode.region(),
newFunctionType,
op.name(),
op.linkage(),
op.attributes());
auto newLambda = rvsdg::LambdaNode::Create(
*lambdaNode.region(),
llvm::LlvmLambdaOperation::Create(newFunctionType, op.name(), op.linkage(), op.attributes()));

rvsdg::SubstitutionMap substitutionMap;
for (const auto & ctxvar : lambdaNode.GetContextVars())
Expand Down Expand Up @@ -210,7 +207,7 @@ RemoveUnusedStatesInStructuralNode(rvsdg::StructuralNode & structuralNode)
{
RemoveUnusedStatesFromThetaNode(*thetaNode);
}
else if (auto lambdaNode = dynamic_cast<llvm::lambda::node *>(&structuralNode))
else if (auto lambdaNode = dynamic_cast<rvsdg::LambdaNode *>(&structuralNode))
{
RemoveUnusedStatesFromLambda(*lambdaNode);
}
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/add-prints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ route_to_region(jlm::rvsdg::output * output, rvsdg::Region * region)
{
output = theta->AddLoopVar(output).pre;
}
else if (auto lambda = dynamic_cast<llvm::lambda::node *>(region->node()))
else if (auto lambda = dynamic_cast<rvsdg::LambdaNode *>(region->node()))
{
output = lambda->AddContextVar(*output).inner;
}
Expand Down
17 changes: 7 additions & 10 deletions jlm/hls/backend/rvsdg2rhls/add-triggers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ get_trigger(rvsdg::Region * region)
return nullptr;
}

jlm::llvm::lambda::node *
add_lambda_argument(llvm::lambda::node * ln, std::shared_ptr<const jlm::rvsdg::Type> type)
jlm::rvsdg::LambdaNode *
add_lambda_argument(rvsdg::LambdaNode * ln, std::shared_ptr<const jlm::rvsdg::Type> type)
{
const auto & op = ln->GetOperation();
const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
auto old_fcttype = op.type();
std::vector<std::shared_ptr<const jlm::rvsdg::Type>> new_argument_types;
for (size_t i = 0; i < old_fcttype.NumArguments(); ++i)
Expand All @@ -45,12 +45,9 @@ add_lambda_argument(llvm::lambda::node * ln, std::shared_ptr<const jlm::rvsdg::T
new_result_types.push_back(old_fcttype.Results()[i]);
}
auto new_fcttype = rvsdg::FunctionType::Create(new_argument_types, new_result_types);
auto new_lambda = llvm::lambda::node::create(
ln->region(),
new_fcttype,
op.name(),
op.linkage(),
op.attributes());
auto new_lambda = rvsdg::LambdaNode::Create(
*ln->region(),
llvm::LlvmLambdaOperation::Create(new_fcttype, op.name(), op.linkage(), op.attributes()));

rvsdg::SubstitutionMap smap;
for (const auto & ctxvar : ln->GetContextVars())
Expand Down Expand Up @@ -95,7 +92,7 @@ add_triggers(rvsdg::Region * region)
{
if (rvsdg::is<rvsdg::StructuralOperation>(node))
{
if (auto ln = dynamic_cast<llvm::lambda::node *>(node))
if (auto ln = dynamic_cast<rvsdg::LambdaNode *>(node))
{
// check here in order not to process removed and re-added node twice
if (!get_trigger(ln->subregion()))
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rvsdg2rhls/add-triggers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace jlm::hls
rvsdg::output *
get_trigger(rvsdg::Region * region);

llvm::lambda::node *
add_lambda_argument(llvm::lambda::node * ln, const rvsdg::Type * type);
rvsdg::LambdaNode *
add_lambda_argument(rvsdg::LambdaNode * ln, const rvsdg::Type * type);

void
add_triggers(rvsdg::Region * region);
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/check-rhls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ check_rhls(llvm::RvsdgModule & rm)
{
throw jlm::util::error("Root should have only one node now");
}
auto ln = dynamic_cast<const llvm::lambda::node *>(root->Nodes().begin().ptr());
auto ln = dynamic_cast<const rvsdg::LambdaNode *>(root->Nodes().begin().ptr());
if (!ln)
{
throw jlm::util::error("Node needs to be a lambda");
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/dae-conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ process_loopnode(loop_node * loopNode)
void
dae_conv(rvsdg::Region * region)
{
auto lambda = dynamic_cast<const jlm::llvm::lambda::node *>(region->Nodes().begin().ptr());
auto lambda = dynamic_cast<const jlm::rvsdg::LambdaNode *>(region->Nodes().begin().ptr());
bool changed;
do
{
Expand Down
2 changes: 1 addition & 1 deletion jlm/hls/backend/rvsdg2rhls/distribute-constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ hls::distribute_constants(rvsdg::Region * region)
{
if (rvsdg::is<rvsdg::StructuralOperation>(node))
{
if (auto ln = dynamic_cast<llvm::lambda::node *>(node))
if (auto ln = dynamic_cast<rvsdg::LambdaNode *>(node))
{
distribute_constants(ln->subregion());
}
Expand Down
13 changes: 7 additions & 6 deletions jlm/hls/backend/rvsdg2rhls/instrument-ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
namespace jlm::hls
{

llvm::lambda::node *
change_function_name(llvm::lambda::node * ln, const std::string & name)
rvsdg::LambdaNode *
change_function_name(rvsdg::LambdaNode * ln, const std::string & name)
{
const auto & op = ln->GetOperation();
auto lambda =
llvm::lambda::node::create(ln->region(), op.Type(), name, op.linkage(), op.attributes());
const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
auto lambda = rvsdg::LambdaNode::Create(
*ln->region(),
llvm::LlvmLambdaOperation::Create(op.Type(), name, op.linkage(), op.attributes()));

/* add context variables */
rvsdg::SubstitutionMap subregionmap;
Expand Down Expand Up @@ -61,7 +62,7 @@ instrument_ref(llvm::RvsdgModule & rm)
{
auto & graph = rm.Rvsdg();
auto root = &graph.GetRootRegion();
auto lambda = dynamic_cast<llvm::lambda::node *>(root->Nodes().begin().ptr());
auto lambda = dynamic_cast<rvsdg::LambdaNode *>(root->Nodes().begin().ptr());

auto newLambda = change_function_name(lambda, "instrumented_ref");

Expand Down
25 changes: 11 additions & 14 deletions jlm/hls/backend/rvsdg2rhls/mem-conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ trace_function_calls(

jlm::rvsdg::SimpleNode *
find_decouple_response(
const jlm::llvm::lambda::node * lambda,
const jlm::rvsdg::LambdaNode * lambda,
const jlm::rvsdg::bitconstant_op * request_constant)
{
jlm::rvsdg::output * response_function = nullptr;
Expand Down Expand Up @@ -264,7 +264,7 @@ find_decouple_response(

jlm::rvsdg::SimpleNode *
replace_decouple(
const jlm::llvm::lambda::node * lambda,
const jlm::rvsdg::LambdaNode * lambda,
jlm::rvsdg::SimpleNode * decouple_request,
jlm::rvsdg::output * resp)
{
Expand Down Expand Up @@ -527,7 +527,7 @@ IsDecoupledFunctionPointer(

void
jlm::hls::TracePointerArguments(
const jlm::llvm::lambda::node * lambda,
const jlm::rvsdg::LambdaNode * lambda,
port_load_store_decouple & portNodes)
{
for (size_t i = 0; i < lambda->subregion()->narguments(); ++i)
Expand Down Expand Up @@ -568,13 +568,13 @@ jlm::hls::MemoryConverter(jlm::llvm::RvsdgModule & rm)
//

auto root = &rm.Rvsdg().GetRootRegion();
auto lambda = dynamic_cast<jlm::llvm::lambda::node *>(root->Nodes().begin().ptr());
auto lambda = dynamic_cast<jlm::rvsdg::LambdaNode *>(root->Nodes().begin().ptr());

//
// Converting loads and stores to explicitly use memory ports
// This modifies the function signature so we create a new lambda node to replace the old one
//
const auto & op = lambda->GetOperation();
const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(lambda->GetOperation());
auto oldFunctionType = op.type();
std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newArgumentTypes;
for (size_t i = 0; i < oldFunctionType.NumArguments(); ++i)
Expand Down Expand Up @@ -641,12 +641,9 @@ jlm::hls::MemoryConverter(jlm::llvm::RvsdgModule & rm)
// Create new lambda and copy the region from the old lambda
//
auto newFunctionType = jlm::rvsdg::FunctionType::Create(newArgumentTypes, newResultTypes);
auto newLambda = jlm::llvm::lambda::node::create(
lambda->region(),
newFunctionType,
op.name(),
op.linkage(),
op.attributes());
auto newLambda = jlm::rvsdg::LambdaNode::Create(
*lambda->region(),
llvm::LlvmLambdaOperation::Create(newFunctionType, op.name(), op.linkage(), op.attributes()));

rvsdg::SubstitutionMap smap;
for (const auto & ctxvar : lambda->GetContextVars())
Expand Down Expand Up @@ -727,7 +724,7 @@ jlm::hls::MemoryConverter(jlm::llvm::RvsdgModule & rm)

// Need to get the lambda from the root since remote_unused_state replaces the lambda
JLM_ASSERT(root->nnodes() == 1);
newLambda = jlm::util::AssertedCast<jlm::llvm::lambda::node>(root->Nodes().begin().ptr());
newLambda = jlm::util::AssertedCast<jlm::rvsdg::LambdaNode>(root->Nodes().begin().ptr());

// Go through in reverse since we are removing things
auto ctxvars = newLambda->GetContextVars();
Expand Down Expand Up @@ -755,7 +752,7 @@ jlm::hls::MemoryConverter(jlm::llvm::RvsdgModule & rm)

jlm::rvsdg::output *
jlm::hls::ConnectRequestResponseMemPorts(
const jlm::llvm::lambda::node * lambda,
const jlm::rvsdg::LambdaNode * lambda,
size_t argumentIndex,
rvsdg::SubstitutionMap & smap,
const std::vector<jlm::rvsdg::SimpleNode *> & originalLoadNodes,
Expand Down Expand Up @@ -926,7 +923,7 @@ jlm::hls::ReplaceStore(rvsdg::SubstitutionMap & smap, const jlm::rvsdg::SimpleNo
jlm::rvsdg::SimpleNode *
ReplaceDecouple(
jlm::rvsdg::SubstitutionMap & smap,
const jlm::llvm::lambda::node * lambda,
const jlm::rvsdg::LambdaNode * lambda,
jlm::rvsdg::SimpleNode * originalDecoupleRequest,
jlm::rvsdg::output * response)
{
Expand Down
4 changes: 2 additions & 2 deletions jlm/hls/backend/rvsdg2rhls/mem-conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ typedef std::vector<std::tuple<
* @param portNodes A vector where each element contains all memory operations traced from a pointer
*/
void
TracePointerArguments(const llvm::lambda::node * lambda, port_load_store_decouple & portNodes);
TracePointerArguments(const rvsdg::LambdaNode * lambda, port_load_store_decouple & portNodes);

void
MemoryConverter(llvm::RvsdgModule & rm);
Expand All @@ -42,7 +42,7 @@ MemoryConverter(llvm::RvsdgModule & rm);
*/
jlm::rvsdg::output *
ConnectRequestResponseMemPorts(
const llvm::lambda::node * lambda,
const rvsdg::LambdaNode * lambda,
size_t argumentIndex,
rvsdg::SubstitutionMap & smap,
const std::vector<jlm::rvsdg::SimpleNode *> & originalLoadNodes,
Expand Down
Loading

0 comments on commit 309e5d0

Please sign in to comment.