Skip to content

Commit

Permalink
Support bytes concat
Browse files Browse the repository at this point in the history
  • Loading branch information
pgebal committed Nov 17, 2023
1 parent 58811f1 commit effab73
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Language Features:


Compiler Features:

* SMTChecker: Support `bytes.concat` function.

Bugfixes:

Expand Down
8 changes: 4 additions & 4 deletions libsolidity/formal/CHC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co
m_currentContract
);
auto postCallState = std::vector<smtutil::Expression>{state().state()} + currentStateVariables(_contract);
std::vector<smtutil::Expression> stateExprs{error, address, state().abi(), state().crypto()};
std::vector<smtutil::Expression> stateExprs{error, address, state().abi(), state().bytesConcat(), state().crypto()};

auto nondet = (*m_nondetInterfaces.at(&_contract))(stateExprs + preCallState + postCallState);
auto nondetCall = callPredicate(stateExprs + preCallState + postCallState);
Expand Down Expand Up @@ -1033,7 +1033,7 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall)
&_funCall
);
auto postCallState = std::vector<smtutil::Expression>{state().state()} + currentStateVariables();
std::vector<smtutil::Expression> stateExprs{error, state().thisAddress(), state().abi(), state().crypto()};
std::vector<smtutil::Expression> stateExprs{error, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto()};

auto nondet = (*m_nondetInterfaces.at(m_currentContract))(stateExprs + preCallState + postCallState);
auto nondetCall = callPredicate(stateExprs + preCallState + postCallState);
Expand Down Expand Up @@ -1464,7 +1464,7 @@ void CHC::defineInterfacesAndSummaries(SourceUnit const& _source)
auto errorPost = errorFlag().increaseIndex();
auto nondetPost = smt::nondetInterface(iface, *contract, m_context, 0, 2);

std::vector<smtutil::Expression> args{errorPost, state().thisAddress(), state().abi(), state().crypto(), state().tx(), state().state(1)};
std::vector<smtutil::Expression> args{errorPost, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state(1)};
args += state1 +
applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) +
std::vector<smtutil::Expression>{state().state(2)} +
Expand Down Expand Up @@ -1829,7 +1829,7 @@ smtutil::Expression CHC::predicate(

errorFlag().increaseIndex();

std::vector<smtutil::Expression> args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().crypto(), state().tx(), state().state()};
std::vector<smtutil::Expression> args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state()};

auto const* contract = _funDef->annotation().contract;
auto const& hierarchy = m_currentContract->annotation().linearizedBaseContracts;
Expand Down
16 changes: 8 additions & 8 deletions libsolidity/formal/Predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ std::string Predicate::formatSummaryCall(
return {};
}

/// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockChainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockChainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// Here we are interested in preInputVars to format the function call.

