Skip to content

Commit

Permalink
smtr: Make Rosette compatible
Browse files Browse the repository at this point in the history
Convert most of the operators, except pmux and memory.
Convert formatting for non-stateful modules.
  • Loading branch information
KrystalDelusion committed Jul 13, 2024
1 parent 42e54c3 commit 5468bdf
Showing 1 changed file with 77 additions and 124 deletions.
201 changes: 77 additions & 124 deletions backends/functional/smtlib_rosette.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ template <class NodeNames> struct SmtPrintVisitor {

std::string slice(Node, Node a, int, int offset, int out_width)
{
return format("((_ extract %2 %1) %0)", np(a), offset, offset + out_width - 1);
return format("(extract %2 %1 %0)", np(a), offset, offset + out_width - 1);
}

std::string zero_extend(Node, Node a, int, int out_width) { return format("((_ zero_extend %1) %0)", np(a), out_width - a.width()); }
std::string zero_extend(Node, Node a, int, int out_width) { return format("(zero-extend %0 (bitvector %1))", np(a), out_width); }

std::string sign_extend(Node, Node a, int, int out_width) { return format("((_ sign_extend %1) %0)", np(a), out_width - a.width()); }
std::string sign_extend(Node, Node a, int, int out_width) { return format("(sign-extend %0 (bitvector %1))", np(a), out_width); }

std::string concat(Node, Node a, int, Node b, int) { return format("(concat %0 %1)", np(a), np(b)); }

Expand All @@ -136,137 +136,64 @@ template <class NodeNames> struct SmtPrintVisitor {

std::string unary_minus(Node, Node a, int) { return format("(bvneg %0)", np(a)); }

std::string reduce_and(Node, Node a, int) {
std::stringstream ss;
// We use ite to set the result to bit vector, to ensure appropriate type
ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '1') << ") #b1 #b0)";
return ss.str();
}
std::string reduce_and(Node, Node a, int) { return format("(apply bvand (bitvector->bits %0))", np(a)); }

std::string reduce_or(Node, Node a, int)
{
std::stringstream ss;
// We use ite to set the result to bit vector, to ensure appropriate type
ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '0') << ") #b0 #b1)";
return ss.str();
}
std::string reduce_or(Node, Node a, int) { return format("(apply bvor (bitvector->bits %0))", np(a)); }

std::string reduce_xor(Node, Node a, int) {
std::stringstream ss;
ss << "(bvxor ";
for (int i = 0; i < a.width(); ++i) {
if (i > 0) ss << " ";
ss << "((_ extract " << i << " " << i << ") " << np(a) << ")";
}
ss << ")";
return ss.str();
}
std::string reduce_xor(Node, Node a, int) { return format("(apply bvxor (bitvector->bits %0))", np(a)); }

std::string equal(Node, Node a, Node b, int) {
return format("(ite (= %0 %1) #b1 #b0)", np(a), np(b));
return format("(bool->bitvector (bveq %0 %1))", np(a), np(b));
}

std::string not_equal(Node, Node a, Node b, int) {
return format("(ite (distinct %0 %1) #b1 #b0)", np(a), np(b));
return format("(bool->bitvector (not (bveq %0 %1)))", np(a), np(b));
}

std::string signed_greater_than(Node, Node a, Node b, int) {
return format("(ite (bvsgt %0 %1) #b1 #b0)", np(a), np(b));
return format("(bool->bitvector (bvsgt %0 %1))", np(a), np(b));
}

std::string signed_greater_equal(Node, Node a, Node b, int) {
return format("(ite (bvsge %0 %1) #b1 #b0)", np(a), np(b));
return format("(bool->bitvector (bvsge %0 %1))", np(a), np(b));
}

std::string unsigned_greater_than(Node, Node a, Node b, int) {
return format("(ite (bvugt %0 %1) #b1 #b0)", np(a), np(b));
return format("(bool->bitvector (bvugt %0 %1))", np(a), np(b));
}

std::string unsigned_greater_equal(Node, Node a, Node b, int) {
return format("(ite (bvuge %0 %1) #b1 #b0)", np(a), np(b));
}

