From c18e28427dd86a5f649aac7d4ba9b98c0bb05ed4 Mon Sep 17 00:00:00 2001 From: Pawel Gebal Date: Wed, 10 Jan 2024 17:22:13 +0100 Subject: [PATCH] refactor --- libsolidity/formal/CHC.cpp | 22 ++++++---- libsolidity/formal/CHC.h | 2 + libsolidity/formal/PredicateInstance.cpp | 55 ++++++++++++++++++------ libsolidity/formal/PredicateInstance.h | 7 +++ libsolidity/formal/PredicateSort.cpp | 22 ++++++++-- libsolidity/formal/PredicateSort.h | 4 +- 6 files changed, 87 insertions(+), 25 deletions(-) diff --git a/libsolidity/formal/CHC.cpp b/libsolidity/formal/CHC.cpp index 213d4bbc591a..5315cfcfda88 100644 --- a/libsolidity/formal/CHC.cpp +++ b/libsolidity/formal/CHC.cpp @@ -940,8 +940,6 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co for (auto const* var: _contract.stateVariables()) m_context.variable(*var)->increaseIndex(); - auto error = errorFlag().increaseIndex(); - Predicate const& callPredicate = *createSymbolicBlock( nondetInterfaceSort(_contract, state()), "nondet_call_" + uniquePrefix(), @@ -950,7 +948,7 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co m_currentContract ); auto postCallState = std::vector{state().state()} + currentStateVariables(_contract); - std::vector stateExprs{error, address, state().abi(), state().bytesConcat(), state().crypto()}; + std::vector stateExprs = commonStateExpressions(errorFlag().increaseIndex(), address); auto nondet = (*m_nondetInterfaces.at(&_contract))(stateExprs + preCallState + postCallState); auto nondetCall = callPredicate(stateExprs + preCallState + postCallState); @@ -1024,8 +1022,6 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall) m_context.variable(*var)->increaseIndex(); } - auto error = errorFlag().increaseIndex(); - Predicate const& callPredicate = *createSymbolicBlock( nondetInterfaceSort(*m_currentContract, state()), "nondet_call_" + uniquePrefix(), @@ -1033,7 +1029,7 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall) &_funCall ); auto postCallState = std::vector{state().state()} + currentStateVariables(); - std::vector stateExprs{error, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto()}; + std::vector stateExprs = commonStateExpressions(errorFlag().increaseIndex(), state().thisAddress()); auto nondet = (*m_nondetInterfaces.at(m_currentContract))(stateExprs + preCallState + postCallState); auto nondetCall = callPredicate(stateExprs + preCallState + postCallState); @@ -1464,7 +1460,9 @@ void CHC::defineInterfacesAndSummaries(SourceUnit const& _source) auto errorPost = errorFlag().increaseIndex(); auto nondetPost = smt::nondetInterface(iface, *contract, m_context, 0, 2); - std::vector args{errorPost, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state(1)}; + std::vector args = commonStateExpressions(errorPost, state().thisAddress()); + args.push_back(state().tx()); + args.push_back(state().state(1)); args += state1 + applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) + std::vector{state().state(2)} + @@ -1829,7 +1827,9 @@ smtutil::Expression CHC::predicate( errorFlag().increaseIndex(); - std::vector args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state()}; + std::vector args = commonStateExpressions(errorFlag().currentValue(), _contractAddressValue); + args.push_back(state().tx()); + args.push_back(state().state()); auto const* contract = _funDef->annotation().contract; auto const& hierarchy = m_currentContract->annotation().linearizedBaseContracts; @@ -2486,3 +2486,9 @@ void CHC::decreaseBalanceFromOptionsValue(Expression const& _value) { state().addBalance(state().thisAddress(), 0 - expr(_value)); } + + +std::vector CHC::commonStateExpressions(smtutil::Expression error, smtutil::Expression address) +{ + return {error, address, state().abi(), state().bytesConcat(), state().crypto()}; +} diff --git a/libsolidity/formal/CHC.h b/libsolidity/formal/CHC.h index fe7387bd3568..d1d4fb6530ba 100644 --- a/libsolidity/formal/CHC.h +++ b/libsolidity/formal/CHC.h @@ -362,6 +362,8 @@ class CHC: public SMTEncoder /// Adds constraints that decrease the balance of the caller by _value. void decreaseBalanceFromOptionsValue(Expression const& _value); + + std::vector commonStateExpressions(smtutil::Expression error, smtutil::Expression address); //@} /// Predicates. diff --git a/libsolidity/formal/PredicateInstance.cpp b/libsolidity/formal/PredicateInstance.cpp index a3be210fbfdb..e2fa0e87e5ad 100644 --- a/libsolidity/formal/PredicateInstance.cpp +++ b/libsolidity/formal/PredicateInstance.cpp @@ -29,16 +29,14 @@ namespace solidity::frontend::smt smtutil::Expression interfacePre(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context) { auto& state = _context.state(); - std::vector stateExprs{state.thisAddress(0), - state.abi(0), state.bytesConcat(0), state.crypto(0), state.state(0)}; + std::vector stateExprs = getStateExpressionsForInterfacePre(state); return _pred(stateExprs + initialStateVariables(_contract, _context)); } smtutil::Expression interface(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context) { auto const& state = _context.state(); - std::vector stateExprs{state.thisAddress(0), - state.abi(0), state.bytesConcat(0), state.crypto(0), state.state()}; + std::vector stateExprs = getStateExpressionsForInterface(state); return _pred(stateExprs + currentStateVariables(_contract, _context)); } @@ -50,8 +48,7 @@ smtutil::Expression nondetInterface( unsigned _postIdx) { auto const& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(), - state.abi(), state.bytesConcat(), state.crypto()}; + std::vector stateExprs = getStateExpressionsForNondetInterface(state); return _pred( stateExprs + std::vector{_context.state().state(_preIdx)} + @@ -68,7 +65,7 @@ smtutil::Expression constructor(Predicate const& _pred, EncodingContext& _contex return _pred(currentFunctionVariablesForDefinition(*constructor, &contract, _context)); auto& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0), state.state()}; + std::vector stateExprs = getStateExpressionsForConstructor(state); return _pred(stateExprs + initialStateVariables(contract, _context) + currentStateVariables(contract, _context)); } @@ -79,7 +76,7 @@ smtutil::Expression constructorCall(Predicate const& _pred, EncodingContext& _co return _pred(currentFunctionVariablesForCall(*constructor, &contract, _context, _internal)); auto& state = _context.state(); - std::vector stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; + std::vector stateExprs = getStateExpressionsForCall(state, _internal); state.newState(); stateExprs += std::vector{state.state()}; stateExprs += currentStateVariables(contract, _context); @@ -155,8 +152,7 @@ std::vector currentFunctionVariablesForDefinition( ) { auto& state = _context.state(); - std::vector exprs{state.errorFlag().currentValue(), state.thisAddress(0), - state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0)}; + std::vector exprs = getStateExpressionsForDefinition(state); exprs += _contract ? initialStateVariables(*_contract, _context) : std::vector{}; exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->valueAtIndex(0); }); exprs += std::vector{state.state()}; @@ -174,8 +170,7 @@ std::vector currentFunctionVariablesForCall( ) { auto& state = _context.state(); - std::vector exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), - state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()}; + std::vector exprs = getStateExpressionsForCall(state, _internal); exprs += _contract ? currentStateVariables(*_contract, _context) : std::vector{}; exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->currentValue(); }); @@ -197,4 +192,40 @@ std::vector currentBlockVariables(FunctionDefinition const& ); } +std::vector getStateExpressionsForInterfacePre(SymbolicState const& _state) +{ + return {_state.thisAddress(0), + _state.abi(0), _state.bytesConcat(0), _state.crypto(0), _state.state(0)}; +} + +std::vector getStateExpressionsForInterface(SymbolicState const& _state) +{ + return {_state.thisAddress(0), + _state.abi(0), _state.bytesConcat(0), _state.crypto(0), _state.state()}; +} + +std::vector getStateExpressionsForNondetInterface(SymbolicState const& _state) +{ + return {_state.errorFlag().currentValue(), _state.thisAddress(), + _state.abi(), _state.bytesConcat(), _state.crypto()}; +} + +std::vector getStateExpressionsForConstructor(SymbolicState const& _state) +{ + return {_state.errorFlag().currentValue(), _state.thisAddress(0), _state.abi(0), + _state.bytesConcat(0), _state.crypto(0), _state.tx(0), _state.state(0), _state.state()}; +} + +std::vector getStateExpressionsForCall(SymbolicState const& _state, bool _internal) +{ + return {_state.errorFlag().currentValue(), _internal ? _state.thisAddress(0) : _state.thisAddress(), + _state.abi(0), _state.bytesConcat(0), _state.crypto(0), _internal ? _state.tx(0) : _state.tx(), _state.state()}; +} + +std::vector getStateExpressionsForDefinition(SymbolicState const& _state) +{ + return {_state.errorFlag().currentValue(), _state.thisAddress(0), _state.abi(0), + _state.bytesConcat(0), _state.crypto(0), _state.tx(0), _state.state(0)}; +} + } diff --git a/libsolidity/formal/PredicateInstance.h b/libsolidity/formal/PredicateInstance.h index 42a9e2040b36..33c8a8a0dde9 100644 --- a/libsolidity/formal/PredicateInstance.h +++ b/libsolidity/formal/PredicateInstance.h @@ -19,6 +19,7 @@ #pragma once #include +#include namespace solidity::frontend::smt { @@ -94,4 +95,10 @@ std::vector currentBlockVariables( EncodingContext& _context ); +std::vector getStateExpressionsForInterfacePre(SymbolicState const& _state); +std::vector getStateExpressionsForInterface(SymbolicState const& _state); +std::vector getStateExpressionsForNondetInterface(SymbolicState const& _state); +std::vector getStateExpressionsForConstructor(SymbolicState const& _state); +std::vector getStateExpressionsForCall(SymbolicState const& _state, bool _internal); +std::vector getStateExpressionsForDefinition(SymbolicState const& _state); } diff --git a/libsolidity/formal/PredicateSort.cpp b/libsolidity/formal/PredicateSort.cpp index 6df2b67d874f..6d57b04033c3 100644 --- a/libsolidity/formal/PredicateSort.cpp +++ b/libsolidity/formal/PredicateSort.cpp @@ -30,7 +30,10 @@ namespace solidity::frontend::smt SortPointer interfaceSort(ContractDefinition const& _contract, SymbolicState& _state) { return std::make_shared( - std::vector{_state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract), + std::vector{_state.thisAddressSort()} + + getBuiltInFunctionsSorts(_state) + + std::vector{_state.stateSort()} + + stateSorts(_contract), SortProvider::boolSort ); } @@ -40,7 +43,8 @@ SortPointer nondetInterfaceSort(ContractDefinition const& _contract, SymbolicSta auto varSorts = stateSorts(_contract); std::vector stateSort{_state.stateSort()}; return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()} + + std::vector{_state.errorFlagSort(), _state.thisAddressSort()} + + getBuiltInFunctionsSorts(_state) + stateSort + varSorts + stateSort + @@ -57,7 +61,10 @@ SortPointer constructorSort(ContractDefinition const& _contract, SymbolicState& auto varSorts = stateSorts(_contract); std::vector stateSort{_state.stateSort()}; return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts, + std::vector{_state.errorFlagSort(), _state.thisAddressSort()} + + getBuiltInFunctionsSorts(_state) + + std::vector{_state.txSort(), _state.stateSort(), _state.stateSort()} + + varSorts + varSorts, SortProvider::boolSort ); } @@ -69,7 +76,9 @@ SortPointer functionSort(FunctionDefinition const& _function, ContractDefinition auto inputSorts = applyMap(_function.parameters(), smtSort); auto outputSorts = applyMap(_function.returnParameters(), smtSort); return std::make_shared( - std::vector{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} + + std::vector{_state.errorFlagSort(), _state.thisAddressSort()} + + getBuiltInFunctionsSorts(_state) + + std::vector{_state.txSort(), _state.stateSort()} + varSorts + inputSorts + std::vector{_state.stateSort()} + @@ -110,4 +119,9 @@ std::vector stateSorts(ContractDefinition const& _contract) ); } +std::vector getBuiltInFunctionsSorts(SymbolicState& _state) +{ + return {_state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()}; +} + } diff --git a/libsolidity/formal/PredicateSort.h b/libsolidity/formal/PredicateSort.h index 5638582546e9..c5f63f2a6965 100644 --- a/libsolidity/formal/PredicateSort.h +++ b/libsolidity/formal/PredicateSort.h @@ -74,6 +74,8 @@ smtutil::SortPointer arity0FunctionSort(); /// Helpers -std::vector stateSorts(ContractDefinition const& _contract) ; +std::vector stateSorts(ContractDefinition const& _contract); + +std::vector getBuiltInFunctionsSorts(SymbolicState& _state); }