Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
pgebal committed Jan 10, 2024
1 parent df1b15e commit c18e284
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 25 deletions.
22 changes: 14 additions & 8 deletions libsolidity/formal/CHC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,8 +940,6 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co
for (auto const* var: _contract.stateVariables())
m_context.variable(*var)->increaseIndex();

auto error = errorFlag().increaseIndex();

Predicate const& callPredicate = *createSymbolicBlock(
nondetInterfaceSort(_contract, state()),
"nondet_call_" + uniquePrefix(),
Expand All @@ -950,7 +948,7 @@ void CHC::nondetCall(ContractDefinition const& _contract, VariableDeclaration co
m_currentContract
);
auto postCallState = std::vector<smtutil::Expression>{state().state()} + currentStateVariables(_contract);
std::vector<smtutil::Expression> stateExprs{error, address, state().abi(), state().bytesConcat(), state().crypto()};
std::vector<smtutil::Expression> stateExprs = commonStateExpressions(errorFlag().increaseIndex(), address);

auto nondet = (*m_nondetInterfaces.at(&_contract))(stateExprs + preCallState + postCallState);
auto nondetCall = callPredicate(stateExprs + preCallState + postCallState);
Expand Down Expand Up @@ -1024,16 +1022,14 @@ void CHC::externalFunctionCall(FunctionCall const& _funCall)
m_context.variable(*var)->increaseIndex();
}

auto error = errorFlag().increaseIndex();

Predicate const& callPredicate = *createSymbolicBlock(
nondetInterfaceSort(*m_currentContract, state()),
"nondet_call_" + uniquePrefix(),
PredicateType::ExternalCallUntrusted,
&_funCall
);
auto postCallState = std::vector<smtutil::Expression>{state().state()} + currentStateVariables();
std::vector<smtutil::Expression> stateExprs{error, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto()};
std::vector<smtutil::Expression> stateExprs = commonStateExpressions(errorFlag().increaseIndex(), state().thisAddress());

auto nondet = (*m_nondetInterfaces.at(m_currentContract))(stateExprs + preCallState + postCallState);
auto nondetCall = callPredicate(stateExprs + preCallState + postCallState);
Expand Down Expand Up @@ -1464,7 +1460,9 @@ void CHC::defineInterfacesAndSummaries(SourceUnit const& _source)
auto errorPost = errorFlag().increaseIndex();
auto nondetPost = smt::nondetInterface(iface, *contract, m_context, 0, 2);

std::vector<smtutil::Expression> args{errorPost, state().thisAddress(), state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state(1)};
std::vector<smtutil::Expression> args = commonStateExpressions(errorPost, state().thisAddress());
args.push_back(state().tx());
args.push_back(state().state(1));
args += state1 +
applyMap(function->parameters(), [this](auto _var) { return valueAtIndex(*_var, 0); }) +
std::vector<smtutil::Expression>{state().state(2)} +
Expand Down Expand Up @@ -1829,7 +1827,9 @@ smtutil::Expression CHC::predicate(

errorFlag().increaseIndex();

std::vector<smtutil::Expression> args{errorFlag().currentValue(), _contractAddressValue, state().abi(), state().bytesConcat(), state().crypto(), state().tx(), state().state()};
std::vector<smtutil::Expression> args = commonStateExpressions(errorFlag().currentValue(), _contractAddressValue);
args.push_back(state().tx());
args.push_back(state().state());

auto const* contract = _funDef->annotation().contract;
auto const& hierarchy = m_currentContract->annotation().linearizedBaseContracts;
Expand Down Expand Up @@ -2486,3 +2486,9 @@ void CHC::decreaseBalanceFromOptionsValue(Expression const& _value)
{
state().addBalance(state().thisAddress(), 0 - expr(_value));
}


std::vector<smtutil::Expression> CHC::commonStateExpressions(smtutil::Expression error, smtutil::Expression address)
{
return {error, address, state().abi(), state().bytesConcat(), state().crypto()};
}
2 changes: 2 additions & 0 deletions libsolidity/formal/CHC.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ class CHC: public SMTEncoder

/// Adds constraints that decrease the balance of the caller by _value.
void decreaseBalanceFromOptionsValue(Expression const& _value);

std::vector<smtutil::Expression> commonStateExpressions(smtutil::Expression error, smtutil::Expression address);
//@}

/// Predicates.
Expand Down
55 changes: 43 additions & 12 deletions libsolidity/formal/PredicateInstance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ namespace solidity::frontend::smt
smtutil::Expression interfacePre(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context)
{
auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.state(0)};
std::vector<smtutil::Expression> stateExprs = getStateExpressionsForInterfacePre(state);
return _pred(stateExprs + initialStateVariables(_contract, _context));
}