std::string logical_shift_left(Node, Node a, Node b, int, int) {
// Get the bit-widths of a and b
int bit_width_a = a.width();
int bit_width_b = b.width();

// Extend b to match the bit-width of a if necessary
std::ostringstream oss;
if (bit_width_a > bit_width_b) {
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
} else {
oss << np(b); // No extension needed if b's width is already sufficient
}
std::string b_extended = oss.str();

// Format the bvshl operation with the extended b
oss.str(""); // Clear the stringstream
oss << "(bvshl " << np(a) << " " << b_extended << ")";
return oss.str();
}

std::string logical_shift_right(Node, Node a, Node b, int, int) {
// Get the bit-widths of a and b
int bit_width_a = a.width();
int bit_width_b = b.width();

// Extend b to match the bit-width of a if necessary
std::ostringstream oss;
if (bit_width_a > bit_width_b) {
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
} else {
oss << np(b); // No extension needed if b's width is already sufficient
}
std::string b_extended = oss.str();

// Format the bvlshr operation with the extended b
oss.str(""); // Clear the stringstream
oss << "(bvlshr " << np(a) << " " << b_extended << ")";
return oss.str();
return format("(bool->bitvector (bvuge %0 %1))", np(a), np(b));
}

std::string arithmetic_shift_right(Node, Node a, Node b, int, int) {
// Get the bit-widths of a and b
int bit_width_a = a.width();
int bit_width_b = b.width();
std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("(bvshl %0 %1)", np(a), np(b)); }

// Extend b to match the bit-width of a if necessary
std::ostringstream oss;
if (bit_width_a > bit_width_b) {
oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")";
} else {
oss << np(b); // No extension needed if b's width is already sufficient
}
std::string b_extended = oss.str();
std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("(bvlshr %0 %1)", np(a), np(b)); }

// Format the bvashr operation with the extended b
oss.str(""); // Clear the stringstream
oss << "(bvashr " << np(a) << " " << b_extended << ")";
return oss.str();
}
std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("(bvashr %0 %1)", np(a), np(b)); }

std::string mux(Node, Node a, Node b, Node s, int) {
return format("(ite (= %2 #b1) %0 %1)", np(a), np(b), np(s));
}
std::string mux(Node, Node a, Node b, Node s, int) { return format("(if %2 %0 %1)", np(a), np(b), np(s)); }

// How does pmux?
std::string pmux(Node, Node a, Node b, Node s, int, int)
{
// Assume s is a bit vector, combine a and b based on the selection bits
return format("(pmux %0 %1 %2)", np(a), np(b), np(s));
}

std::string constant(Node, RTLIL::Const value) { return format("#b%0", value.as_string()); }
std::string constant(Node, RTLIL::Const value) { return format("(bv #b%0 %1)", value.as_string(), value.size()); }

std::string input(Node, IdString name) { return format("%0", scope[name]); }

// How does state?
std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); }

// How does memory?
std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); }

std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("(store %0 %1 %2)", np(mem), np(addr), np(data)); }

std::string undriven(Node, int width) { return format("#b%0", std::string(width, '0')); }
std::string undriven(Node, int width) { return format("(bv 0 %0)", width); }
};