std::string txModel;
Expand Down Expand Up @@ -337,8 +337,8 @@ std::string Predicate::formatSummaryCall(

std::vector<std::optional<std::string>> Predicate::summaryStateValues(std::vector<smtutil::Expression> const& _args) const
{
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// The signature of the summary predicate of a contract without constructor is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, postBlockchainState, preStateVars, postStateVars).
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// The signature of the summary predicate of a contract without constructor is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, postBlockchainState, preStateVars, postStateVars).
/// Here we are interested in postStateVars.
auto stateVars = stateVariables();
solAssert(stateVars.has_value(), "");
Expand Down Expand Up @@ -371,7 +371,7 @@ std::vector<std::optional<std::string>> Predicate::summaryStateValues(std::vecto

std::vector<std::optional<std::string>> Predicate::summaryPostInputValues(std::vector<smtutil::Expression> const& _args) const
{
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// Here we are interested in postInputVars.
auto const* function = programFunction();
solAssert(function, "");
Expand All @@ -395,7 +395,7 @@ std::vector<std::optional<std::string>> Predicate::summaryPostInputValues(std::v

std::vector<std::optional<std::string>> Predicate::summaryPostOutputValues(std::vector<smtutil::Expression> const& _args) const
{
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars).
/// Here we are interested in outputVars.
auto const* function = programFunction();
solAssert(function, "");
Expand All @@ -418,7 +418,7 @@ std::vector<std::optional<std::string>> Predicate::summaryPostOutputValues(std::
std::pair<std::vector<std::optional<std::string>>, std::vector<VariableDeclaration const*>> Predicate::localVariableValues(std::vector<smtutil::Expression> const& _args) const
{
/// The signature of a local block predicate is:
/// block(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars, localVars).
/// block(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars, localVars).
/// Here we are interested in localVars.
auto const* function = programFunction();
solAssert(function, "");
Expand Down Expand Up @@ -452,7 +452,7 @@ std::map<std::string, std::string> Predicate::expressionSubstitution(smtutil::Ex
auto nArgs = _predExpr.arguments.size();

// The signature of an interface predicate is
// interface(this, abiFunctions, cryptoFunctions, blockchainState, stateVariables).
// interface(this, abiFunctions, bytesConcatFunctions, cryptoFunctions, blockchainState, stateVariables).
// An invariant for an interface predicate is a contract
// invariant over its state, for example `x <= 0`.
if (isInterface())
Expand All @@ -464,7 +464,7 @@ std::map<std::string, std::string> Predicate::expressionSubstitution(smtutil::Ex
subst[_predExpr.arguments.at(i).name] = stateVars.at(i - 4)->name();
}
// The signature of a nondet interface predicate is
// nondet_interface(error, this, abiFunctions, cryptoFunctions, blockchainState, stateVariables, blockchainState', stateVariables').
// nondet_interface(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, blockchainState, stateVariables, blockchainState', stateVariables').
// An invariant for a nondet interface predicate is a reentrancy property
// over the pre and post state variables of a contract, where pre state vars
// are represented by the variable's name and post state vars are represented
Expand Down
19 changes: 12 additions & 7 deletions libsolidity/formal/PredicateInstance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ namespace solidity::frontend::smt
smtutil::Expression interfacePre(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context)
{
auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0), state.abi(0), state.crypto(0), state.state(0)};
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.state(0)};
return _pred(stateExprs + initialStateVariables(_contract, _context));
}

smtutil::Expression interface(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context)
{
auto const& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0), state.abi(0), state.crypto(0), state.state()};
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.state()};
return _pred(stateExprs + currentStateVariables(_contract, _context));
}

Expand All @@ -48,7 +50,8 @@ smtutil::Expression nondetInterface(
unsigned _postIdx)
{
auto const& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(), state.abi(), state.crypto()};
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(),
state.abi(), state.bytesConcat(0), state.crypto()};
return _pred(
stateExprs +
std::vector<smtutil::Expression>{_context.state().state(_preIdx)} +
Expand All @@ -65,7 +68,7 @@ smtutil::Expression constructor(Predicate const& _pred, EncodingContext& _contex
return _pred(currentFunctionVariablesForDefinition(*constructor, &contract, _context));

auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.crypto(0), state.tx(0), state.state(0), state.state()};
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0), state.state()};
return _pred(stateExprs + initialStateVariables(contract, _context) + currentStateVariables(contract, _context));
}

Expand All @@ -76,7 +79,7 @@ smtutil::Expression constructorCall(Predicate const& _pred, EncodingContext& _co
return _pred(currentFunctionVariablesForCall(*constructor, &contract, _context, _internal));

auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
state.newState();
stateExprs += std::vector<smtutil::Expression>{state.state()};
stateExprs += currentStateVariables(contract, _context);
Expand Down Expand Up @@ -152,7 +155,8 @@ std::vector<smtutil::Expression> currentFunctionVariablesForDefinition(
)
{
auto& state = _context.state();
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.crypto(0), state.tx(0), state.state(0)};
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0)};
exprs += _contract ? initialStateVariables(*_contract, _context) : std::vector<smtutil::Expression>{};
exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->valueAtIndex(0); });
exprs += std::vector<smtutil::Expression>{state.state()};
Expand All @@ -170,7 +174,8 @@ std::vector<smtutil::Expression> currentFunctionVariablesForCall(
)
{
auto& state = _context.state();
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(),
state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
exprs += _contract ? currentStateVariables(*_contract, _context) : std::vector<smtutil::Expression>{};
exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->currentValue(); });

Expand Down
8 changes: 4 additions & 4 deletions libsolidity/formal/PredicateSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace solidity::frontend::smt
SortPointer interfaceSort(ContractDefinition const& _contract, SymbolicState& _state)
{
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract),
std::vector<SortPointer>{_state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract),
SortProvider::boolSort
);
}
Expand All @@ -40,7 +40,7 @@ SortPointer nondetInterfaceSort(ContractDefinition const& _contract, SymbolicSta
auto varSorts = stateSorts(_contract);
std::vector<SortPointer> stateSort{_state.stateSort()};
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort()} +
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()} +
stateSort +
varSorts +
stateSort +
Expand All @@ -57,7 +57,7 @@ SortPointer constructorSort(ContractDefinition const& _contract, SymbolicState&
auto varSorts = stateSorts(_contract);
std::vector<SortPointer> stateSort{_state.stateSort()};
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts,
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts,
SortProvider::boolSort
);
}
Expand All @@ -69,7 +69,7 @@ SortPointer functionSort(FunctionDefinition const& _function, ContractDefinition
auto inputSorts = applyMap(_function.parameters(), smtSort);
auto outputSorts = applyMap(_function.returnParameters(), smtSort);
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} +
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} +
varSorts +
inputSorts +
std::vector<SortPointer>{_state.stateSort()} +
Expand Down
66 changes: 65 additions & 1 deletion libsolidity/formal/SMTEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ void SMTEncoder::endVisit(FunctionCall const& _funCall)
if (publicGetter(_funCall.expression()))
visitPublicGetter(_funCall);
break;
case FunctionType::Kind::BytesConcat:
visitBytesConcat(_funCall);
break;
case FunctionType::Kind::ABIDecode:
case FunctionType::Kind::ABIEncode:
case FunctionType::Kind::ABIEncodePacked:
Expand Down Expand Up @@ -780,6 +783,44 @@ void SMTEncoder::visitRequire(FunctionCall const& _funCall)
addPathImpliedExpression(expr(*args.front()));
}