smtutil::Expression interface(Predicate const& _pred, ContractDefinition const& _contract, EncodingContext& _context)
{
auto const& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.state()};
std::vector<smtutil::Expression> stateExprs = getStateExpressionsForInterface(state);
return _pred(stateExprs + currentStateVariables(_contract, _context));
}

Expand All @@ -50,8 +48,7 @@ smtutil::Expression nondetInterface(
unsigned _postIdx)
{
auto const& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(),
state.abi(), state.bytesConcat(), state.crypto()};
std::vector<smtutil::Expression> stateExprs = getStateExpressionsForNondetInterface(state);
return _pred(
stateExprs +
std::vector<smtutil::Expression>{_context.state().state(_preIdx)} +
Expand All @@ -68,7 +65,7 @@ smtutil::Expression constructor(Predicate const& _pred, EncodingContext& _contex
return _pred(currentFunctionVariablesForDefinition(*constructor, &contract, _context));

auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), state.thisAddress(0), state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0), state.state()};
std::vector<smtutil::Expression> stateExprs = getStateExpressionsForConstructor(state);
return _pred(stateExprs + initialStateVariables(contract, _context) + currentStateVariables(contract, _context));
}

Expand All @@ -79,7 +76,7 @@ smtutil::Expression constructorCall(Predicate const& _pred, EncodingContext& _co
return _pred(currentFunctionVariablesForCall(*constructor, &contract, _context, _internal));

auto& state = _context.state();
std::vector<smtutil::Expression> stateExprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(), state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
std::vector<smtutil::Expression> stateExprs = getStateExpressionsForCall(state, _internal);
state.newState();
stateExprs += std::vector<smtutil::Expression>{state.state()};
stateExprs += currentStateVariables(contract, _context);
Expand Down Expand Up @@ -155,8 +152,7 @@ std::vector<smtutil::Expression> currentFunctionVariablesForDefinition(
)
{
auto& state = _context.state();
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), state.thisAddress(0),
state.abi(0), state.bytesConcat(0), state.crypto(0), state.tx(0), state.state(0)};
std::vector<smtutil::Expression> exprs = getStateExpressionsForDefinition(state);
exprs += _contract ? initialStateVariables(*_contract, _context) : std::vector<smtutil::Expression>{};
exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->valueAtIndex(0); });
exprs += std::vector<smtutil::Expression>{state.state()};
Expand All @@ -174,8 +170,7 @@ std::vector<smtutil::Expression> currentFunctionVariablesForCall(
)
{
auto& state = _context.state();
std::vector<smtutil::Expression> exprs{state.errorFlag().currentValue(), _internal ? state.thisAddress(0) : state.thisAddress(),
state.abi(0), state.bytesConcat(0), state.crypto(0), _internal ? state.tx(0) : state.tx(), state.state()};
std::vector<smtutil::Expression> exprs = getStateExpressionsForCall(state, _internal);
exprs += _contract ? currentStateVariables(*_contract, _context) : std::vector<smtutil::Expression>{};
exprs += applyMap(_function.parameters(), [&](auto _var) { return _context.variable(*_var)->currentValue(); });

Expand All @@ -197,4 +192,40 @@ std::vector<smtutil::Expression> currentBlockVariables(FunctionDefinition const&
);
}

std::vector<smtutil::Expression> getStateExpressionsForInterfacePre(SymbolicState const& _state)
{
return {_state.thisAddress(0),
_state.abi(0), _state.bytesConcat(0), _state.crypto(0), _state.state(0)};
}

std::vector<smtutil::Expression> getStateExpressionsForInterface(SymbolicState const& _state)
{
return {_state.thisAddress(0),
_state.abi(0), _state.bytesConcat(0), _state.crypto(0), _state.state()};
}

std::vector<smtutil::Expression> getStateExpressionsForNondetInterface(SymbolicState const& _state)
{
return {_state.errorFlag().currentValue(), _state.thisAddress(),
_state.abi(), _state.bytesConcat(), _state.crypto()};
}

std::vector<smtutil::Expression> getStateExpressionsForConstructor(SymbolicState const& _state)
{
return {_state.errorFlag().currentValue(), _state.thisAddress(0), _state.abi(0),
_state.bytesConcat(0), _state.crypto(0), _state.tx(0), _state.state(0), _state.state()};
}