struct SmtModule {
Expand All @@ -281,23 +208,46 @@ struct SmtModule {
const bool stateful = ir.state().size() != 0;
SmtWriter writer(out);

writer.print("(declare-fun %s () Bool)\n\n", name.c_str());
// Rosette lang header
writer.print("#lang rosette\n\n");
std::string end_part = "\n";
std::string indent = "\t";

// Not sure if this is actually necessary or not, so make it optional I guess?
bool guarded = true;

writer.print("(declare-datatypes () ((Inputs (mk_inputs");
// ???
// writer.print("(declare-fun %s () Bool)\n\n", name.c_str());

// Inputs
std::stringstream input_list;
std::stringstream input_values;
for (const auto &input : ir.inputs()) {
std::string input_name = scope[input.first];
writer.print(" (%s (_ BitVec %d))", input_name.c_str(), input.second.width());
auto input_name = scope[input.first];
input_list << input_name << " ";
if (guarded) {
input_values << end_part << indent << indent << indent;
auto width = input.second.width();
input_values << "(extract " << width-1 << " 0 (concat (bv 0 " << width << ") " << input_name << "))";
}
}
writer.print("))))\n\n");
writer.print("(struct Inputs (%s)", input_list.str().c_str());
if (guarded) {
writer.print("%s%s#:guard (lambda (%sname)%s", end_part.c_str(), indent.c_str(), input_list.str().c_str(), end_part.c_str());
writer.print("%s%s(values%s))", indent.c_str(), indent.c_str(), input_values.str().c_str());
}
writer.print(")\n");

writer.print("(declare-datatypes () ((Outputs (mk_outputs");
// Outputs
writer.print("(struct Outputs (");
for (const auto &output : ir.outputs()) {
std::string output_name = scope[output.first];
writer.print(" (%s (_ BitVec %d))", output_name.c_str(), output.second.width());
auto output_name = scope[output.first];
writer.print("%s ", output_name.c_str());
}
writer.print("))))\n");
writer.print("))\n");

if (stateful) {
// ?
writer.print("(declare-datatypes () ((State (mk_state");
for (const auto &state : ir.state()) {
std::string state_name = scope[state.first];
Expand All @@ -308,21 +258,21 @@ struct SmtModule {
writer.print("(declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n");
}

if (stateful)
writer.print("(define-fun %s_step ((current_state State) (inputs Inputs)) Pair", name.c_str());
else
writer.print("(define-fun %s_step ((inputs Inputs)) Outputs", name.c_str());
// Function start
writer.print("(define (%s_step inputs)%s", name.c_str(), end_part.c_str());

writer.print(" (let (");
// Bind inputs
writer.print("%s(let (", indent.c_str());
for (const auto &input : ir.inputs()) {
std::string input_name = scope[input.first];
writer.print(" (%s (%s inputs))", input_name.c_str(), input_name.c_str());
auto input_name = scope[input.first];
writer.print("[%s (Inputs-%s inputs)] ", input_name.c_str(), input_name.c_str());
}
writer.print(" )");
writer.print(")");

auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name()]; };
SmtPrintVisitor<decltype(node_to_string)> visitor(node_to_string, scope);

// Bind operators
for (auto it = ir.begin(); it != ir.end(); ++it) {
const FunctionalIR::Node &node = *it;

Expand All @@ -332,10 +282,12 @@ struct SmtModule {
std::string node_name = scope[node.name()];
std::string node_expr = node.visit(visitor);

writer.print(" (let ( (%s %s))", node_name.c_str(), node_expr.c_str());
writer.print(" (let ([%s %s])", node_name.c_str(), node_expr.c_str());
}

// Bind next state
if (stateful) {
// ?
writer.print(" (let ( (next_state (mk_state ");
for (const auto &state : ir.state()) {
std::string state_name = scope[state.first];
Expand All @@ -345,7 +297,9 @@ struct SmtModule {
writer.print(" )))");
}

// Bind outputs
if (stateful) {
// ?
writer.print(" (let ( (outputs (mk_outputs ");
for (const auto &output : ir.outputs()) {
std::string output_name = scope[output.first];
Expand All @@ -354,27 +308,26 @@ struct SmtModule {
writer.print(" )))");

writer.print("(mk-pair outputs next_state)");
writer.print(" )"); // Closing outputs let statement
writer.print(" )"); // Closing next_state let statement
}
else {
writer.print(" (mk_outputs ");
writer.print(" (Outputs ");
for (const auto &output : ir.outputs()) {
std::string output_name = scope[output.first];
writer.print(" %s", output_name.c_str());
auto output_name = scope[output.first];
writer.print("%s ", output_name.c_str());
}
writer.print(" )"); // Closing mk_outputs
}
if (stateful) {
writer.print(" )"); // Closing outputs let statement
writer.print(" )"); // Closing next_state let statement
writer.print(")"); // Closing outputs
}

// Close the nested lets
for (size_t i = 0; i < ir.size() - ir.inputs().size(); ++i) {
writer.print(" )"); // Closing each node
for (auto i = ir.inputs().size(); i < ir.size(); ++i) {
writer.print(")"); // Closing each node
}
if (ir.size() == ir.inputs().size())
writer.print(" )"); // Corner case
writer.print(")"); // Corner case

writer.print(" )"); // Closing inputs let statement
writer.print(")"); // Closing inputs let statement
writer.print(")\n"); // Closing step function
}
};
Expand Down

0 comments on commit 5468bdf

Please sign in to comment.