From b7c1cdf5215a04cd5ab256a9825373cca27f9cd7 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Thu, 20 Jun 2024 16:40:02 +0100 Subject: [PATCH] add new generic compute graph and rewrite c++ functional backend to use it --- Makefile | 2 +- backends/functional/Makefile.inc | 3 +- backends/functional/cxx.cc | 392 +++++++---------------- backends/functional/test_generic.cc | 56 ++++ kernel/functional.h | 22 +- kernel/{graphtools.h => functionalir.cc} | 74 ++++- kernel/functionalir.h | 381 ++++++++++++++++++++++ 7 files changed, 628 insertions(+), 302 deletions(-) create mode 100644 backends/functional/test_generic.cc rename kernel/{graphtools.h => functionalir.cc} (88%) create mode 100644 kernel/functionalir.h diff --git a/Makefile b/Makefile index d16188b87b0..64122624722 100644 --- a/Makefile +++ b/Makefile @@ -648,7 +648,7 @@ $(eval $(call add_include_file,backends/rtlil/rtlil_backend.h)) OBJS += kernel/driver.o kernel/register.o kernel/rtlil.o kernel/log.o kernel/calc.o kernel/yosys.o OBJS += kernel/binding.o OBJS += kernel/cellaigs.o kernel/celledges.o kernel/satgen.o kernel/scopeinfo.o kernel/qcsat.o kernel/mem.o kernel/ffmerge.o kernel/ff.o kernel/yw.o kernel/json.o kernel/fmt.o -OBJS += kernel/drivertools.o +OBJS += kernel/drivertools.o kernel/functionalir.o ifeq ($(ENABLE_ZLIB),1) OBJS += kernel/fstdata.o endif diff --git a/backends/functional/Makefile.inc b/backends/functional/Makefile.inc index c712d2aefe5..f4b968ef371 100644 --- a/backends/functional/Makefile.inc +++ b/backends/functional/Makefile.inc @@ -1,2 +1,3 @@ OBJS += backends/functional/cxx.o -OBJS += backends/functional/smtlib.o +#OBJS += backends/functional/smtlib.o +OBJS += backends/functional/test_generic.o diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index 08d9ba7914b..789ab7949ba 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -18,10 +18,7 @@ */ #include "kernel/yosys.h" -#include "kernel/drivertools.h" -#include "kernel/topo_scc.h" -#include "kernel/functional.h" -#include "kernel/graphtools.h" +#include "kernel/functionalir.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN @@ -84,38 +81,17 @@ struct CxxScope { }; struct CxxType { - bool _is_memory; - int _width; - int _addr_width; -public: - CxxType() : _is_memory(false), _width(0), _addr_width(0) { } - CxxType(int width) : _is_memory(false), _width(width), _addr_width(0) { } - CxxType(int addr_width, int data_width) : _is_memory(true), _width(data_width), _addr_width(addr_width) { } - static CxxType signal(int width) { return CxxType(width); } - static CxxType memory(int addr_width, int data_width) { return CxxType(addr_width, data_width); } - bool is_signal() const { return !_is_memory; } - bool is_memory() const { return _is_memory; } - int width() const { log_assert(is_signal()); return _width; } - int addr_width() const { log_assert(is_memory()); return _addr_width; } - int data_width() const { log_assert(is_memory()); return _width; } + FunctionalIR::Sort sort; + CxxType(FunctionalIR::Sort sort) : sort(sort) {} std::string to_string() const { - if(_is_memory) { - return stringf("Memory<%d, %d>", addr_width(), data_width()); + if(sort.is_memory()) { + return stringf("Memory<%d, %d>", sort.addr_width(), sort.data_width()); + } else if(sort.is_signal()) { + return stringf("Signal<%d>", sort.width()); } else { - return stringf("Signal<%d>", width()); + log_error("unknown sort"); } } - bool operator ==(CxxType const& other) const { - if(_is_memory != other._is_memory) return false; - if(_is_memory && _addr_width != other._addr_width) return false; - return _width == other._width; - } - unsigned int hash() const { - if(_is_memory) - return mkhash(1, mkhash(_width, _addr_width)); - else - return mkhash(0, _width); - } }; struct CxxWriter { @@ -135,9 +111,8 @@ struct CxxStruct { dict types; CxxScope scope; bool generate_methods; - int count; - CxxStruct(std::string name, bool generate_methods = false, int count = 0) - : name(name), generate_methods(generate_methods), count(count) { + CxxStruct(std::string name, bool generate_methods = false) + : name(name), generate_methods(generate_methods) { scope.reserve("out"); scope.reserve("dump"); } @@ -159,7 +134,7 @@ struct CxxStruct { if (generate_methods) { // Add size method f.printf("\tint size() const {\n"); - f.printf("\t\treturn %d;\n", count); + f.printf("\t\treturn %d;\n", types.size()); f.printf("\t}\n\n"); // Add get_input method @@ -197,119 +172,87 @@ struct CxxStruct { } }; -struct CxxFunction { - IdString name; - CxxType type; - dict parameters; - - CxxFunction(IdString name, CxxType type) : name(name), type(type) {} - CxxFunction(IdString name, CxxType type, dict parameters) : name(name), type(type), parameters(parameters) {} - - bool operator==(CxxFunction const &other) const { - return name == other.name && parameters == other.parameters && type == other.type; - } - - unsigned int hash() const { - return mkhash(name.hash(), mkhash(type.hash(), parameters.hash())); - } -}; - -typedef ComputeGraph CxxComputeGraph; - -class CxxComputeGraphFactory { - CxxComputeGraph &graph; - using T = CxxComputeGraph::Ref; - static bool is_single_output(IdString type) - { - auto it = yosys_celltypes.cell_types.find(type); - return it != yosys_celltypes.cell_types.end() && it->second.outputs.size() <= 1; - } +struct CxxTemplate { + vector> _v; public: - CxxComputeGraphFactory(CxxComputeGraph &g) : graph(g) {} - T slice(T a, int in_width, int offset, int out_width) { - log_assert(offset + out_width <= in_width); - return graph.add(CxxFunction(ID($$slice), out_width, {{ID(offset), offset}}), 0, std::array{a}); - } - T extend(T a, int in_width, int out_width, bool is_signed) { - log_assert(in_width < out_width); - if(is_signed) - return graph.add(CxxFunction(ID($sign_extend), out_width, {{ID(WIDTH), out_width}}), 0, std::array{a}); - else - return graph.add(CxxFunction(ID($zero_extend), out_width, {{ID(WIDTH), out_width}}), 0, std::array{a}); - } - T concat(T a, int a_width, T b, int b_width) { - return graph.add(CxxFunction(ID($$concat), a_width + b_width), 0, std::array{a, b}); - } - T add(T a, T b, int width) { return graph.add(CxxFunction(ID($add), width), 0, std::array{a, b}); } - T sub(T a, T b, int width) { return graph.add(CxxFunction(ID($sub), width), 0, std::array{a, b}); } - T bitwise_and(T a, T b, int width) { return graph.add(CxxFunction(ID($and), width), 0, std::array{a, b}); } - T bitwise_or(T a, T b, int width) { return graph.add(CxxFunction(ID($or), width), 0, std::array{a, b}); } - T bitwise_xor(T a, T b, int width) { return graph.add(CxxFunction(ID($xor), width), 0, std::array{a, b}); } - T bitwise_not(T a, int width) { return graph.add(CxxFunction(ID($not), width), 0, std::array{a}); } - T neg(T a, int width) { return graph.add(CxxFunction(ID($neg), width), 0, std::array{a}); } - T mux(T a, T b, T s, int width) { return graph.add(CxxFunction(ID($mux), width), 0, std::array{a, b, s}); } - T pmux(T a, T b, T s, int width, int) { return graph.add(CxxFunction(ID($pmux), width), 0, std::array{a, b, s}); } - T reduce_and(T a, int) { return graph.add(CxxFunction(ID($reduce_and), 1), 0, std::array{a}); } - T reduce_or(T a, int) { return graph.add(CxxFunction(ID($reduce_or), 1), 0, std::array{a}); } - T reduce_xor(T a, int) { return graph.add(CxxFunction(ID($reduce_xor), 1), 0, std::array{a}); } - T eq(T a, T b, int) { return graph.add(CxxFunction(ID($eq), 1), 0, std::array{a, b}); } - T ne(T a, T b, int) { return graph.add(CxxFunction(ID($ne), 1), 0, std::array{a, b}); } - T gt(T a, T b, int) { return graph.add(CxxFunction(ID($gt), 1), 0, std::array{a, b}); } - T ge(T a, T b, int) { return graph.add(CxxFunction(ID($ge), 1), 0, std::array{a, b}); } - T ugt(T a, T b, int) { return graph.add(CxxFunction(ID($ugt), 1), 0, std::array{a, b}); } - T uge(T a, T b, int) { return graph.add(CxxFunction(ID($uge), 1), 0, std::array{a, b}); } - T logical_shift_left(T a, T b, int y_width, int) { return graph.add(CxxFunction(ID($shl), y_width, {{ID(WIDTH), y_width}}), 0, std::array{a, b}); } - T logical_shift_right(T a, T b, int y_width, int) { return graph.add(CxxFunction(ID($shr), y_width, {{ID(WIDTH), y_width}}), 0, std::array{a, b}); } - T arithmetic_shift_right(T a, T b, int y_width, int) { return graph.add(CxxFunction(ID($asr), y_width, {{ID(WIDTH), y_width}}), 0, std::array{a, b}); } - - T constant(RTLIL::Const value) { - return graph.add(CxxFunction(ID($$const), value.size(), {{ID(value), value}}), 0); - } - T input(IdString name, int width) { return graph.add(CxxFunction(ID($$input), width, {{name, {}}}), 0); } - T state(IdString name, int width) { return graph.add(CxxFunction(ID($$state), width, {{name, {}}}), 0); } - T state_memory(IdString name, int addr_width, int data_width) { - return graph.add(CxxFunction(ID($$state), CxxType::memory(addr_width, data_width), {{name, {}}}), 0); - } - T cell_output(T cell, IdString type, IdString name, int width) { - if (is_single_output(type)) - return cell; - else - return graph.add(CxxFunction(ID($$cell_output), width, {{name, {}}}), 0, std::array{cell}); - } - T multiple(vector args, int width) { - return graph.add(CxxFunction(ID($$multiple), width), 0, args); - } - T undriven(int width) { - return graph.add(CxxFunction(ID($$undriven), width), 0); - } - - T memory_read(T mem, T addr, int addr_width, int data_width) { - return graph.add(CxxFunction(ID($memory_read), data_width), 0, std::array{mem, addr}); + CxxTemplate(std::string fmt) { + std::string buf; + for(auto it = fmt.begin(); it != fmt.end(); it++){ + if(*it == '%'){ + it++; + log_assert(it != fmt.end()); + if(*it == '%') + buf += *it; + else { + log_assert(*it >= '0' && *it <= '9'); + _v.emplace_back(std::move(buf)); + _v.emplace_back((int)(*it - '0')); + } + }else + buf += *it; + } + if(!buf.empty()) + _v.emplace_back(std::move(buf)); } - T memory_write(T mem, T addr, T data, int addr_width, int data_width) { - return graph.add(CxxFunction(ID($memory_write), CxxType::memory(addr_width, data_width)), 0, std::array{mem, addr, data}); + template static std::string format(CxxTemplate fmt, Args&&... args) { + vector strs = {args...}; + std::string result; + for(auto &v : fmt._v){ + if(std::string *s = std::get_if(&v)) + result += *s; + else if(int *i = std::get_if(&v)) + result += strs[*i]; + else + log_error("missing case"); + } + return result; } +}; - T create_pending(int width) { - return graph.add(CxxFunction(ID($$pending), width), 0); - } - void update_pending(T pending, T node) { - log_assert(pending.function().name == ID($$pending)); - pending.set_function(CxxFunction(ID($$buf), pending.function().type)); - pending.append_arg(node); - } - void declare_output(T node, IdString name, int) { - node.assign_key(name); - } - void declare_state(T node, IdString name, int) { - node.assign_key(name); - } - void declare_state_memory(T node, IdString name, int, int) { - node.assign_key(name); - } - void suggest_name(T node, IdString name) { - node.sparse_attr() = name; +template struct CxxPrintVisitor { + using Node = FunctionalIR::Node; + NodeNames np; + CxxStruct &input_struct; + CxxStruct &state_struct; + CxxPrintVisitor(NodeNames np, CxxStruct &input_struct, CxxStruct &state_struct) : np(np), input_struct(input_struct), state_struct(state_struct) { } + template std::string arg_to_string(T n) { return std::to_string(n); } + template<> std::string arg_to_string(std::string n) { return n; } + template<> std::string arg_to_string(Node n) { return np(n); } + template std::string format(std::string fmt, Args&&... args) { + return CxxTemplate::format(fmt, arg_to_string(args)...); } + std::string buf(Node, Node n) { return np(n); } + std::string slice(Node, Node a, int, int offset, int out_width) { return format("slice<%2>(%0, %1)", a, offset, out_width); } + std::string zero_extend(Node, Node a, int, int out_width) { return format("$zero_extend<%1>(%0)", a, out_width); } + std::string sign_extend(Node, Node a, int, int out_width) { return format("$sign_extend<%1>(%0)", a, out_width); } + std::string concat(Node, Node a, int, Node b, int) { return format("concat(%0, %1)", a, b); } + std::string add(Node, Node a, Node b, int) { return format("$add(%0, %1)", a, b); } + std::string sub(Node, Node a, Node b, int) { return format("$sub(%0, %1)", a, b); } + std::string bitwise_and(Node, Node a, Node b, int) { return format("$and(%0, %1)", a, b); } + std::string bitwise_or(Node, Node a, Node b, int) { return format("$or(%0, %1)", a, b); } + std::string bitwise_xor(Node, Node a, Node b, int) { return format("$xor(%0, %1)", a, b); } + std::string bitwise_not(Node, Node a, int) { return format("$not(%0)", a); } + std::string unary_minus(Node, Node a, int) { return format("$neg(%0)", a); } + std::string reduce_and(Node, Node a, int) { return format("$reduce_and(%0)", a); } + std::string reduce_or(Node, Node a, int) { return format("$reduce_or(%0)", a); } + std::string reduce_xor(Node, Node a, int) { return format("$reduce_xor(%0)", a); } + std::string equal(Node, Node a, Node b, int) { return format("$eq(%0, %1)", a, b); } + std::string not_equal(Node, Node a, Node b, int) { return format("$ne(%0, %1)", a, b); } + std::string signed_greater_than(Node, Node a, Node b, int) { return format("$gt(%0, %1)", a, b); } + std::string signed_greater_equal(Node, Node a, Node b, int) { return format("$ge(%0, %1)", a, b); } + std::string unsigned_greater_than(Node, Node a, Node b, int) { return format("$ugt(%0, %1)", a, b); } + std::string unsigned_greater_equal(Node, Node a, Node b, int) { return format("$uge(%0, %1)", a, b); } + std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("$shl<%2>(%0, %1)", a, b, a.width()); } + std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("$shr<%2>(%0, %1)", a, b, a.width()); } + std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("$asr<%2>(%0, %1)", a, b, a.width()); } + std::string mux(Node, Node a, Node b, Node s, int) { return format("$mux(%0, %1, %2)", a, b, s); } + std::string pmux(Node, Node a, Node b, Node s, int, int) { return format("$pmux(%0, %1, %2)", a, b, s); } + std::string constant(Node, RTLIL::Const value) { return format("$const<%0>(%1)", value.size(), value.as_int()); } + std::string input(Node, IdString name) { return format("input.%0", input_struct[name]); } + std::string state(Node, IdString name) { return format("current_state.%0", state_struct[name]); } + std::string memory_read(Node, Node mem, Node addr, int, int) { return format("$memory_read(%0, %1)", mem, addr); } + std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("$memory_write(%0, %1, %2)", mem, addr, data); } + std::string undriven(Node, int width) { return format("$const<%0>(0)", width); } }; struct FunctionalCxxBackend : public Backend @@ -322,88 +265,24 @@ struct FunctionalCxxBackend : public Backend log("\n"); } - CxxComputeGraph calculate_compute_graph(RTLIL::Module *module) - { - CxxComputeGraph compute_graph; - CxxComputeGraphFactory factory(compute_graph); - ComputeGraphConstruction construction(factory); - construction.add_module(module); - construction.process_queue(); - - // Perform topo sort and detect SCCs - CxxComputeGraph::SccAdaptor compute_graph_scc(compute_graph); - - bool scc = false; - std::vector perm; - topo_sorted_sccs(compute_graph_scc, [&](int *begin, int *end) { - perm.insert(perm.end(), begin, end); - if (end > begin + 1) - { - log_warning("SCC:"); - for (int *i = begin; i != end; ++i) - log(" %d(%s)(%s)", *i, compute_graph[*i].function().name.c_str(), compute_graph[*i].has_sparse_attr() ? compute_graph[*i].sparse_attr().c_str() : ""); - log("\n"); - scc = true; - } - }, /* sources_first */ true); - compute_graph.permute(perm); - if(scc) log_error("combinational loops, aborting\n"); - - // Forward $$buf - std::vector alias; - perm.clear(); - - for (int i = 0; i < compute_graph.size(); ++i) - { - auto node = compute_graph[i]; - if (node.function().name == ID($$buf) && node.arg(0).index() < i) - { - int target_index = alias[node.arg(0).index()]; - auto target_node = compute_graph[perm[target_index]]; - if(!target_node.has_sparse_attr() && node.has_sparse_attr()){ - IdString id = node.sparse_attr(); - target_node.sparse_attr() = id; - } - alias.push_back(target_index); - } - else - { - alias.push_back(GetSize(perm)); - perm.push_back(i); - } - } - compute_graph.permute(perm, alias); - return compute_graph; - } - - void printCxx(std::ostream &stream, std::string, std::string const & name, CxxComputeGraph &compute_graph) + void printCxx(std::ostream &stream, std::string, std::string const & name, Module *module) { - dict inputs, state; + auto ir = FunctionalIR::from_module(module); CxxWriter f(stream); - // Dump the compute graph - for (int i = 0; i < compute_graph.size(); ++i) - { - auto ref = compute_graph[i]; - if(ref.function().name == ID($$input)) - inputs[ref.function().parameters.begin()->first] = ref.function().type; - if(ref.function().name == ID($$state)) - state[ref.function().parameters.begin()->first] = ref.function().type; - } f.printf("#include \"sim.h\"\n"); f.printf("#include \n"); - CxxStruct input_struct(name + "_Inputs", true, inputs.size()); - for (auto const &input : inputs) - input_struct.insert(input.first, input.second); + CxxStruct input_struct(name + "_Inputs", true); + for (auto [name, sort] : ir.inputs()) + input_struct.insert(name, sort); CxxStruct output_struct(name + "_Outputs"); - for (auto const &key : compute_graph.keys()) - if(state.count(key.first) == 0) - output_struct.insert(key.first, compute_graph[key.second].function().type); + for (auto [name, sort] : ir.outputs()) + output_struct.insert(name, sort); CxxStruct state_struct(name + "_State"); - for (auto const &state_var : state) - state_struct.insert(state_var.first, state_var.second); + for (auto [name, sort] : ir.state()) + state_struct.insert(name, sort); - idict node_names; + dict node_names; CxxScope locals; input_struct.print(f); @@ -415,73 +294,17 @@ struct FunctionalCxxBackend : public Backend locals.reserve("output"); locals.reserve("current_state"); locals.reserve("next_state"); - for (int i = 0; i < compute_graph.size(); ++i) - { - auto ref = compute_graph[i]; - auto type = ref.function().type; - std::string name; - if(ref.has_sparse_attr()) - name = locals.insert(ref.sparse_attr()); - else - name = locals.insert("\\n" + std::to_string(i)); - node_names(name); - if(ref.function().name == ID($$input)) - f.printf("\t%s %s = input.%s;\n", type.to_string().c_str(), name.c_str(), input_struct[ref.function().parameters.begin()->first].c_str()); - else if(ref.function().name == ID($$state)) - f.printf("\t%s %s = current_state.%s;\n", type.to_string().c_str(), name.c_str(), state_struct[ref.function().parameters.begin()->first].c_str()); - else if(ref.function().name == ID($$buf)) - f.printf("\t%s %s = %s;\n", type.to_string().c_str(), name.c_str(), node_names[ref.arg(0).index()].c_str()); - else if(ref.function().name == ID($$cell_output)) - f.printf("\t%s %s = %s.%s;\n", type.to_string().c_str(), name.c_str(), node_names[ref.arg(0).index()].c_str(), RTLIL::unescape_id(ref.function().parameters.begin()->first).c_str()); - else if(ref.function().name == ID($$const)){ - auto c = ref.function().parameters.begin()->second; - if(c.size() <= 32){ - f.printf("\t%s %s = $const<%d>(%#x);\n", type.to_string().c_str(), name.c_str(), type.width(), (uint32_t) c.as_int()); - }else{ - f.printf("\t%s %s = $const<%d>({%#x", type.to_string().c_str(), name.c_str(), type.width(), (uint32_t) c.as_int()); - while(c.size() > 32){ - c = c.extract(32, c.size() - 32); - f.printf(", %#x", c.as_int()); - } - f.printf("});\n"); - } - }else if(ref.function().name == ID($$undriven)) - f.printf("\t%s %s; //undriven\n", type.to_string().c_str(), name.c_str()); - else if(ref.function().name == ID($$slice)) - f.printf("\t%s %s = slice<%d>(%s, %d);\n", type.to_string().c_str(), name.c_str(), type.width(), node_names[ref.arg(0).index()].c_str(), ref.function().parameters.at(ID(offset)).as_int()); - else if(ref.function().name == ID($$concat)){ - f.printf("\tauto %s = concat(", name.c_str()); - for (int i = 0, end = ref.size(); i != end; ++i){ - if(i > 0) - f.printf(", "); - f.printf("%s", node_names[ref.arg(i).index()].c_str()); - } - f.printf(");\n"); - }else{ - f.printf("\t"); - f.printf("%s %s = %s", type.to_string().c_str(), name.c_str(), log_id(ref.function().name)); - if(ref.function().parameters.count(ID(WIDTH))){ - f.printf("<%d>", ref.function().parameters.at(ID(WIDTH)).as_int()); - } - f.printf("("); - for (int i = 0, end = ref.size(); i != end; ++i) - f.printf("%s%s", i>0?", ":"", node_names[ref.arg(i).index()].c_str()); - f.printf("); //"); - for (auto const ¶m : ref.function().parameters) - { - if (param.second.empty()) - f.printf("[%s]", log_id(param.first)); - else - f.printf("[%s=%s]", log_id(param.first), log_const(param.second)); - } - f.printf("\n"); - } - } - - for (auto const &key : compute_graph.keys()) + auto node_to_string = [&](FunctionalIR::Node n) { return node_names.at(n.id()); }; + for (auto node : ir) { - f.printf("\t%s.%s = %s;\n", state.count(key.first) > 0 ? "next_state" : "output", state_struct[key.first].c_str(), node_names[key.second].c_str()); + std::string name = locals.insert(node.name()); + node_names.emplace(node.id(), name); + f.printf("\t%s %s = %s;\n", CxxType(node.sort()).to_string().c_str(), name.c_str(), node.visit(CxxPrintVisitor(node_to_string, input_struct, state_struct)).c_str()); } + for (auto [name, sort] : ir.state()) + f.printf("\tnext_state.%s = %s;\n", state_struct[name].c_str(), node_to_string(ir.get_state_next_node(name)).c_str()); + for (auto [name, sort] : ir.outputs()) + f.printf("\toutput.%s = %s;\n", output_struct[name].c_str(), node_to_string(ir.get_output_node(name)).c_str()); f.printf("}\n"); } @@ -494,8 +317,7 @@ struct FunctionalCxxBackend : public Backend for (auto module : design->selected_modules()) { log("Dumping module `%s'.\n", module->name.c_str()); - auto compute_graph = calculate_compute_graph(module); - printCxx(*f, filename, RTLIL::unescape_id(module->name), compute_graph); + printCxx(*f, filename, RTLIL::unescape_id(module->name), module); } } } FunctionalCxxBackend; diff --git a/backends/functional/test_generic.cc b/backends/functional/test_generic.cc new file mode 100644 index 00000000000..172ac7fd6f5 --- /dev/null +++ b/backends/functional/test_generic.cc @@ -0,0 +1,56 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2024 Emily Schmidt + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#include "kernel/yosys.h" +#include "kernel/functionalir.h" + +USING_YOSYS_NAMESPACE +PRIVATE_NAMESPACE_BEGIN + +struct FunctionalTestGeneric : public Pass +{ + FunctionalTestGeneric() : Pass("test_generic", "test the generic compute graph") {} + + void help() override + { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + } + + void execute(std::vector args, RTLIL::Design *design) override + { + log_header(design, "Executing Test Generic.\n"); + + size_t argidx = 1; + extra_args(args, argidx, design); + + for (auto module : design->selected_modules()) { + log("Dumping module `%s'.\n", module->name.c_str()); + auto fir = FunctionalIR::from_module(module); + for(auto node : fir) + std::cout << RTLIL::unescape_id(node.name()) << " = " << node.to_string([](auto n) { return RTLIL::unescape_id(n.name()); }) << "\n"; + for(auto [name, sort] : fir.outputs()) + std::cout << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_output_node(name).name()) << "\n"; + for(auto [name, sort] : fir.state()) + std::cout << RTLIL::unescape_id(name) << " = " << RTLIL::unescape_id(fir.get_state_next_node(name).name()) << "\n"; + } + } +} FunctionalCxxBackend; + +PRIVATE_NAMESPACE_END diff --git a/kernel/functional.h b/kernel/functional.h index e5ee8824099..09a8826433c 100644 --- a/kernel/functional.h +++ b/kernel/functional.h @@ -123,6 +123,8 @@ struct ComputeGraph Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; } public: + Ref(BaseRef ref) : Ref(ref.graph_, ref.index_) {} + void set_function(Fn const &function) const { deref().fn_index = this->graph_->functions(function); @@ -224,7 +226,7 @@ struct ComputeGraph } template - Ref add(Fn const &function, Attr const &attr, T const &args) + Ref add(Fn const &function, Attr const &attr, T &&args) { Ref added = add(function, attr); for (auto arg : args) @@ -233,7 +235,23 @@ struct ComputeGraph } template - Ref add(Fn const &function, Attr &&attr, T const &args) + Ref add(Fn const &function, Attr &&attr, T &&args) + { + Ref added = add(function, std::move(attr)); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + Ref add(Fn const &function, Attr const &attr, std::initializer_list args) + { + Ref added = add(function, attr); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + Ref add(Fn const &function, Attr &&attr, std::initializer_list args) { Ref added = add(function, std::move(attr)); for (auto arg : args) diff --git a/kernel/graphtools.h b/kernel/functionalir.cc similarity index 88% rename from kernel/graphtools.h rename to kernel/functionalir.cc index 4fb7aacf444..27765ac26c5 100644 --- a/kernel/graphtools.h +++ b/kernel/functionalir.cc @@ -17,15 +17,8 @@ * */ -#ifndef GRAPHTOOLS_H -#define GRAPHTOOLS_H +#include "kernel/functionalir.h" -#include "kernel/yosys.h" -#include "kernel/drivertools.h" -#include "kernel/functional.h" -#include "kernel/mem.h" - -USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN template @@ -196,7 +189,7 @@ class CellSimplifier { }; template -class ComputeGraphConstruction { +class FunctionalIRConstruction { std::deque queue; dict graph_nodes; idict cells; @@ -218,7 +211,7 @@ class ComputeGraphConstruction { return it->second; } public: - ComputeGraphConstruction(Factory &f) : factory(f), simplifier(f) {} + FunctionalIRConstruction(Factory &f) : factory(f), simplifier(f) {} void add_module(Module *module) { driver_map.add(module); @@ -238,8 +231,9 @@ class ComputeGraphConstruction { memories[mem.cell] = &mem; } } - T concatenate_read_results(Mem *mem, vector results) + T concatenate_read_results(Mem *, vector results) { + /* TODO: write code to check that this is ok to do */ if(results.size() == 0) return factory.undriven(0); T node = results[0]; @@ -381,6 +375,60 @@ class ComputeGraphConstruction { } }; -YOSYS_NAMESPACE_END +FunctionalIR FunctionalIR::from_module(Module *module) { + FunctionalIR ir; + auto factory = ir.factory(); + FunctionalIRConstruction ctor(factory); + ctor.add_module(module); + ctor.process_queue(); + ir.topological_sort(); + ir.forward_buf(); + return ir; +} + +void FunctionalIR::topological_sort() { + Graph::SccAdaptor compute_graph_scc(_graph); + bool scc = false; + std::vector perm; + topo_sorted_sccs(compute_graph_scc, [&](int *begin, int *end) { + perm.insert(perm.end(), begin, end); + if (end > begin + 1) + { + log_warning("SCC:"); + for (int *i = begin; i != end; ++i) + log(" %d", *i); + log("\n"); + scc = true; + } + }, /* sources_first */ true); + _graph.permute(perm); + if(scc) log_error("combinational loops, aborting\n"); +} -#endif +void FunctionalIR::forward_buf() { + std::vector perm, alias; + perm.clear(); + + for (int i = 0; i < _graph.size(); ++i) + { + auto node = _graph[i]; + if (node.function().fn() == Fn::buf && node.arg(0).index() < i) + { + int target_index = alias[node.arg(0).index()]; + auto target_node = _graph[perm[target_index]]; + if(!target_node.has_sparse_attr() && node.has_sparse_attr()){ + IdString id = node.sparse_attr(); + target_node.sparse_attr() = id; + } + alias.push_back(target_index); + } + else + { + alias.push_back(GetSize(perm)); + perm.push_back(i); + } + } + _graph.permute(perm, alias); +} + +YOSYS_NAMESPACE_END diff --git a/kernel/functionalir.h b/kernel/functionalir.h new file mode 100644 index 00000000000..7eca9d87f83 --- /dev/null +++ b/kernel/functionalir.h @@ -0,0 +1,381 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2024 Emily Schmidt + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#ifndef FUNCTIONALIR_H +#define FUNCTIONALIR_H + +#include "kernel/yosys.h" +#include "kernel/functional.h" +#include "kernel/drivertools.h" +#include "kernel/mem.h" +#include "kernel/topo_scc.h" + +USING_YOSYS_NAMESPACE +YOSYS_NAMESPACE_BEGIN + +class FunctionalIR { + enum class Fn { + invalid, + buf, + slice, + zero_extend, + sign_extend, + concat, + add, + sub, + bitwise_and, + bitwise_or, + bitwise_xor, + bitwise_not, + reduce_and, + reduce_or, + reduce_xor, + unary_minus, + equal, + not_equal, + signed_greater_than, + signed_greater_equal, + unsigned_greater_than, + unsigned_greater_equal, + logical_shift_left, + logical_shift_right, + arithmetic_shift_right, + mux, + pmux, + constant, + input, + state, + multiple, + undriven, + memory_read, + memory_write + }; +public: + class Sort { + std::variant> _v; + public: + explicit Sort(int width) : _v(width) { } + Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { } + bool is_signal() const { return _v.index() == 0; } + bool is_memory() const { return _v.index() == 1; } + int width() const { return std::get<0>(_v); } + int addr_width() const { return std::get<1>(_v).first; } + int data_width() const { return std::get<1>(_v).second; } + bool operator==(Sort const& other) const { return _v == other._v; } + unsigned int hash() const { return mkhash(_v); } + }; +private: + class NodeData { + Fn _fn; + std::variant< + std::monostate, + RTLIL::Const, + IdString, + int + > _extra; + public: + NodeData() : _fn(Fn::invalid) {} + NodeData(Fn fn) : _fn(fn) {} + template NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward(extra)) {} + Fn fn() const { return _fn; } + const RTLIL::Const &as_const() const { return std::get(_extra); } + IdString as_idstring() const { return std::get(_extra); } + int as_int() const { return std::get(_extra); } + int hash() const { + return mkhash((unsigned int) _fn, mkhash(_extra)); + } + bool operator==(NodeData const &other) const { + return _fn == other._fn && _extra == other._extra; + } + }; + struct Attr { + Sort sort; + }; + using Graph = ComputeGraph>; + Graph _graph; + dict _inputs; + dict _outputs; + dict _state; + void add_input(IdString name, Sort sort) { + auto [it, found] = _inputs.emplace(name, std::move(sort)); + if(found) + log_assert(it->second == sort); + } + void add_state(IdString name, Sort sort) { + auto [it, found] = _state.emplace(name, std::move(sort)); + if(found) + log_assert(it->second == sort); + } + void add_output(IdString name, Sort sort) { + auto [it, found] = _outputs.emplace(name, std::move(sort)); + if(found) + log_assert(it->second == sort); + } +public: + class Factory; + class Node { + friend class Factory; + friend class FunctionalIR; + Graph::Ref _ref; + Node(Graph::Ref ref) : _ref(ref) { } + operator Graph::Ref() { return _ref; } + template struct PrintVisitor { + NodePrinter np; + PrintVisitor(NodePrinter np) : np(np) { } + std::string buf(Node, Node n) { return "buf(" + np(n) + ")"; } + std::string slice(Node, Node a, int, int offset, int out_width) { return "slice(" + np(a) + ", " + std::to_string(offset) + ", " + std::to_string(out_width) + ")"; } + std::string zero_extend(Node, Node a, int, int out_width) { return "zero_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; } + std::string sign_extend(Node, Node a, int, int out_width) { return "sign_extend(" + np(a) + ", " + std::to_string(out_width) + ")"; } + std::string concat(Node, Node a, int, Node b, int) { return "concat(" + np(a) + ", " + np(b) + ")"; } + std::string add(Node, Node a, Node b, int) { return "add(" + np(a) + ", " + np(b) + ")"; } + std::string sub(Node, Node a, Node b, int) { return "sub(" + np(a) + ", " + np(b) + ")"; } + std::string bitwise_and(Node, Node a, Node b, int) { return "bitwise_and(" + np(a) + ", " + np(b) + ")"; } + std::string bitwise_or(Node, Node a, Node b, int) { return "bitwise_or(" + np(a) + ", " + np(b) + ")"; } + std::string bitwise_xor(Node, Node a, Node b, int) { return "bitwise_xor(" + np(a) + ", " + np(b) + ")"; } + std::string bitwise_not(Node, Node a, int) { return "bitwise_not(" + np(a) + ")"; } + std::string unary_minus(Node, Node a, int) { return "unary_minus(" + np(a) + ")"; } + std::string reduce_and(Node, Node a, int) { return "reduce_and(" + np(a) + ")"; } + std::string reduce_or(Node, Node a, int) { return "reduce_or(" + np(a) + ")"; } + std::string reduce_xor(Node, Node a, int) { return "reduce_xor(" + np(a) + ")"; } + std::string equal(Node, Node a, Node b, int) { return "equal(" + np(a) + ", " + np(b) + ")"; } + std::string not_equal(Node, Node a, Node b, int) { return "not_equal(" + np(a) + ", " + np(b) + ")"; } + std::string signed_greater_than(Node, Node a, Node b, int) { return "signed_greater_than(" + np(a) + ", " + np(b) + ")"; } + std::string signed_greater_equal(Node, Node a, Node b, int) { return "signed_greater_equal(" + np(a) + ", " + np(b) + ")"; } + std::string unsigned_greater_than(Node, Node a, Node b, int) { return "unsigned_greater_than(" + np(a) + ", " + np(b) + ")"; } + std::string unsigned_greater_equal(Node, Node a, Node b, int) { return "unsigned_greater_equal(" + np(a) + ", " + np(b) + ")"; } + std::string logical_shift_left(Node, Node a, Node b, int, int) { return "logical_shift_left(" + np(a) + ", " + np(b) + ")"; } + std::string logical_shift_right(Node, Node a, Node b, int, int) { return "logical_shift_right(" + np(a) + ", " + np(b) + ")"; } + std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return "arithmetic_shift_right(" + np(a) + ", " + np(b) + ")"; } + std::string mux(Node, Node a, Node b, Node s, int) { return "mux(" + np(a) + ", " + np(b) + ", " + np(s) + ")"; } + std::string pmux(Node, Node a, Node b, Node s, int, int) { return "pmux(" + np(a) + ", " + np(b) + ", " + np(s) + ")"; } + std::string constant(Node, RTLIL::Const value) { return "constant(" + value.as_string() + ")"; } + std::string input(Node, IdString name) { return "input(" + name.str() + ")"; } + std::string state(Node, IdString name) { return "state(" + name.str() + ")"; } + std::string memory_read(Node, Node mem, Node addr, int, int) { return "memory_read(" + np(mem) + ", " + np(addr) + ")"; } + std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return "memory_write(" + np(mem) + ", " + np(addr) + ", " + np(data) + ")"; } + std::string undriven(Node, int width) { return "undriven(" + std::to_string(width) + ")"; } + }; + public: + int id() const { return _ref.index(); } + IdString name() const { + if(_ref.has_sparse_attr()) + return _ref.sparse_attr(); + else + return std::string("\\n") + std::to_string(id()); + } + Sort sort() const { return _ref.attr().sort; } + int width() const { return sort().width(); } + Node arg(int n) const { return Node(_ref.arg(n)); } + template auto visit(Visitor v) const + { + switch(_ref.function().fn()) { + case Fn::invalid: log_error("invalid node in visit"); break; + case Fn::buf: return v.buf(*this, arg(0)); break; + case Fn::slice: return v.slice(*this, arg(0), arg(0).width(), _ref.function().as_int(), sort().width()); break; + case Fn::zero_extend: return v.zero_extend(*this, arg(0), arg(0).width(), width()); break; + case Fn::sign_extend: return v.sign_extend(*this, arg(0), arg(0).width(), width()); break; + case Fn::concat: return v.concat(*this, arg(0), arg(0).width(), arg(1), arg(1).width()); break; + case Fn::add: return v.add(*this, arg(0), arg(1), sort().width()); break; + case Fn::sub: return v.sub(*this, arg(0), arg(1), sort().width()); break; + case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1), sort().width()); break; + case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1), sort().width()); break; + case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1), sort().width()); break; + case Fn::bitwise_not: return v.bitwise_not(*this, arg(0), sort().width()); break; + case Fn::unary_minus: return v.bitwise_not(*this, arg(0), sort().width()); break; + case Fn::reduce_and: return v.reduce_and(*this, arg(0), arg(0).width()); break; + case Fn::reduce_or: return v.reduce_or(*this, arg(0), arg(0).width()); break; + case Fn::reduce_xor: return v.reduce_xor(*this, arg(0), arg(0).width()); break; + case Fn::equal: return v.equal(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1), arg(0).width()); break; + case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1), arg(0).width(), arg(1).width()); break; + case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1), arg(0).width(), arg(1).width()); break; + case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1), arg(0).width(), arg(1).width()); break; + case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2), arg(0).width()); break; + case Fn::pmux: return v.pmux(*this, arg(0), arg(1), arg(2), arg(0).width(), arg(2).width()); break; + case Fn::constant: return v.constant(*this, _ref.function().as_const()); break; + case Fn::input: return v.input(*this, _ref.function().as_idstring()); break; + case Fn::state: return v.state(*this, _ref.function().as_idstring()); break; + case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1), arg(1).width(), width()); break; + case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2), arg(1).width(), arg(2).width()); break; + case Fn::multiple: log_error("multiple in visit"); break; + case Fn::undriven: return v.undriven(*this, width()); break; + } + } + template std::string to_string(NodePrinter np) + { + return visit(PrintVisitor(np)); + } + /* TODO: delete */ int size() const { return sort().width(); } + }; + class Factory { + FunctionalIR &_ir; + friend class FunctionalIR; + explicit Factory(FunctionalIR &ir) : _ir(ir) {} + Node add(NodeData &&fn, Sort &&sort, std::initializer_list args) { + Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)}); + for (auto arg : args) + ref.append_arg(Graph::Ref(arg)); + return ref; + } + void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); } + void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); } + void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } + public: + Node slice(Node a, int, int offset, int out_width) { + log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); + return add(NodeData(Fn::slice, offset), Sort(out_width), {a}); + } + Node extend(Node a, int, int out_width, bool is_signed) { + log_assert(a.sort().is_signal() && a.sort().width() < out_width); + if(is_signed) + return add(Fn::sign_extend, Sort(out_width), {a}); + else + return add(Fn::zero_extend, Sort(out_width), {a}); + } + Node concat(Node a, int, Node b, int) { + log_assert(a.sort().is_signal() && b.sort().is_signal()); + return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); + } + Node add(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } + Node sub(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } + Node bitwise_and(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } + Node bitwise_or(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } + Node bitwise_xor(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } + Node bitwise_not(Node a, int) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } + Node unary_minus(Node a, int) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } + Node reduce_and(Node a, int) { check_unary(a); return add(Fn::reduce_and, Sort(1), {a}); } + Node reduce_or(Node a, int) { check_unary(a); return add(Fn::reduce_or, Sort(1), {a}); } + Node reduce_xor(Node a, int) { check_unary(a); return add(Fn::reduce_xor, Sort(1), {a}); } + Node equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } + Node not_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } + Node signed_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } + Node signed_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } + Node unsigned_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } + Node unsigned_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } + Node logical_shift_left(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } + Node logical_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } + Node arithmetic_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } + Node mux(Node a, Node b, Node s, int) { + log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); + return add(Fn::mux, a.sort(), {a, b, s}); + } + Node pmux(Node a, Node b, Node s, int, int) { + log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width()); + return add(Fn::pmux, a.sort(), {a, b, s}); + } + Node memory_read(Node mem, Node addr, int, int) { + log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); + return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); + } + Node memory_write(Node mem, Node addr, Node data, int, int) { + log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && + mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); + return add(Fn::memory_write, mem.sort(), {mem, addr, data}); + } + Node constant(RTLIL::Const value) { + return add(NodeData(Fn::constant, std::move(value)), Sort(value.size()), {}); + } + Node create_pending(int width) { + return add(Fn::buf, Sort(width), {}); + } + void update_pending(Node node, Node value) { + log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0 && node.sort() == value.sort()); + node._ref.append_arg(value._ref); + } + Node input(IdString name, int width) { + _ir.add_input(name, Sort(width)); + return add(NodeData(Fn::input, name), Sort(width), {}); + } + Node state(IdString name, int width) { + _ir.add_state(name, Sort(width)); + return add(NodeData(Fn::state, name), Sort(width), {}); + } + Node state_memory(IdString name, int addr_width, int data_width) { + _ir.add_state(name, Sort(addr_width, data_width)); + return add(NodeData(Fn::state, name), Sort(addr_width, data_width), {}); + } + Node cell_output(Node node, IdString, IdString, int) { + return node; + } + Node multiple(vector args, int width) { + auto node = add(Fn::multiple, Sort(width), {}); + for(const auto &arg : args) + node._ref.append_arg(arg._ref); + return node; + } + Node undriven(int width) { + return add(Fn::undriven, Sort(width), {}); + } + void declare_output(Node node, IdString name, int width) { + _ir.add_output(name, Sort(width)); + node._ref.assign_key({name, false}); + } + void declare_state(Node node, IdString name, int width) { + _ir.add_state(name, Sort(width)); + node._ref.assign_key({name, true}); + } + void declare_state_memory(Node node, IdString name, int addr_width, int data_width) { + _ir.add_state(name, Sort(addr_width, data_width)); + node._ref.assign_key({name, true}); + } + void suggest_name(Node node, IdString name) { + node._ref.sparse_attr() = name; + } + + /* TODO delete this later*/ + Node eq(Node a, Node b, int) { return equal(a, b, 0); } + Node ne(Node a, Node b, int) { return not_equal(a, b, 0); } + Node gt(Node a, Node b, int) { return signed_greater_than(a, b, 0); } + Node ge(Node a, Node b, int) { return signed_greater_equal(a, b, 0); } + Node ugt(Node a, Node b, int) { return unsigned_greater_than(a, b, 0); } + Node uge(Node a, Node b, int) { return unsigned_greater_equal(a, b, 0); } + Node neg(Node a, int) { return unary_minus(a, 0); } + }; + static FunctionalIR from_module(Module *module); + Factory factory() { return Factory(*this); } + int size() const { return _graph.size(); } + Node operator[](int i) { return _graph[i]; } + void topological_sort(); + void forward_buf(); + dict inputs() const { return _inputs; } + dict outputs() const { return _outputs; } + dict state() const { return _state; } + Node get_output_node(IdString name) { return _graph({name, false}); } + Node get_state_next_node(IdString name) { return _graph({name, true}); } + class Iterator { + friend class FunctionalIR; + FunctionalIR &_ir; + int _index; + Iterator(FunctionalIR &ir, int index) : _ir(ir), _index(index) {} + public: + Node operator*() { return _ir._graph[_index]; } + Iterator &operator++() { _index++; return *this; } + bool operator!=(Iterator const &other) const { return _index != other._index; } + }; + Iterator begin() { return Iterator(*this, 0); } + Iterator end() { return Iterator(*this, _graph.size()); } +}; + +YOSYS_NAMESPACE_END + +#endif