Skip to content

Commit

Permalink
feat: add delegatecall in CairoLib (#1251)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR: 0.4d

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves #<Issue number>

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

- Adds a feature to delegatecall to a CairoLib to preserve the
msg.sender context
-
-

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1251)
<!-- Reviewable:end -->

---------

Co-authored-by: Clément Walter <[email protected]>
  • Loading branch information
enitrat and ClementWalter authored Jul 4, 2024
1 parent 083cb41 commit 22cbf53
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 80 deletions.
13 changes: 8 additions & 5 deletions solidity_contracts/src/CairoPrecompiles/CairoCounterCaller.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using CairoLib for uint256;

contract CairoCounterCaller {
/// @dev The cairo contract to call
uint256 cairoCounter;
uint256 immutable cairoCounter;

/// @dev The cairo function selector to call - `inc`
uint256 constant FUNCTION_SELECTOR_INC = uint256(keccak256("inc")) % 2 ** 250;
Expand All @@ -25,18 +25,21 @@ contract CairoCounterCaller {
}

function getCairoCounter() public view returns (uint256 counterValue) {
bytes memory returnData = cairoCounter.staticcallContract(FUNCTION_SELECTOR_GET);
bytes memory returnData = cairoCounter.staticcallCairo(FUNCTION_SELECTOR_GET);

// The return data is a 256-bit integer, so we can directly cast it to uint256
return abi.decode(returnData, (uint256));
}

/// @notice Calls the Cairo contract to increment its internal counter
/// @dev The delegatecall preserves the caller's context, so the caller's address will
/// be the caller of this function.
function incrementCairoCounter() external {
cairoCounter.callContract("inc");
cairoCounter.delegatecallCairo("inc");
}

/// @notice Calls the Cairo contract to set its internal counter to an arbitrary value
/// @dev Called with a regular call, the caller's address will be this contract's address
/// @dev The counter value is split into two 128-bit values to match the Cairo contract's expected inputs (u256 is composed of two u128s)
/// @param newCounter The new counter value to set
function setCairoCounter(uint256 newCounter) external {
Expand All @@ -47,13 +50,13 @@ contract CairoCounterCaller {
uint256[] memory data = new uint256[](2);
data[0] = uint256(newCounterLow);
data[1] = uint256(newCounterHigh);
cairoCounter.callContract(FUNCTION_SELECTOR_SET_COUNTER, data);
cairoCounter.callCairo(FUNCTION_SELECTOR_SET_COUNTER, data);
}

/// @notice Calls the Cairo contract to get the (starknet) address of the last caller
/// @return lastCaller The starknet address of the last caller
function getLastCaller() external view returns (uint256 lastCaller) {
bytes memory returnData = cairoCounter.staticcallContract(FUNCTION_SELECTOR_GET_LAST_CALLER);
bytes memory returnData = cairoCounter.staticcallCairo(FUNCTION_SELECTOR_GET_LAST_CALLER);

return abi.decode(returnData, (uint256));
}
Expand Down
87 changes: 71 additions & 16 deletions solidity_contracts/src/CairoPrecompiles/CairoLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ library CairoLib {

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @param data The input data for the Cairo contract function.
/// @return returnData The return data from the Cairo contract function.
function callContract(uint256 contractAddress, uint256 functionSelector, uint256[] memory data)
function callCairo(uint256 contractAddress, uint256 functionSelector, uint256[] memory data)
internal
returns (bytes memory returnData)
{
Expand All @@ -26,27 +27,77 @@ library CairoLib {

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function callContract(uint256 contractAddress, uint256 functionSelector)
function callCairo(uint256 contractAddress, uint256 functionSelector) internal returns (bytes memory returnData) {
uint256[] memory data = new uint256[](0);
return callCairo(contractAddress, functionSelector, data);
}

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @param functionName The name of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function callCairo(uint256 contractAddress, string memory functionName)
internal
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
return callContract(contractAddress, functionSelector, data);
uint256 functionSelector = uint256(keccak256(bytes(functionName))) % 2 ** 250;
return callCairo(contractAddress, functionSelector, data);
}

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
/// @notice Performs a low-level delegatecall to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @dev Using delegatecall preserves the context of the calling contract, and the execution of the
/// callee contract is performed using the `msg.sender` of the calling contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @param data The input data for the Cairo contract function.
/// @return returnData The return data from the Cairo contract function.
function delegatecallCairo(uint256 contractAddress, uint256 functionSelector, uint256[] memory data)
internal
returns (bytes memory returnData)
{
bytes memory callData =
abi.encodeWithSignature("call_contract(uint256,uint256,uint256[])", contractAddress, functionSelector, data);

(bool success, bytes memory result) = CAIRO_PRECOMPILE_ADDRESS.delegatecall(callData);
require(success, "CairoLib: call_contract failed");

returnData = result;
}

/// @notice Performs a low-level delegatecall to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @dev Using delegatecall preserves the context of the calling contract, and the execution of the
/// callee contract is performed using the `msg.sender` of the calling contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function delegatecallCairo(uint256 contractAddress, uint256 functionSelector)
internal
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
return delegatecallCairo(contractAddress, functionSelector, data);
}

/// @notice Performs a low-level delegatecall to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to modify the state of the Cairo contract.
/// @dev Using delegatecall preserves the context of the calling contract, and the execution of the
/// callee contract is performed using the `msg.sender` of the calling contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionName The name of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function callContract(uint256 contractAddress, string memory functionName)
function delegatecallCairo(uint256 contractAddress, string memory functionName)
internal
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
uint256 functionSelector = uint256(keccak256(bytes(functionName))) % 2 ** 250;
return callContract(contractAddress, functionSelector, data);
return delegatecallCairo(contractAddress, functionSelector, data);
}

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
Expand All @@ -55,7 +106,7 @@ library CairoLib {
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @param data The input data for the Cairo contract function.
/// @return returnData The return data from the Cairo contract function.
function staticcallContract(uint256 contractAddress, uint256 functionSelector, uint256[] memory data)
function staticcallCairo(uint256 contractAddress, uint256 functionSelector, uint256[] memory data)
internal
view
returns (bytes memory returnData)
Expand All @@ -74,36 +125,36 @@ library CairoLib {
/// @param contractAddress The address of the Cairo contract.
/// @param functionSelector The function selector of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function staticcallContract(uint256 contractAddress, uint256 functionSelector)
function staticcallCairo(uint256 contractAddress, uint256 functionSelector)
internal
view
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
return staticcallContract(contractAddress, functionSelector, data);
return staticcallCairo(contractAddress, functionSelector, data);
}

/// @notice Performs a low-level call to a Cairo contract deployed on the Starknet appchain.
/// @dev Used with intent to read the state of the Cairo contract.
/// @param contractAddress The address of the Cairo contract.
/// @param functionName The name of the Cairo contract function to be called.
/// @return returnData The return data from the Cairo contract function.
function staticcallContract(uint256 contractAddress, string memory functionName)
function staticcallCairo(uint256 contractAddress, string memory functionName)
internal
view
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
uint256 functionSelector = uint256(keccak256(bytes(functionName))) % 2 ** 250;
return staticcallContract(contractAddress, functionSelector, data);
return staticcallCairo(contractAddress, functionSelector, data);
}

/// @dev Performs a low-level call to a Cairo class declared on the Starknet appchain.
/// @param classHash The class hash of the Cairo class.
/// @param functionSelector The function selector of the Cairo class function to be called.
/// @param data The input data for the Cairo class function.
/// @return returnData The return data from the Cairo class function.
function libraryCall(uint256 classHash, uint256 functionSelector, uint256[] memory data)
function libraryCallCairo(uint256 classHash, uint256 functionSelector, uint256[] memory data)
internal
view
returns (bytes memory returnData)
Expand All @@ -121,23 +172,27 @@ library CairoLib {
/// @param classHash The class hash of the Cairo class.
/// @param functionSelector The function selector of the Cairo class function to be called.
/// @return returnData The return data from the Cairo class function.
function libraryCall(uint256 classHash, uint256 functionSelector) internal view returns (bytes memory returnData) {
function libraryCallCairo(uint256 classHash, uint256 functionSelector)
internal
view
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
return libraryCall(classHash, functionSelector, data);
return libraryCallCairo(classHash, functionSelector, data);
}

/// @dev Performs a low-level call to a Cairo class declared on the Starknet appchain.
/// @param classHash The class hash of the Cairo class.
/// @param functionName The name of the Cairo class function to be called.
/// @return returnData The return data from the Cairo class function.
function libraryCall(uint256 classHash, string memory functionName)
function libraryCallCairo(uint256 classHash, string memory functionName)
internal
view
returns (bytes memory returnData)
{
uint256[] memory data = new uint256[](0);
uint256 functionSelector = uint256(keccak256(bytes(functionName))) % 2 ** 250;
return libraryCall(classHash, functionSelector, data);
return libraryCallCairo(classHash, functionSelector, data);
}

/// @notice Performs a low-level call to send a message from the Kakarot to the Ethereum network.
Expand Down
2 changes: 1 addition & 1 deletion solidity_contracts/src/CairoPrecompiles/PragmaCaller.sol
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ contract PragmaCaller {
data[2] = request.expirationTimestamp;
}

bytes memory returnData = pragmaOracle.staticcallContract(FUNCTION_SELECTOR_GET_DATA_MEDIAN, data);
bytes memory returnData = pragmaOracle.staticcallCairo(FUNCTION_SELECTOR_GET_DATA_MEDIAN, data);

assembly {
// Load the values from the return data
Expand Down
14 changes: 6 additions & 8 deletions src/kakarot/interpreter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,25 @@ namespace Interpreter {
let is_pc_ge_code_len = is_le(evm.message.bytecode_len, pc);
if (is_pc_ge_code_len != FALSE) {
let is_precompile = Precompiles.is_precompile(evm.message.code_address.evm);
let caller_address = evm.message.caller;
// Caller of the contract that is calling the precompile
if (is_precompile != FALSE) {
// If the precompile is called straight from an EOA, the sender_context is the EOA
// Otherwise, the sender_context is the caller of the contract that is calling the precompile
// This is only relevant for the Kakarot Cairo Module precompile.
let parent_context = evm.message.parent;
let is_parent_zero = Helpers.is_zero(cast(parent_context, felt));
if (is_parent_zero != FALSE) {
tempvar sender_context = evm.message.caller;
// Case A: The precompile is called straight from an EOA
tempvar caller_code_address = evm.message.caller;
} else {
tempvar sender_context = parent_context.evm.message.caller;
// Case B: The precompile is called from a contract
tempvar caller_code_address = parent_context.evm.message.code_address.evm;
}
tempvar caller_address = evm.message.caller;
let (
output_len, output, gas_used, precompile_reverted
) = Precompiles.exec_precompile(
evm.message.code_address.evm,
evm.message.calldata_len,
evm.message.calldata,
caller_code_address,
caller_address,
sender_context,
);
let evm = EVM.charge_gas(evm, gas_used);
let evm_reverted = is_not_zero(evm.reverted);
Expand Down
16 changes: 7 additions & 9 deletions src/kakarot/precompiles/kakarot_precompiles.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ namespace KakarotPrecompiles {
// @notice Executes a cairo contract/class.
// @param input_len The length of the input in bytes.
// @param input The input data.
// @param caller_address The address of the contract that calls the precompile
// @param sender_context The address of the sender in the context of the caller contract.
// @param caller_address The address of the caller of the precompile. Delegatecall rules apply.
func cairo_precompile{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(input_len: felt, input: felt*, caller_address: felt, sender_context: felt) -> (
}(input_len: felt, input: felt*, caller_address: felt) -> (
output_len: felt, output: felt*, gas_used: felt, reverted: felt
) {
alloc_locals;
Expand Down Expand Up @@ -79,8 +78,8 @@ namespace KakarotPrecompiles {
let (data_len, data) = Helpers.load_256_bits_array(data_bytes_len, data_ptr);

if (selector == CALL_CONTRACT_SOLIDITY_SELECTOR) {
let sender_starknet_address = Account.get_registered_starknet_address(sender_context);
let is_not_deployed = Helpers.is_zero(sender_starknet_address);
let caller_starknet_address = Account.get_registered_starknet_address(caller_address);
let is_not_deployed = Helpers.is_zero(caller_starknet_address);
if (is_not_deployed != FALSE) {
let (revert_reason_len, revert_reason) = Errors.accountNotDeployed();
Expand All @@ -90,7 +89,7 @@ namespace KakarotPrecompiles {
}

let (retdata_len, retdata, success) = IAccount.execute_starknet_call(
sender_starknet_address, to_starknet_address, starknet_selector, data_len, data
caller_starknet_address, to_starknet_address, starknet_selector, data_len, data
);
let (output) = alloc();
let output_len = retdata_len * 32;
Expand All @@ -115,14 +114,13 @@ namespace KakarotPrecompiles {
// @notice Sends a message to a message to L1.
// @param input_len The length of the input in bytes.
// @param input The input data.
// @param caller_address The address of the contract that calls the precompile
// @param sender_context unused
// @param caller_address unused
func cairo_message{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(input_len: felt, input: felt*, caller_address: felt, sender_context: felt) -> (
}(input_len: felt, input: felt*, caller_address: felt) -> (
output_len: felt, output: felt*, gas_used: felt, reverted: felt
) {
alloc_locals;
Expand Down
Loading

0 comments on commit 22cbf53

Please sign in to comment.