std::vector<smtutil::Expression> getStateExpressionsForCall(SymbolicState const& _state, bool _internal)
{
return {_state.errorFlag().currentValue(), _internal ? _state.thisAddress(0) : _state.thisAddress(),
_state.abi(0), _state.bytesConcat(0), _state.crypto(0), _internal ? _state.tx(0) : _state.tx(), _state.state()};
}

std::vector<smtutil::Expression> getStateExpressionsForDefinition(SymbolicState const& _state)
{
return {_state.errorFlag().currentValue(), _state.thisAddress(0), _state.abi(0),
_state.bytesConcat(0), _state.crypto(0), _state.tx(0), _state.state(0)};
}

}
7 changes: 7 additions & 0 deletions libsolidity/formal/PredicateInstance.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#pragma once

#include <libsolidity/formal/Predicate.h>
#include <libsolidity/formal/SymbolicState.h>

namespace solidity::frontend::smt
{
Expand Down Expand Up @@ -94,4 +95,10 @@ std::vector<smtutil::Expression> currentBlockVariables(
EncodingContext& _context
);

std::vector<smtutil::Expression> getStateExpressionsForInterfacePre(SymbolicState const& _state);
std::vector<smtutil::Expression> getStateExpressionsForInterface(SymbolicState const& _state);
std::vector<smtutil::Expression> getStateExpressionsForNondetInterface(SymbolicState const& _state);
std::vector<smtutil::Expression> getStateExpressionsForConstructor(SymbolicState const& _state);
std::vector<smtutil::Expression> getStateExpressionsForCall(SymbolicState const& _state, bool _internal);
std::vector<smtutil::Expression> getStateExpressionsForDefinition(SymbolicState const& _state);
}
22 changes: 18 additions & 4 deletions libsolidity/formal/PredicateSort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ namespace solidity::frontend::smt
SortPointer interfaceSort(ContractDefinition const& _contract, SymbolicState& _state)
{
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.stateSort()} + stateSorts(_contract),
std::vector<SortPointer>{_state.thisAddressSort()} +
getBuiltInFunctionsSorts(_state) +
std::vector<SortPointer>{_state.stateSort()} +
stateSorts(_contract),
SortProvider::boolSort
);
}
Expand All @@ -40,7 +43,8 @@ SortPointer nondetInterfaceSort(ContractDefinition const& _contract, SymbolicSta
auto varSorts = stateSorts(_contract);
std::vector<SortPointer> stateSort{_state.stateSort()};
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()} +
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort()} +
getBuiltInFunctionsSorts(_state) +
stateSort +
varSorts +
stateSort +
Expand All @@ -57,7 +61,10 @@ SortPointer constructorSort(ContractDefinition const& _contract, SymbolicState&
auto varSorts = stateSorts(_contract);
std::vector<SortPointer> stateSort{_state.stateSort()};
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort(), _state.stateSort()} + varSorts + varSorts,
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort()} +
getBuiltInFunctionsSorts(_state) +
std::vector<SortPointer>{_state.txSort(), _state.stateSort(), _state.stateSort()} +
varSorts + varSorts,
SortProvider::boolSort
);
}
Expand All @@ -69,7 +76,9 @@ SortPointer functionSort(FunctionDefinition const& _function, ContractDefinition
auto inputSorts = applyMap(_function.parameters(), smtSort);
auto outputSorts = applyMap(_function.returnParameters(), smtSort);
return std::make_shared<FunctionSort>(
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort(), _state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort(), _state.txSort(), _state.stateSort()} +
std::vector<SortPointer>{_state.errorFlagSort(), _state.thisAddressSort()} +
getBuiltInFunctionsSorts(_state) +
std::vector<SortPointer>{_state.txSort(), _state.stateSort()} +
varSorts +
inputSorts +
std::vector<SortPointer>{_state.stateSort()} +
Expand Down Expand Up @@ -110,4 +119,9 @@ std::vector<SortPointer> stateSorts(ContractDefinition const& _contract)
);
}

std::vector<SortPointer> getBuiltInFunctionsSorts(SymbolicState& _state)
{
return {_state.abiSort(), _state.bytesConcatSort(), _state.cryptoSort()};
}

}
4 changes: 3 additions & 1 deletion libsolidity/formal/PredicateSort.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ smtutil::SortPointer arity0FunctionSort();

/// Helpers

std::vector<smtutil::SortPointer> stateSorts(ContractDefinition const& _contract) ;
std::vector<smtutil::SortPointer> stateSorts(ContractDefinition const& _contract);

std::vector<smtutil::SortPointer> getBuiltInFunctionsSorts(SymbolicState& _state);

}

0 comments on commit c18e284

Please sign in to comment.