diff --git a/Changelog.md b/Changelog.md index 01a77ce17f7a..83e51c5172bc 100644 --- a/Changelog.md +++ b/Changelog.md @@ -4,7 +4,7 @@ Language Features: Compiler Features: - +* SMTChecker: Support `bytes.concat` function with an exception of string literals provided as an argument to `bytes.concat` call. Bugfixes: diff --git a/libsolidity/formal/CHC.cpp b/libsolidity/formal/CHC.cpp index 3cd1524599df..213d4bbc591a 100644 --- a/libsolidity/formal/CHC.cpp +++ b/libsolidity/formal/CHC.cpp @@ -950,7 +950,7 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co m_currentContract ); auto postCallState = std::vector{state().state()} + currentStateVariables(_contract); - std::vector stateExprs{error, address, state().abi(), state().crypto()}; + std::vector stateExprs{error, address, state().abi(), state().bytesConcat(), state().crypto()}; auto nondet = (*m_nondetInterfaces.at(&_contract))(stateExprs + preCallState + postCallState); auto nondetCall = callPredicate(stateExprs + preCallState + postCallState); @@ -1033,7 +1033,7 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall) &_funCall ); auto postCallState = std::vector{state().state()} + currentStateVariables(); - std::vector stateExprs{error, state().thisAddress(), state().abi(), state().crypto()}; + std::vector stateExprs{error, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto()}; auto nondet = (*m_nondetInterfaces.at(m_currentContract))(stateExprs + preCallState + postCallState); auto nondetCall = callPredicate(stateExprs + preCallState + postCallState); @@ -1464,7 +1464,7 @@ void CHC::defineInterfacesAndSummaries(SourceUnit const& _source) auto errorPost = errorFlag().increaseIndex(); auto nondetPost = smt::nondetInterface(iface, *contract, m_context, 0, 2); - std::vector args{errorPost, state().thisAddress(), state().abi(), state().crypto(), state().tx(), state().state(1)}; + std::vector args{errorPost, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state(1)}; args += state1 + applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) + std::vector{state().state(2)} + @@ -1829,7 +1829,7 @@ smtutil::Expression CHC::predicate( errorFlag().increaseIndex(); - std::vector args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().crypto(), state().tx(), state().state()}; + std::vector args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state()}; auto const* contract = _funDef->annotation().contract; auto const& hierarchy = m_currentContract->annotation().linearizedBaseContracts; diff --git a/libsolidity/formal/Predicate.cpp b/libsolidity/formal/Predicate.cpp index 2ab494fc4d28..741cc6326052 100644 --- a/libsolidity/formal/Predicate.cpp +++ b/libsolidity/formal/Predicate.cpp @@ -229,7 +229,7 @@ std::string Predicate::formatSummaryCall( return {}; } - /// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockChainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). + /// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockChainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). /// Here we are interested in preInputVars to format the function call. std::string txModel; @@ -285,7 +285,7 @@ std::string Predicate::formatSummaryCall( } // Here we are interested in txData from the summary predicate. - auto txValues = readTxVars(_args.at(4)); + auto txValues = readTxVars(_args.at(5)); std::vector values; for (auto const& _var: txVars) if (auto v = txValues.at(_var)) @@ -303,7 +303,7 @@ std::string Predicate::formatSummaryCall( auto const* fun = programFunction(); solAssert(fun, ""); - auto first = _args.begin() + 6 + static_cast(stateVars->size()); + auto first = _args.begin() + 7 + static_cast(stateVars->size()); auto last = first + static_cast(fun->parameters().size()); solAssert(first >= _args.begin() && first <= _args.end(), ""); solAssert(last >= _args.begin() && last <= _args.end(), ""); @@ -337,8 +337,8 @@ std::string Predicate::formatSummaryCall( std::vector> Predicate::summaryStateValues(std::vector const& _args) const { - /// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). - /// The signature of the summary predicate of a contract without constructor is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, postBlockchainState, preStateVars, postStateVars). + /// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). + /// The signature of the summary predicate of a contract without constructor is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, postBlockchainState, preStateVars, postStateVars). /// Here we are interested in postStateVars. auto stateVars = stateVariables(); solAssert(stateVars.has_value(), ""); @@ -347,12 +347,12 @@ std::vector> Predicate::summaryStateValues(std::vecto std::vector::const_iterator stateLast; if (auto const* function = programFunction()) { - stateFirst = _args.begin() + 6 + static_cast(stateVars->size()) + static_cast(function->parameters().size()) + 1; + stateFirst = _args.begin() + 7 + static_cast(stateVars->size()) + static_cast(function->parameters().size()) + 1; stateLast = stateFirst + static_cast(stateVars->size()); } else if (programContract()) { - stateFirst = _args.begin() + 7 + static_cast(stateVars->size()); + stateFirst = _args.begin() + 8 + static_cast(stateVars->size()); stateLast = stateFirst + static_cast(stateVars->size()); } else if (programVariable()) @@ -371,7 +371,7 @@ std::vector> Predicate::summaryStateValues(std::vecto std::vector> Predicate::summaryPostInputValues(std::vector const& _args) const { - /// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). + /// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). /// Here we are interested in postInputVars. auto const* function = programFunction(); solAssert(function, ""); @@ -381,7 +381,7 @@ std::vector> Predicate::summaryPostInputValues(std::v auto const& inParams = function->parameters(); - auto first = _args.begin() + 6 + static_cast(stateVars->size()) * 2 + static_cast(inParams.size()) + 1; + auto first = _args.begin() + 7 + static_cast(stateVars->size()) * 2 + static_cast(inParams.size()) + 1; auto last = first + static_cast(inParams.size()); solAssert(first >= _args.begin() && first <= _args.end(), ""); @@ -395,7 +395,7 @@ std::vector> Predicate::summaryPostInputValues(std::v std::vector> Predicate::summaryPostOutputValues(std::vector const& _args) const { - /// The signature of a function summary predicate is: summary(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). + /// The signature of a function summary predicate is: summary(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars). /// Here we are interested in outputVars. auto const* function = programFunction(); solAssert(function, ""); @@ -405,7 +405,7 @@ std::vector> Predicate::summaryPostOutputValues(std:: auto const& inParams = function->parameters(); - auto first = _args.begin() + 6 + static_cast(stateVars->size()) * 2 + static_cast(inParams.size()) * 2 + 1; + auto first = _args.begin() + 7 + static_cast(stateVars->size()) * 2 + static_cast(inParams.size()) * 2 + 1; solAssert(first >= _args.begin() && first <= _args.end(), ""); @@ -418,7 +418,7 @@ std::vector> Predicate::summaryPostOutputValues(std:: std::pair>, std::vector> Predicate::localVariableValues(std::vector const& _args) const { /// The signature of a local block predicate is: - /// block(error, this, abiFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars, localVars). + /// block(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, txData, preBlockchainState, preStateVars, preInputVars, postBlockchainState, postStateVars, postInputVars, outputVars, localVars). /// Here we are interested in localVars. auto const* function = programFunction(); solAssert(function, ""); @@ -452,19 +452,19 @@ std::map Predicate::expressionSubstitution(smtutil::Ex auto nArgs = _predExpr.arguments.size(); // The signature of an interface predicate is - // interface(this, abiFunctions, cryptoFunctions, blockchainState, stateVariables). + // interface(this, abiFunctions, bytesConcatFunctions, cryptoFunctions, blockchainState, stateVariables). // An invariant for an interface predicate is a contract // invariant over its state, for example `x <= 0`. if (isInterface()) { solAssert(starts_with(predName, "interface"), ""); subst[_predExpr.arguments.at(0).name] = "address(this)"; - solAssert(nArgs == stateVars.size() + 4, ""); + solAssert(nArgs == stateVars.size() + 5, ""); for (size_t i = nArgs - stateVars.size(); i < nArgs; ++i) - subst[_predExpr.arguments.at(i).name] = stateVars.at(i - 4)->name(); + subst[_predExpr.arguments.at(i).name] = stateVars.at(i - 5)->name(); } // The signature of a nondet interface predicate is - // nondet_interface(error, this, abiFunctions, cryptoFunctions, blockchainState, stateVariables, blockchainState', stateVariables'). + // nondet_interface(error, this, abiFunctions, bytesConcatFunctions, cryptoFunctions, blockchainState, stateVariables, blockchainState', stateVariables'). // An invariant for a nondet interface predicate is a reentrancy property // over the pre and post state variables of a contract, where pre state vars // are represented by the variable's name and post state vars are represented @@ -475,7 +475,7 @@ std::map Predicate::expressionSubstitution(smtutil::Ex solAssert(starts_with(predName, "nondet_interface"), ""); subst[_predExpr.arguments.at(0).name] = ""; subst[_predExpr.arguments.at(1).name] = "address(this)"; - solAssert(nArgs == stateVars.size() * 2 + 6, ""); + solAssert(nArgs == stateVars.size() * 2 + 7, ""); for (size_t i = nArgs - stateVars.size(), s = 0; i < nArgs; ++i, ++s) subst[_predExpr.arguments.at(i).name] = stateVars.at(s)->name() + "'"; for (size_t i = nArgs - (stateVars.size() * 2 + 1), s = 0; i < nArgs - (stateVars.size() + 1); ++i, ++s) diff --git a/libsolidity/formal/PredicateInstance.cpp b/libsolidity/formal/PredicateInstance.cpp index 0bf352dd4777..72351ae3b312 100644 --- a/libsolidity/formal/PredicateInstance.cpp +++ b/libsolidity/formal/PredicateInstance.cpp @@ -29,14 +29,16 @@ namespace solidity::frontend::smt smtutil::Expression interfacePre(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context) { auto& state = _context.state(); - std::vector stateExprs{state.thisAddress(0), state.abi(0), state.crypto(0), state.state(0)}; + std::vector stateExprs{state.thisAddress(0), + state.abi(0), state.bytesConcat(0), state.crypto(0), state.state(0)}; return _pred(stateExprs + initialStateVariables(_contract, _context)); } smtutil::Expression interface(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context) { auto const& state = _context.state(); - std::vector stateExprs{state.thisAddress(0), state.abi(0), state.crypto(0), state.state()}; + std::vector stateExprs{state.thisAddress(0), + state.abi(0), state.bytesConcat(0), state.crypto(0), state.state()}; return _pred(stateExprs + currentStateVariables(_contract, _context)); } @@ -48,7 +50,8 @@ smtutil::Expression nondetInterface( unsigned _postIdx) { auto const& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(), state.abi(), state.crypto()}; + std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(), + state.abi(), state.bytesConcat(0), state.crypto()}; return _pred( stateExprs + std::vector{_context.state().state(_preIdx)} + @@ -65,7 +68,7 @@ smtutil::Expression constructor(Predicate const& _pred, EncodingContext& _contex return _pred(currentFunctionVariablesForDefinition(*constructor, &contract, _context)); auto& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.crypto(0), state.tx(0), state.state(0), state.state()}; + std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0), state.state()}; return _pred(stateExprs + initialStateVariables(contract, _context) + currentStateVariables(contract, _context)); } @@ -76,7 +79,7 @@ smtutil::Expression constructorCall(Predicate const& _pred, EncodingContext& _co return _pred(currentFunctionVariablesForCall(*constructor, &contract, _context, _internal)); auto& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; + std::vector stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; state.newState(); stateExprs += std::vector{state.state()}; stateExprs += currentStateVariables(contract, _context); @@ -152,7 +155,8 @@ std::vector currentFunctionVariablesForDefinition( ) { auto& state = _context.state(); - std::vector exprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.crypto(0), state.tx(0), state.state(0)}; + std::vector exprs{state.errorFlag().currentValue(), state.thisAddress(0), + state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0)}; exprs += _contract ? initialStateVariables(*_contract, _context) : std::vector{}; exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->valueAtIndex(0); }); exprs += std::vector{state.state()}; @@ -170,7 +174,8 @@ std::vector currentFunctionVariablesForCall( ) { auto& state = _context.state(); - std::vector exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; + std::vector exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), + state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; exprs += _contract ? currentStateVariables(*_contract, _context) : std::vector{}; exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->currentValue(); }); diff --git a/libsolidity/formal/PredicateSort.cpp b/libsolidity/formal/PredicateSort.cpp index 2e95a29a531d..6df2b67d874f 100644 --- a/libsolidity/formal/PredicateSort.cpp +++ b/libsolidity/formal/PredicateSort.cpp @@ -30,7 +30,7 @@ namespace solidity::frontend::smt SortPointer interfaceSort(ContractDefinition const& _contract, SymbolicState& _state) { return std::make_shared( - std::vector{_state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract), + std::vector{_state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract), SortProvider::boolSort ); } @@ -40,7 +40,7 @@ SortPointer nondetInterfaceSort(ContractDefinition const& _contract, SymbolicSta auto varSorts = stateSorts(_contract); std::vector stateSort{_state.stateSort()}; return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort()} + + std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()} + stateSort + varSorts + stateSort + @@ -57,7 +57,7 @@ SortPointer constructorSort(ContractDefinition const& _contract, SymbolicState& auto varSorts = stateSorts(_contract); std::vector stateSort{_state.stateSort()}; return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts, + std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts, SortProvider::boolSort ); } @@ -69,7 +69,7 @@ SortPointer functionSort(FunctionDefinition const& _function, ContractDefinition auto inputSorts = applyMap(_function.parameters(), smtSort); auto outputSorts = applyMap(_function.returnParameters(), smtSort); return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} + + std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} + varSorts + inputSorts + std::vector{_state.stateSort()} + diff --git a/libsolidity/formal/SMTEncoder.cpp b/libsolidity/formal/SMTEncoder.cpp index ae3a06869d19..4e91776dc500 100644 --- a/libsolidity/formal/SMTEncoder.cpp +++ b/libsolidity/formal/SMTEncoder.cpp @@ -651,6 +651,9 @@ void SMTEncoder::endVisit(FunctionCall const& _funCall) if (publicGetter(_funCall.expression())) visitPublicGetter(_funCall); break; + case FunctionType::Kind::BytesConcat: + visitBytesConcat(_funCall); + break; case FunctionType::Kind::ABIDecode: case FunctionType::Kind::ABIEncode: case FunctionType::Kind::ABIEncodePacked: @@ -780,6 +783,44 @@ void SMTEncoder::visitRequire(FunctionCall const& _funCall) addPathImpliedExpression(expr(*args.front())); } +void SMTEncoder::visitBytesConcat(FunctionCall const& _funCall) +{ + auto const& args = _funCall.sortedArguments(); + auto const& [name, inTypes, outType] = state().bytesConcatFunctionTypes(&_funCall); + // bytes.concat call with no arguments returns an empty array + if (args.size() == 0) + { + defineExpr(_funCall, smt::zeroValue(TypeProvider::bytesMemory())); + return; + } + + auto symbFunction = state().bytesConcatFunction(&_funCall); + + solAssert(inTypes.size() == args.size(), ""); + + std::vector symbArgs; + for (unsigned i = 0; i < args.size(); ++i) + if (args.at(i)) + symbArgs.emplace_back(expr(*args.at(i), inTypes.at(i))); + + std::optional arg; + if (inTypes.size() == 1) + arg = expr(*args.at(0), inTypes.at(0)); + else + { + auto inputSort = dynamic_cast(*symbFunction.sort).domain; + arg = smtutil::Expression::tuple_constructor( + smtutil::Expression(std::make_shared(inputSort), ""), + symbArgs + ); + } + + auto out = smtutil::Expression::select(symbFunction, *arg); + defineExpr(_funCall, out); + + return; +} + void SMTEncoder::visitABIFunction(FunctionCall const& _funCall) { auto symbFunction = state().abiFunction(&_funCall); @@ -909,7 +950,6 @@ void SMTEncoder::visitObjectCreation(FunctionCall const& _funCall) smtutil::Expression arraySize = expr(*args.front()); setSymbolicUnknownValue(arraySize, TypeProvider::uint256(), m_context); - auto symbArray = std::dynamic_pointer_cast(m_context.expression(_funCall)); solAssert(symbArray, ""); smt::setSymbolicZeroValue(*symbArray, m_context); @@ -3112,6 +3152,29 @@ std::set> SMTEncoder::collectA return ABIFunctions(_node).abiCalls; } +std::set> SMTEncoder::collectBytesConcatCalls(ASTNode const* _node) +{ + struct BytesConcatFunctions: public ASTConstVisitor + { + BytesConcatFunctions(ASTNode const* _node) { _node->accept(*this); } + void endVisit(FunctionCall const& _funCall) + { + if (*_funCall.annotation().kind == FunctionCallKind::FunctionCall) + switch (dynamic_cast(*_funCall.expression().annotation().type).kind()) + { + case FunctionType::Kind::BytesConcat: + bytesConcatCalls.insert(&_funCall); + break; + default: {} + } + } + + std::set> bytesConcatCalls; + }; + + return BytesConcatFunctions(_node).bytesConcatCalls; +} + std::set SMTEncoder::sourceDependencies(SourceUnit const& _source) { std::set sources; diff --git a/libsolidity/formal/SMTEncoder.h b/libsolidity/formal/SMTEncoder.h index 69010fb1cf54..c556540ae5ba 100644 --- a/libsolidity/formal/SMTEncoder.h +++ b/libsolidity/formal/SMTEncoder.h @@ -122,6 +122,7 @@ class SMTEncoder: public ASTConstVisitor static RationalNumberType const* isConstant(Expression const& _expr); static std::set> collectABICalls(ASTNode const* _node); + static std::set> collectBytesConcatCalls(ASTNode const* _node); /// @returns all the sources that @param _source depends on, /// including itself. @@ -211,6 +212,7 @@ class SMTEncoder: public ASTConstVisitor void visitAssert(FunctionCall const& _funCall); void visitRequire(FunctionCall const& _funCall); void visitABIFunction(FunctionCall const& _funCall); + void visitBytesConcat(FunctionCall const& _funCall); void visitCryptoFunction(FunctionCall const& _funCall); void visitGasLeft(FunctionCall const& _funCall); virtual void visitAddMulMod(FunctionCall const& _funCall); diff --git a/libsolidity/formal/SymbolicState.cpp b/libsolidity/formal/SymbolicState.cpp index 3859e99c6243..585de4d9614a 100644 --- a/libsolidity/formal/SymbolicState.cpp +++ b/libsolidity/formal/SymbolicState.cpp @@ -268,16 +268,19 @@ void SymbolicState::prepareForSourceUnit(SourceUnit const& _source, bool _storag auto allSources = _source.referencedSourceUnits(true); allSources.insert(&_source); std::set> abiCalls; + std::set> bytesConcatCalls; std::set> contracts; for (auto const& source: allSources) { abiCalls += SMTEncoder::collectABICalls(source); + bytesConcatCalls += SMTEncoder::collectBytesConcatCalls(source); for (auto node: source->nodes()) if (auto contract = dynamic_cast(node.get())) contracts.insert(contract); } buildState(contracts, _storage); buildABIFunctions(abiCalls); + buildBytesConcatFunctions(bytesConcatCalls); } /// Private helpers. @@ -355,6 +358,82 @@ void SymbolicState::buildState(std::set> const& _bytesConcatCalls) +{ + std::map functions; + + for (auto const* funCall: _bytesConcatCalls) + { + auto t = dynamic_cast(funCall->expression().annotation().type); + solAssert(t->kind() == FunctionType::Kind::BytesConcat, "Expected bytes.concat function"); + + auto const& args = funCall->sortedArguments(); + + auto argTypes = [](auto const& _args) { + return util::applyMap(_args, [](auto arg) { return arg->annotation().type; }); + }; + + // bytes.concat : (bytes/literal/fixed bytes, ... ) -> bytes + + std::vector inTypes; + inTypes = argTypes(args); + + auto replaceUserDefinedValueTypes = [](auto& _types) { + for (auto& t: _types) + if (auto userType = dynamic_cast(t)) + t = &userType->underlyingType(); + }; + auto replaceStringLiteralTypes = [](auto& _types) { + for (auto& t: _types) + if (t->category() == frontend::Type::Category::StringLiteral) + t = TypeProvider::bytesMemory(); + }; + replaceUserDefinedValueTypes(inTypes); + replaceStringLiteralTypes(inTypes); + + auto name = t->richIdentifier(); + for (auto paramType: inTypes) + name += "_" + paramType->richIdentifier(); + + frontend::Type const* outType = TypeProvider::bytesMemory(); + name += "_" + outType->richIdentifier(); + + m_bytesConcatMembers[funCall] = {name, inTypes, outType}; + + if (functions.count(name)) + continue; + + /// If there is only one parameter, we use that type directly. + /// Otherwise we create a tuple wrapping the necessary types. + auto typesToSort = [](auto const& _types, std::string const& _name) -> std::shared_ptr { + if (_types.size() == 1) + return smtSortAbstractFunction(*_types.front()); + + std::vector inNames; + std::vector sorts; + for (unsigned i = 0; i < _types.size(); ++i) + { + inNames.emplace_back(_name + "_input_" + std::to_string(i)); + sorts.emplace_back(smtSortAbstractFunction(*_types.at(i))); + } + return std::make_shared( + _name + "_input", + inNames, + sorts + ); + }; + + auto functionSort = std::make_shared( + typesToSort(inTypes, name), + smtSortAbstractFunction(*outType) + ); + + functions[name] = functionSort; + } + + m_bytesConcat = std::make_unique("bytesConcat", std::move(functions), m_context); +} + void SymbolicState::buildABIFunctions(std::set> const& _abiFunctions) { std::map functions; @@ -367,7 +446,6 @@ void SymbolicState::buildABIFunctions(std::setparameterTypes(); auto const& returnTypes = t->returnParameterTypes(); - auto argTypes = [](auto const& _args) { return util::applyMap(_args, [](auto arg) { return arg->annotation().type; }); }; @@ -493,3 +571,14 @@ SymbolicState::SymbolicABIFunction const& SymbolicState::abiFunctionTypes(Functi { return m_abiMembers.at(_funCall); } + +smtutil::Expression SymbolicState::bytesConcatFunction(frontend::FunctionCall const* _funCall) +{ + solAssert(m_bytesConcat, ""); + return m_bytesConcat->member(std::get<0>(m_bytesConcatMembers.at(_funCall))); +} + +SymbolicState::SymbolicBytesConcatFunction const& SymbolicState::bytesConcatFunctionTypes(FunctionCall const* _funCall) const +{ + return m_bytesConcatMembers.at(_funCall); +} diff --git a/libsolidity/formal/SymbolicState.h b/libsolidity/formal/SymbolicState.h index 7fe5b028d627..75d3f4c4452c 100644 --- a/libsolidity/formal/SymbolicState.h +++ b/libsolidity/formal/SymbolicState.h @@ -182,7 +182,21 @@ class SymbolicState smtutil::Expression abi() const { solAssert(m_abi, ""); return m_abi->value(); } smtutil::Expression abi(unsigned _idx) const { solAssert(m_abi, ""); return m_abi->value(_idx); } smtutil::SortPointer const& abiSort() const { solAssert(m_abi, ""); return m_abi->sort(); } - void newABI() { solAssert(m_abi, ""); m_abi->newVar(); } + void newABI() { solAssert(m_abi, ""); m_abi->newVar(); } // unused? + //@} + + /// bytes.concat functions. + //@{ + smtutil::Expression bytesConcatFunction(FunctionCall const* _funCall); + using SymbolicBytesConcatFunction = std::tuple< + std::string, + std::vector, + frontend::Type const* + >; + SymbolicBytesConcatFunction const& bytesConcatFunctionTypes(FunctionCall const* _funCall) const; + smtutil::Expression bytesConcat() const { solAssert(m_bytesConcat, ""); return m_bytesConcat->value(); } + smtutil::Expression bytesConcat(unsigned _idx) const { solAssert(m_bytesConcat, ""); return m_bytesConcat->value(_idx); } + smtutil::SortPointer const& bytesConcatSort() const { solAssert(m_bytesConcat, ""); return m_bytesConcat->sort(); } //@} private: @@ -196,6 +210,9 @@ class SymbolicState /// Builds m_abi based on the abi.* calls _abiFunctions. void buildABIFunctions(std::set> const& _abiFunctions); + /// Builds m_bytesConcat based on the bytes.concat calls + void buildBytesConcatFunctions(std::set> const& _bytesConcatCalls); + EncodingContext& m_context; SymbolicIntVariable m_error{ @@ -278,6 +295,12 @@ class SymbolicState /// Maps ABI functions calls to their tuple names generated by /// `buildABIFunctions`. std::map m_abiMembers; + + /// Tuple containing all used bytes.concat functions. + std::unique_ptr m_bytesConcat; + /// Maps bytes.concat functions calls to their tuple names generated by + /// `buildBytesConcatFunctions`. + std::map m_bytesConcatMembers; }; } diff --git a/test/libsolidity/smtCheckerTests/bytes_concat/equals.sol b/test/libsolidity/smtCheckerTests/bytes_concat/equals.sol new file mode 100644 index 000000000000..334a2e0c5ff4 --- /dev/null +++ b/test/libsolidity/smtCheckerTests/bytes_concat/equals.sol @@ -0,0 +1,52 @@ +contract C { + + function concatCall(bytes8 a) public pure returns (bytes memory) { + return bytes.concat(a, a); + } + + function equalArguments1(bytes8 a, bytes8 b, bytes8 c, bytes8 d) public pure { + require(a == c); + require(b == d); + bytes memory concat1 = bytes.concat(a, b); + bytes memory concat2 = bytes.concat(c, d); + assert(keccak256(concat1) == keccak256(concat2)); + } + + function equalArguments2(bytes8 a, bytes8 c) public pure { + require(a == c); + bytes memory concat1 = bytes.concat(a, concatCall(a)); + bytes memory concat2 = bytes.concat(c, concatCall(c)); + assert(keccak256(concat1) == keccak256(concat2)); + } + + function equalLengthFixedBytes(bytes8 a, bytes8 b) public pure { + bytes memory concat1 = bytes.concat(a, b); + bytes memory concat2 = bytes.concat(a, b); + assert(concat1.length == concat2.length); + } + + function equalLengthMemoryBytes(bytes memory a, bytes memory b) public pure { + bytes memory concat1 = bytes.concat(a, b); + bytes memory concat2 = bytes.concat(a, b); + assert(concat1.length == concat2.length); + } + + function equalLengthMixed(bytes memory a, bytes2 b) public pure { + bytes memory concat1 = bytes.concat(a, b); + bytes memory concat2 = bytes.concat(a, b); + assert(concat1.length == concat2.length); + } + + function equalLengthLiterals() public pure { + bytes memory a = hex"aa"; + bytes1 b = bytes1(0xbb); + bytes memory c = "c"; + bytes memory concat1 = bytes.concat(a, b, c); + bytes memory concat2 = bytes.concat(a, b, c); + assert(concat1.length == concat2.length); + } +} +// ==== +// SMTEngine: all +// ---- +// Info 1391: CHC: 6 verification condition(s) proved safe! Enable the model checker option "show proved safe" to see all of them. diff --git a/test/libsolidity/smtCheckerTests/bytes_concat/simple.sol b/test/libsolidity/smtCheckerTests/bytes_concat/simple.sol new file mode 100644 index 000000000000..834d89ee9b4e --- /dev/null +++ b/test/libsolidity/smtCheckerTests/bytes_concat/simple.sol @@ -0,0 +1,35 @@ +contract C { + function zeroArgs() public pure returns(bytes memory) { + bytes memory a = bytes.concat(); + assert(a.length == 0); // zero args call is an empty bytes array + return bytes.concat(); + } + + function oneArg(bytes memory a) public pure returns(bytes memory) { + return bytes.concat(a); + } + + function oneArgFixedBytes(bytes8 a) public pure returns(bytes memory) { + return bytes.concat(a); + } + + function fixedBytes(bytes8 a, bytes8 b) public pure returns(bytes memory) { + return bytes.concat(a, b); + } + + function memoryBytes(bytes memory a, bytes memory b) public pure returns(bytes memory) { + return bytes.concat(a, b); + } + + function mixed(bytes memory a, bytes2 b) public pure returns(bytes memory) { + return bytes.concat(a, b, "StringLiteral"); + } + + function functionCallAsArg(bytes memory a) public pure returns(bytes memory) { + return bytes.concat(oneArg(a), zeroArgs()); + } +} +// ==== +// SMTEngine: all +// ---- +// Info 1391: CHC: 1 verification condition(s) proved safe! Enable the model checker option "show proved safe" to see all of them.