void SMTEncoder::visitBytesConcat(FunctionCall const& _funCall)
{
auto symbFunction = state().bytesConcatFunction(&_funCall);
auto const& [name, inTypes, outType] = state().bytesConcatFunctionTypes(&_funCall);
auto const& args = _funCall.sortedArguments();

solAssert(inTypes.size() == args.size(), "");

// bytes.concat call with no arguments returns an empty array
if (args.size()== 0)
{
defineExpr(_funCall, smt::zeroValue(TypeProvider::bytesMemory()));
return;
}

std::vector<smtutil::Expression> symbArgs;
for (unsigned i = 0; i < args.size(); ++i)
if (args.at(i))
symbArgs.emplace_back(expr(*args.at(i), inTypes.at(i)));

std::optional<smtutil::Expression> arg;
if (inTypes.size() == 1)
arg = expr(*args.at(0), inTypes.at(0));
else
{
auto inputSort = dynamic_cast<smtutil::ArraySort&>(*symbFunction.sort).domain;
arg = smtutil::Expression::tuple_constructor(
smtutil::Expression(std::make_shared<smtutil::SortSort>(inputSort), ""),
symbArgs
);
}

auto out = smtutil::Expression::select(symbFunction, *arg);
defineExpr(_funCall, out);

return;
}

void SMTEncoder::visitABIFunction(FunctionCall const& _funCall)
{
auto symbFunction = state().abiFunction(&_funCall);
Expand Down Expand Up @@ -909,7 +950,7 @@ void SMTEncoder::visitObjectCreation(FunctionCall const& _funCall)

smtutil::Expression arraySize = expr(*args.front());
setSymbolicUnknownValue(arraySize, TypeProvider::uint256(), m_context);

// bytes.concat(fun1())
auto symbArray = std::dynamic_pointer_cast<smt::SymbolicArrayVariable>(m_context.expression(_funCall));
solAssert(symbArray, "");
smt::setSymbolicZeroValue(*symbArray, m_context);
Expand Down Expand Up @@ -3112,6 +3153,29 @@ std::set<FunctionCall const*, ASTCompareByID<FunctionCall>> SMTEncoder::collectA
return ABIFunctions(_node).abiCalls;
}

std::set<FunctionCall const*, ASTCompareByID<FunctionCall>> SMTEncoder::collectBytesConcatCalls(ASTNode const* _node)
{
struct BytesConcatFunctions: public ASTConstVisitor
{
BytesConcatFunctions(ASTNode const* _node) { _node->accept(*this); }
void endVisit(FunctionCall const& _funCall)
{
if (*_funCall.annotation().kind == FunctionCallKind::FunctionCall)
switch (dynamic_cast<FunctionType const&>(*_funCall.expression().annotation().type).kind())
{
case FunctionType::Kind::BytesConcat:
bytesConcatCalls.insert(&_funCall);
break;
default: {}
}
}

std::set<FunctionCall const*, ASTCompareByID<FunctionCall>> bytesConcatCalls;
};

return BytesConcatFunctions(_node).bytesConcatCalls;
}

std::set<SourceUnit const*, ASTNode::CompareByID> SMTEncoder::sourceDependencies(SourceUnit const& _source)
{
std::set<SourceUnit const*, ASTNode::CompareByID> sources;
Expand Down
2 changes: 2 additions & 0 deletions libsolidity/formal/SMTEncoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class SMTEncoder: public ASTConstVisitor
static RationalNumberType const* isConstant(Expression const& _expr);

static std::set<FunctionCall const*, ASTCompareByID<FunctionCall>> collectABICalls(ASTNode const* _node);
static std::set<FunctionCall const*, ASTCompareByID<FunctionCall>> collectBytesConcatCalls(ASTNode const* _node);

/// @returns all the sources that @param _source depends on,
/// including itself.
Expand Down Expand Up @@ -211,6 +212,7 @@ class SMTEncoder: public ASTConstVisitor
void visitAssert(FunctionCall const& _funCall);
void visitRequire(FunctionCall const& _funCall);
void visitABIFunction(FunctionCall const& _funCall);
void visitBytesConcat(FunctionCall const& _funCall);
void visitCryptoFunction(FunctionCall const& _funCall);
void visitGasLeft(FunctionCall const& _funCall);
virtual void visitAddMulMod(FunctionCall const& _funCall);
Expand Down
Loading

0 comments on commit effab73

Please sign in to comment.