From 36512991e228480fff49c39637ebe80ae92d2a88 Mon Sep 17 00:00:00 2001 From: enitrat Date: Wed, 9 Oct 2024 19:52:34 +0200 Subject: [PATCH] feat: scarb pytest --- cairo/kakarot-ssj/Scarb.toml | 4 + cairo/kakarot-ssj/crates/evm/Scarb.toml | 4 + cairo/kakarot-ssj/crates/evm/src/errors.cairo | 30 +- cairo/kakarot-ssj/crates/evm/src/stack.cairo | 131 ++++++- cairo/kakarot-ssj/crates/utils/Scarb.toml | 4 + cairo/kakarot-ssj/crates/utils/src/lib.cairo | 6 + .../crates/utils/src/pytests/from_array.cairo | 4 + .../crates/utils/src/pytests/json.cairo | 84 ++++ cairo/kakarot-ssj/py_tests/__init__.py | 0 cairo/kakarot-ssj/py_tests/evm/__init__.py | 0 .../kakarot-ssj/py_tests/evm/src/__init__.py | 0 .../kakarot-ssj/py_tests/evm/src/conftest.py | 60 +++ .../py_tests/evm/src/test_stack.py | 127 ++++++ .../py_tests/test_utils/__init__.py | 0 .../py_tests/test_utils/deserializer.py | 77 ++++ .../py_tests/test_utils/serializer.py | 85 ++++ .../kakarot-ssj/py_tests/test_utils/types.py | 362 ++++++++++++++++++ pyproject.toml | 1 + 18 files changed, 963 insertions(+), 16 deletions(-) create mode 100644 cairo/kakarot-ssj/crates/utils/src/pytests/from_array.cairo create mode 100644 cairo/kakarot-ssj/crates/utils/src/pytests/json.cairo create mode 100644 cairo/kakarot-ssj/py_tests/__init__.py create mode 100644 cairo/kakarot-ssj/py_tests/evm/__init__.py create mode 100644 cairo/kakarot-ssj/py_tests/evm/src/__init__.py create mode 100644 cairo/kakarot-ssj/py_tests/evm/src/conftest.py create mode 100644 cairo/kakarot-ssj/py_tests/evm/src/test_stack.py create mode 100644 cairo/kakarot-ssj/py_tests/test_utils/__init__.py create mode 100644 cairo/kakarot-ssj/py_tests/test_utils/deserializer.py create mode 100644 cairo/kakarot-ssj/py_tests/test_utils/serializer.py create mode 100644 cairo/kakarot-ssj/py_tests/test_utils/types.py diff --git a/cairo/kakarot-ssj/Scarb.toml b/cairo/kakarot-ssj/Scarb.toml index b0d420965..9a71c57b8 100644 --- a/cairo/kakarot-ssj/Scarb.toml +++ b/cairo/kakarot-ssj/Scarb.toml @@ -15,3 +15,7 @@ starknet = "2.8.2" [workspace.tool.fmt] sort-module-level-items = true + +[features] +default = [] +pytest = [] diff --git a/cairo/kakarot-ssj/crates/evm/Scarb.toml b/cairo/kakarot-ssj/crates/evm/Scarb.toml index 1180c54a4..b96e24306 100644 --- a/cairo/kakarot-ssj/crates/evm/Scarb.toml +++ b/cairo/kakarot-ssj/crates/evm/Scarb.toml @@ -23,3 +23,7 @@ fmt.workspace = true [scripts] test = "snforge test --max-n-steps 4294967295" test-profiling = "snforge test --max-n-steps 4294967295 --build-profile" + +[features] +default = [] +pytest = [] diff --git a/cairo/kakarot-ssj/crates/evm/src/errors.cairo b/cairo/kakarot-ssj/crates/evm/src/errors.cairo index b6d17eaa5..3f55f3f11 100644 --- a/cairo/kakarot-ssj/crates/evm/src/errors.cairo +++ b/cairo/kakarot-ssj/crates/evm/src/errors.cairo @@ -66,25 +66,25 @@ pub enum EVMError { pub impl EVMErrorImpl of EVMErrorTrait { fn to_string(self: EVMError) -> felt252 { match self { - EVMError::StackOverflow => 'stack overflow', - EVMError::StackUnderflow => 'stack underflow', + EVMError::StackOverflow => 'StackOverflow', + EVMError::StackUnderflow => 'StackUnderflow', EVMError::TypeConversionError(error_message) => error_message, EVMError::NumericOperations(error_message) => error_message, - EVMError::InsufficientBalance => 'insufficient balance', - EVMError::ReturnDataOutOfBounds => 'return data out of bounds', - EVMError::InvalidJump => 'invalid jump destination', - EVMError::InvalidCode => 'invalid code', - EVMError::NotImplemented => 'not implemented', + EVMError::InsufficientBalance => 'InsufficientBalance', + EVMError::ReturnDataOutOfBounds => 'ReturnDataOutOfBounds', + EVMError::InvalidJump => 'InvalidJump', + EVMError::InvalidCode => 'InvalidCode', + EVMError::NotImplemented => 'NotImplemented', EVMError::InvalidParameter(error_message) => error_message, // TODO: refactor with dynamic strings once supported - EVMError::InvalidOpcode => 'invalid opcode'.into(), - EVMError::WriteInStaticContext => 'write protection', - EVMError::Collision => 'create collision'.into(), - EVMError::OutOfGas => 'out of gas'.into(), - EVMError::Assertion => 'assertion failed'.into(), - EVMError::DepthLimit => 'max call depth exceeded'.into(), - EVMError::MemoryLimitOOG => 'memory limit out of gas'.into(), - EVMError::NonceOverflow => 'nonce overflow'.into(), + EVMError::InvalidOpcode => 'InvalidOpcode', + EVMError::WriteInStaticContext => 'WriteInStaticContext', + EVMError::Collision => 'Collision', + EVMError::OutOfGas => 'OutOfGas', + EVMError::Assertion => 'Assertion', + EVMError::DepthLimit => 'DepthLimit', + EVMError::MemoryLimitOOG => 'MemoryLimitOOG', + EVMError::NonceOverflow => 'NonceOverflow', } } diff --git a/cairo/kakarot-ssj/crates/evm/src/stack.cairo b/cairo/kakarot-ssj/crates/evm/src/stack.cairo index 74ddcba68..5a2e3a216 100644 --- a/cairo/kakarot-ssj/crates/evm/src/stack.cairo +++ b/cairo/kakarot-ssj/crates/evm/src/stack.cairo @@ -209,9 +209,17 @@ impl StackImpl of StackTrait { fn pop_n(ref self: Stack, mut n: usize) -> Result, EVMError> { ensure(!(n > self.len()), EVMError::StackUnderflow)?; let mut popped_items = ArrayTrait::::new(); + let mut err = Result::Ok(array![]); for _ in 0..n { - popped_items.append(self.pop().unwrap()); + let popped_item = self.pop(); + match popped_item { + Result::Ok(item) => popped_items.append(item), + Result::Err(pop_error) => { err = Result::Err(pop_error); break;}, + }; }; + if err.is_err() { + return err; + } Result::Ok(popped_items) } @@ -634,3 +642,124 @@ mod tests { } } } + +#[cfg(feature: 'pytest')] +mod pytests { + //! Pytests are tests that are run with the scarb-pytest framework. + //! This framework allows for testing based on various inputs provided by a third-party test + //! runner such as pytest or cargo test. + use core::fmt::{Formatter}; + use crate::errors::{EVMErrorTrait}; + use utils::pytests::json::{JsonMut, Json}; + use utils::pytests::from_array::FromArray; + use crate::stack::{Stack, StackTrait}; + + impl StackJSON of JsonMut { + fn to_json(ref self: Stack) -> ByteArray { + let mut json: ByteArray = ""; + let mut formatter = Formatter { buffer: json }; + write!(formatter, "[").unwrap(); + for i in 0 + ..self + .len() { + let item = self.items.get(i.into()).deref(); + write!(formatter, "{}", item).unwrap(); + if i != self.len() - 1 { + write!(formatter, ", ").unwrap(); + } + }; + write!(formatter, "]").unwrap(); + formatter.buffer + } + } + + impl StackFromArray of FromArray { + type Output = Stack; + fn from_array(array: Span) -> Self::Output { + let mut stack = StackTrait::new(); + for item in array { + stack.push(*item).expect('Stack FromArray failed'); + }; + stack + } + } + + fn test__stack_push(values: Span) -> ByteArray { + let mut stack = StackTrait::new(); + let mut err = Result::Ok(()); + for value in values { + match stack.push(*value) { + Result::Ok(()) => (), + Result::Err(evm_error) => { + err = Result::Err(evm_error); + break; + }, + }; + }; + if err.is_err() { + core::panic_with_felt252(err.unwrap_err().to_string()); + }; + stack.to_json() + } + + fn test__stack_pop(stack: Span) -> ByteArray { + let mut stack = StackFromArray::from_array(stack); + let mut err = Result::Ok(()); + let value = match stack.pop() { + Result::Ok(value) => value, + Result::Err(evm_error) => { err = Result::Err(evm_error); 0}, + }; + if err.is_err() { + core::panic_with_felt252(err.unwrap_err().to_string()); + }; + let mut output: (Stack, u256) = (stack, value); + output.to_json() + } + + fn test__stack_pop_n(stack: Span, n: usize) -> ByteArray { + let mut stack = StackFromArray::from_array(stack); + let mut err = Result::Ok(()); + let values = match stack.pop_n(n) { + Result::Ok(values) => values, + Result::Err(evm_error) => { err = Result::Err(evm_error); array![]}, + }; + if err.is_err() { + core::panic_with_felt252(err.unwrap_err().to_string()); + }; + let mut output: (Stack, Span) = (stack, values.span()); + output.to_json() + } + + fn test__stack_peek(stack: Span, index: usize) -> ByteArray { + let mut stack = StackFromArray::from_array(stack); + let mut err = Result::Ok(()); + let value = match stack.peek_at(index) { + Result::Ok(value) => value, + Result::Err(evm_error) => { err = Result::Err(evm_error); 0}, + }; + if err.is_err() { + core::panic_with_felt252(err.unwrap_err().to_string()); + }; + + let mut output: (Stack, u256) = (stack, value); + output.to_json() + } + + fn test__stack_swap(stack: Span, index: usize) -> ByteArray { + let mut stack = StackFromArray::from_array(stack); + let mut err = Result::Ok(()); + match stack.swap_i(index) { + Result::Ok(()) => (), + Result::Err(evm_error) => { err = Result::Err(evm_error); }, + }; + if err.is_err() { + core::panic_with_felt252(err.unwrap_err().to_string()); + }; + stack.to_json() + } + + fn test__stack_new() -> ByteArray { + let mut stack: Stack = Default::default(); + stack.to_json() + } +} diff --git a/cairo/kakarot-ssj/crates/utils/Scarb.toml b/cairo/kakarot-ssj/crates/utils/Scarb.toml index 83548065f..9854ba2fd 100644 --- a/cairo/kakarot-ssj/crates/utils/Scarb.toml +++ b/cairo/kakarot-ssj/crates/utils/Scarb.toml @@ -23,3 +23,7 @@ assert_macros = "2.8.2" [scripts] test = "snforge test --max-n-steps 4294967295" test-profiling = "snforge test --max-n-steps 4294967295 --build-profile" + +[features] +default = [] +pytest = [] diff --git a/cairo/kakarot-ssj/crates/utils/src/lib.cairo b/cairo/kakarot-ssj/crates/utils/src/lib.cairo index 05a2f2975..858cdea8a 100644 --- a/cairo/kakarot-ssj/crates/utils/src/lib.cairo +++ b/cairo/kakarot-ssj/crates/utils/src/lib.cairo @@ -14,3 +14,9 @@ pub mod set; pub mod test_data; pub mod traits; pub mod utils; + +// #[cfg(feature: 'pytest')] +pub mod pytests { + pub mod json; + pub mod from_array; +} diff --git a/cairo/kakarot-ssj/crates/utils/src/pytests/from_array.cairo b/cairo/kakarot-ssj/crates/utils/src/pytests/from_array.cairo new file mode 100644 index 000000000..81db9df2f --- /dev/null +++ b/cairo/kakarot-ssj/crates/utils/src/pytests/from_array.cairo @@ -0,0 +1,4 @@ +pub trait FromArray { + type Output; + fn from_array(array: Span) -> Self::Output; +} diff --git a/cairo/kakarot-ssj/crates/utils/src/pytests/json.cairo b/cairo/kakarot-ssj/crates/utils/src/pytests/json.cairo new file mode 100644 index 000000000..f70c78b01 --- /dev/null +++ b/cairo/kakarot-ssj/crates/utils/src/pytests/json.cairo @@ -0,0 +1,84 @@ +use core::fmt::Formatter; + +pub trait JsonMut { + fn to_json(ref self: T) -> ByteArray; +} + +pub trait Json { + fn to_json(self: @T) -> ByteArray; +} + +impl TupleTwoJson, +Json, +Destruct, +Json> of Json<(T1, T2)> { + fn to_json(self: @(T1, T2)) -> ByteArray { + let (t1, t2) = self; + format!("[{}, {}]", t1.to_json(), t2.to_json()) + } +} + +impl TupleThreeJson, +Json, +Destruct, +Json, +Destruct, +Json> of Json<(T1, T2, T3)> { + fn to_json(self: @(T1, T2, T3)) -> ByteArray { + let (t1, t2, t3) = self; + format!("[{}, {}, {}]", t1.to_json(), t2.to_json(), t3.to_json()) + } +} + +impl TupleTwoJsonMut, +JsonMut, +Destruct, +Json> of JsonMut<(T1, T2)> { + fn to_json(ref self: (T1, T2)) -> ByteArray { + let (mut t1, mut t2) = self; + let res = format!("[{}, {}]", t1.to_json(), t2.to_json()); + self = (t1, t2); + res + } +} + +impl TupleThreeJsonMut, +JsonMut, +Destruct, +JsonMut, +Destruct, +JsonMut> of JsonMut<(T1, T2, T3)> { + fn to_json(ref self: (T1, T2, T3)) -> ByteArray { + let (mut t1, mut t2, mut t3) = self; + let res = format!("[{}, {}, {}]", t1.to_json(), t2.to_json(), t3.to_json()); + self = (t1, t2, t3); + res + } +} + +impl SpanJSON, +Drop, +Copy, +Json, +PartialEq> of Json> { + fn to_json(self: @Span) -> ByteArray { + let self = *self; + let mut json: ByteArray = ""; + let mut formatter = Formatter { buffer: json }; + write!(formatter, "[").expect('JSON formatting failed'); + for value in self { + let value = *value; + write!(formatter, "{}", value.to_json()).expect('JSON formatting failed'); + if value != *self.at(self.len() - 1) { + write!(formatter, ", ").expect('JSON formatting failed'); + } + }; + write!(formatter, "]").expect('JSON formatting failed'); + formatter.buffer + } +} + +impl SpanJsonMut, +Drop, +Copy, +JsonMut, +PartialEq> of JsonMut> { + fn to_json(ref self: Span) -> ByteArray { + self.to_json() + } +} + + + +impl U256Json = integer_json::IntegerJSON; +impl U128Json = integer_json::IntegerJSON; +impl U64Json = integer_json::IntegerJSON; +impl U32Json = integer_json::IntegerJSON; +impl U16Json = integer_json::IntegerJSON; +impl U8Json = integer_json::IntegerJSON; + +pub mod integer_json { + use super::Json; + + pub(crate) impl IntegerJSON, +Drop, +Copy> of Json { + fn to_json(self: @T) -> ByteArray { + format!("{}", *self) + } + } +} diff --git a/cairo/kakarot-ssj/py_tests/__init__.py b/cairo/kakarot-ssj/py_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cairo/kakarot-ssj/py_tests/evm/__init__.py b/cairo/kakarot-ssj/py_tests/evm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cairo/kakarot-ssj/py_tests/evm/src/__init__.py b/cairo/kakarot-ssj/py_tests/evm/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cairo/kakarot-ssj/py_tests/evm/src/conftest.py b/cairo/kakarot-ssj/py_tests/evm/src/conftest.py new file mode 100644 index 000000000..2ffdb27a3 --- /dev/null +++ b/cairo/kakarot-ssj/py_tests/evm/src/conftest.py @@ -0,0 +1,60 @@ +import json +import re +import subprocess +from typing import Any, Callable, Tuple, Type, Union + +import pytest +from py_tests.test_utils.deserializer import Deserializer +from py_tests.test_utils.serializer import Serializer +from py_tests.test_utils.types import ByteArray + + +@pytest.fixture +def cairo_run() -> Callable[[str, Union[Type[Any], Tuple[Type[Any], ...]], ...], Any]: + def _cairo_run( + function_name: str, + output_type: Union[Type[Any], Tuple[Type[Any], ...]], + *args: Any, + ) -> Any: + # Serialize arguments into a compatible format for scarb cairo-run + # JSON encode the serialized arguments - [1,2,3] -> "[1,2,3]" + serialized_args = json.dumps(Serializer.serialize_args(args)) + + command = [ + "scarb", + "pytest", + "-p", + "evm", + "--function", + function_name, + serialized_args, + "--no-build", + ] + + try: + result = subprocess.run( + command, + cwd="cairo/kakarot-ssj", + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Command failed with error: {e.stderr}") from e + + stdout = result.stdout.strip() + # Extract panic message if present + panic_match = re.search(r"Run panicked with \[\d+ \(\'(.*?)\'\)", stdout) + if panic_match: + raise ValueError(f"Run panicked with: {panic_match.group(1)}") + + match = re.search( + r"Run completed successfully, returning (\[.*?\])", result.stdout + ) + if not match: + raise ValueError("No array found in the output") + + output = ByteArray(json.loads(match.group(1))) + return Deserializer.deserialize(output, output_type) + + return _cairo_run diff --git a/cairo/kakarot-ssj/py_tests/evm/src/test_stack.py b/cairo/kakarot-ssj/py_tests/evm/src/test_stack.py new file mode 100644 index 000000000..2f884d7d0 --- /dev/null +++ b/cairo/kakarot-ssj/py_tests/evm/src/test_stack.py @@ -0,0 +1,127 @@ +from typing import List + +import pytest +from py_tests.test_utils.types import U256, Stack + + +class TestStack: + def test_new(self, cairo_run): + result = cairo_run("test__stack_new", Stack) + assert result == Stack([]) + + def test_push(self, cairo_run): + values = [U256(0x10), U256(0x20)] + result = cairo_run("test__stack_push", Stack, values) + assert result == Stack([U256(0x10), U256(0x20)]) + + def test_pop(self, cairo_run): + stack = [U256(0x10), U256(0x20), U256(0x30)] + result = cairo_run("test__stack_pop", (Stack, U256), stack) + assert result[0] == Stack([U256(0x10), U256(0x20)]) + assert result[1] == U256(0x30) + + def test_pop_n(self, cairo_run): + stack = [U256(0x10), U256(0x20), U256(0x30), U256(0x40), U256(0x50)] + n = 3 + result = cairo_run("test__stack_pop_n", (Stack, List[U256]), stack, n) + assert result[0] == Stack([U256(0x10), U256(0x20)]) + assert result[1] == [U256(0x50), U256(0x40), U256(0x30)] + + def test_peek(self, cairo_run): + stack = [U256(0x10), U256(0x20), U256(0x30)] + index = 1 + result = cairo_run("test__stack_peek", (Stack, U256), stack, index) + assert result[0] == Stack([U256(0x10), U256(0x20), U256(0x30)]) + assert result[1] == U256(0x20) + + def test_swap(self, cairo_run): + stack = [U256(0x1), U256(0x2), U256(0x3), U256(0x4)] + index = 2 + result = cairo_run("test__stack_swap", Stack, stack, index) + assert result == Stack([U256(0x1), U256(0x4), U256(0x3), U256(0x2)]) + + def test_push_when_full(self, cairo_run): + # Create a full stack (1024 elements) + values = [U256(i) for i in range(1024)] + result = cairo_run("test__stack_push", Stack, values) + assert result == Stack(values) + + # Try to push one more element + with pytest.raises(Exception) as exc_info: + cairo_run("test__stack_push", Stack, values + [U256(1024)]) + assert "StackOverflow" in str(exc_info.value) + + def test_pop_when_empty(self, cairo_run): + with pytest.raises(Exception) as exc_info: + cairo_run("test__stack_pop", (Stack, U256), []) + assert "StackUnderflow" in str(exc_info.value) + + def test_peek_when_empty(self, cairo_run): + with pytest.raises(Exception) as exc_info: + cairo_run("test__stack_peek", (Stack, U256), [], 0) + assert "StackUnderflow" in str(exc_info.value) + + def test_pop_n_underflow(self, cairo_run): + stack = [U256(0x10), U256(0x20)] + with pytest.raises(Exception) as exc_info: + cairo_run("test__stack_pop_n", (Stack, List[U256]), stack, 3) + assert "StackUnderflow" in str(exc_info.value) + + def test_swap_underflow(self, cairo_run): + stack = [U256(0x1), U256(0x2)] + with pytest.raises(Exception) as exc_info: + cairo_run("test__stack_swap", Stack, stack, 2) + assert "StackUnderflow" in str(exc_info.value) + + def test_push_multiple_and_pop_multiple(self, cairo_run): + values = [U256(0x10), U256(0x20), U256(0x30), U256(0x40)] + result = cairo_run("test__stack_push", Stack, values) + assert result == Stack(values) + + result = cairo_run("test__stack_pop_n", (Stack, List[U256]), values, 2) + assert result[0] == Stack([U256(0x10), U256(0x20)]) + assert result[1] == [U256(0x40), U256(0x30)] + + def test_peek_at_various_indices(self, cairo_run): + stack = [U256(0x10), U256(0x20), U256(0x30), U256(0x40)] + for i, expected in enumerate(reversed(stack)): + result = cairo_run("test__stack_peek", (Stack, U256), stack, i) + assert result[0] == Stack(stack) + assert result[1] == expected + + def test_swap_various_indices(self, cairo_run): + stack = [U256(0x1), U256(0x2), U256(0x3), U256(0x4)] + for i in range(1, len(stack)): + expected = stack.copy() + expected[-1], expected[-1 - i] = expected[-1 - i], expected[-1] + result = cairo_run("test__stack_swap", Stack, stack, i) + assert result == Stack(expected) + + def test_push_pop_peek_combination(self, cairo_run): + # Push some values + values = [U256(0x10), U256(0x20), U256(0x30)] + result = cairo_run("test__stack_push", Stack, values) + assert result == Stack(values) + + # Peek at the top + result = cairo_run("test__stack_peek", (Stack, U256), values, 0) + assert result[0] == Stack(values) + assert result[1] == U256(0x30) + + # Pop one value + result = cairo_run("test__stack_pop", (Stack, U256), values) + assert result[0] == Stack([U256(0x10), U256(0x20)]) + assert result[1] == U256(0x30) + + # Push another value + result = cairo_run( + "test__stack_push", Stack, [U256(0x10), U256(0x20), U256(0x40)] + ) + assert result == Stack([U256(0x10), U256(0x20), U256(0x40)]) + + # Peek at index 1 + result = cairo_run( + "test__stack_peek", (Stack, U256), [U256(0x10), U256(0x20), U256(0x40)], 1 + ) + assert result[0] == Stack([U256(0x10), U256(0x20), U256(0x40)]) + assert result[1] == U256(0x20) diff --git a/cairo/kakarot-ssj/py_tests/test_utils/__init__.py b/cairo/kakarot-ssj/py_tests/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cairo/kakarot-ssj/py_tests/test_utils/deserializer.py b/cairo/kakarot-ssj/py_tests/test_utils/deserializer.py new file mode 100644 index 000000000..fadd79a8f --- /dev/null +++ b/cairo/kakarot-ssj/py_tests/test_utils/deserializer.py @@ -0,0 +1,77 @@ +import json +from typing import Any, Tuple, Type, Union, get_origin + +from py_tests.test_utils.types import TYPE_MAP, ByteArray, handle_list_type + + +class Deserializer: + """ + A utility class for deserializing ByteArray objects into specific types or tuples of types. + """ + + @staticmethod + def deserialize( + byte_array: ByteArray, output_type: Union[Type[Any], Tuple[Type[Any], ...]] + ) -> Any: + """ + Deserialize a ByteArray object into a specified output type or tuple of types. + + Args: + ---- + byte_array (ByteArray): The ByteArray object to deserialize. + output_type (Union[Type[Any], Tuple[Type[Any], ...]]): The desired output type(s). + + Returns: + ------- + Any: An instance of the specified output_type or a tuple of instances. + + Raises: + ------ + ValueError: If the output_type is not supported or if the data doesn't match the expected format. + + """ + json_string = byte_array.to_string() + data = json.loads(json_string) + + if isinstance(output_type, tuple): + if not isinstance(data, list) or len(data) != len(output_type): + raise ValueError( + f"Expected a list of {len(output_type)} elements, got {data}" + ) + return tuple( + Deserializer._deserialize_single(json.dumps(item), t) + for t, item in zip(output_type, data) + ) + else: + return Deserializer._deserialize_single(json_string, output_type) + + @staticmethod + def _deserialize_single(json_string: str, output_type: Type[Any]) -> Any: + """ + Deserialize a JSON string into a single specified output type. + + Args: + ---- + json_string (str): The JSON string to deserialize. + output_type (Type[Any]): The desired output type. + + Returns: + ------- + Any: An instance of the specified output_type. + + Raises: + ------ + ValueError: If the output_type is not supported. + + """ + # Handle deserialization based on the output_type + # 1. If output_type is a list, use a specialized list handling function + # 2. If output_type is not in the predefined TYPE_MAP, raise an error + # 3. Otherwise, use the appropriate type class to deserialize from JSON + if get_origin(output_type) is list: + return handle_list_type(json_string, output_type) + elif output_type.__name__ not in TYPE_MAP: + raise ValueError(f"Unsupported output type: {output_type.__name__}") + + type_class = TYPE_MAP[output_type.__name__] + return type_class.from_json(json_string) diff --git a/cairo/kakarot-ssj/py_tests/test_utils/serializer.py b/cairo/kakarot-ssj/py_tests/test_utils/serializer.py new file mode 100644 index 000000000..eddfc6c02 --- /dev/null +++ b/cairo/kakarot-ssj/py_tests/test_utils/serializer.py @@ -0,0 +1,85 @@ +from typing import Any, List + + +class Serializer: + """ + A utility class for serializing arguments, particularly useful for handling + nested structures and custom objects with 'to_felt_array' method. + """ + + @staticmethod + def serialize_args(args: List[Any]) -> List[Any]: + """ + Serialize a list of arguments, maintaining the structure for nested lists. + + Args: + ---- + args (List[Any]): A list of arguments to be serialized. + + Returns: + ------- + List[Any]: A list of serialized arguments, preserving the original structure. + + Example: + ------- + >>> Serializer.serialize_args([[U256(0x10), U256(0x20)], 3]) + [[16, 0, 32, 0], 3] + + """ + serialized_args = [] + for arg in args: + if isinstance(arg, list): + # Keep the list structure, but serialize its contents + serialized_args.append(Serializer.serialize_list(arg)) + else: + serialized_args.append(Serializer.serialize_single(arg)) + return serialized_args + + @staticmethod + def serialize_list(arg_list: List[Any]) -> List[Any]: + """ + Serialize a list of arguments, flattening any nested structures. + + Args: + ---- + arg_list (List[Any]): A list of arguments to be serialized. + + Returns: + ------- + List[Any]: A flattened list of serialized arguments. + + Example: + ------- + >>> Serializer.serialize_list([U256(0x10), U256(0x20)]) + [16, 0, 32, 0] + + """ + return [item for arg in arg_list for item in Serializer.serialize_single(arg)] + + @staticmethod + def serialize_single(arg: Any) -> Any: + """ + Serialize a single argument. + + If the argument has a 'to_felt_array' method, it uses that for serialization. + Otherwise, it returns the argument as is. + + Args: + ---- + arg (Any): The argument to be serialized. + + Returns: + ------- + Any: The serialized argument, either as a felt array or the original value. + + Example: + ------- + >>> Serializer.serialize_single(5) + 5 + >>> Serializer.serialize_single(U256(0x10)) + [16, 0] + + """ + if hasattr(arg, "to_felt_array"): + return arg.to_felt_array() + return arg # Return single values as is, without wrapping in a list diff --git a/cairo/kakarot-ssj/py_tests/test_utils/types.py b/cairo/kakarot-ssj/py_tests/test_utils/types.py new file mode 100644 index 000000000..11f7016d4 --- /dev/null +++ b/cairo/kakarot-ssj/py_tests/test_utils/types.py @@ -0,0 +1,362 @@ +import json +from typing import Any, Generic, List, Tuple, Type, TypeVar, Union, get_args + +BYTE_ARRAY_MAGIC = 0x046A6158A16A947E5916B2A2CA68501A45E93D7110E81AA2D6438B1C57C879A3 +BYTES_IN_WORD = 31 + +T = TypeVar("T") + + +class ListType(Generic[T]): + """ + A generic class for handling list types in JSON deserialization. + """ + + @classmethod + def from_json(cls, json_string: str, item_type: Type[T]) -> List[T]: + """ + Deserialize a JSON string into a list of specified item type. + + Args: + ---- + json_string (str): The JSON string to deserialize. + item_type (Type[T]): The type of items in the list. + + Returns: + ------- + List[T]: A list of deserialized items of the specified type. + + Raises: + ------ + ValueError: If the JSON string doesn't represent a list. + + """ + data = json.loads(json_string) + if not isinstance(data, list): + raise ValueError( + f"Invalid JSON format for List. Expected a list, got {type(data)}" + ) + return [ + TYPE_MAP[item_type.__name__].from_json(json.dumps(item)) for item in data + ] + + +def parse_json_output( + json_string: str, output_type: Union[Type, Tuple[Type, ...], List[Type]] +) -> Any: + """ + Parse a JSON string into the specified output type(s). + + Args: + ---- + json_string (str): The JSON string to parse. + output_type (Union[Type, Tuple[Type, ...], List[Type]]): The desired output type(s). + + Returns: + ------- + Any: The parsed data in the specified output type(s). + + Raises: + ------ + ValueError: If the JSON data doesn't match the expected output type(s). + + """ + data = json.loads(json_string) + + if isinstance(output_type, tuple): + if not isinstance(data, list) or len(data) != len(output_type): + raise ValueError( + f"Expected a list of {len(output_type)} elements, got {data}" + ) + return tuple( + TYPE_MAP[t.__name__].from_json(json.dumps(item)) + for t, item in zip(output_type, data) + ) + elif isinstance(output_type, list): + if not isinstance(data, list): + raise ValueError(f"Expected a list, got {type(data)}") + return [ + TYPE_MAP[output_type[0].__name__].from_json(json.dumps(item)) + for item in data + ] + else: + return TYPE_MAP[output_type.__name__].from_json(json_string) + + +class ByteArray: + """ + Represents a byte array and provides methods for string conversion and debugging. + """ + + def __init__(self, felt_array: List[int]): + self.felt_array = felt_array + + def to_string(self) -> str: + """ + Convert the ByteArray to a string representation. + + Returns + ------- + str: A formatted string representation of the ByteArray. + + """ + return self.format_for_debug() + + def format_for_debug(self) -> str: + """ + Format the ByteArray for debugging purposes. + Note: The input felt array is expected to begin with the BYTE_ARRAY_MAGIC number. + Otherwise, it's interpreted as a debug string directly. + + This method processes the felt_array and formats each item for debugging. + If the result is a single string item, it returns that item directly. + Otherwise, it returns a formatted string with each item on a new line, + prefixed with '[DEBUG]' for non-string items. + + + Source: https://github.com/starkware-libs/cairo/blob/main/crates/cairo-lang-runner/src/casm_run/mod.rs#L2281 + + Returns + ------- + str: A formatted string representation of the ByteArray for debugging. + + """ + items = [] + i = 0 + while i < len(self.felt_array): + item, is_string, consumed = self.format_next_item(self.felt_array[i:]) + items.append((item, is_string)) + i += consumed + + if len(items) == 1 and items[0][1]: + return items[0][0] + + return "".join( + f"{item}\n" if is_string else f"[DEBUG]\t{item}\n" + for item, is_string in items + ).strip() # Remove trailing newline + + @staticmethod + def format_next_item(values: List[int]) -> Tuple[str, bool, int]: + if not values: + return None, False, 0 + + if values[0] == BYTE_ARRAY_MAGIC: + string, consumed = ByteArray.try_format_string(values) + if string is not None: + return string, True, consumed + + return ByteArray.format_short_string(values[0]), False, 1 + + @staticmethod + def format_short_string(value: int) -> str: + as_string = ByteArray.as_cairo_short_string(value) + if as_string: + return f"{value:#x} ('{as_string}')" + return f"{value:#x}" + + @staticmethod + def try_format_string(values: List[int]) -> Tuple[Union[str, None], int]: + if len(values) < 4: + return None, 0 + + num_full_words = values[1] + full_words = values[2 : 2 + num_full_words] + pending_word = values[2 + num_full_words] + pending_word_len = values[2 + num_full_words + 1] + + if len(full_words) != num_full_words: + return None, 0 + + full_words_string = "".join( + ByteArray.as_cairo_short_string_ex(word, BYTES_IN_WORD) or "" + for word in full_words + ) + pending_word_string = ByteArray.as_cairo_short_string_ex( + pending_word, pending_word_len + ) + + if pending_word_string is None: + return None, 0 + + result = full_words_string + pending_word_string + return result, 2 + num_full_words + 2 + + @staticmethod + def as_cairo_short_string(felt: int) -> Union[str, None]: + as_string = "" + is_end = False + for byte in felt.to_bytes(32, "big"): + if byte == 0: + is_end = True + elif is_end: + return None + elif ByteArray.is_ascii_graphic(byte) or ByteArray.is_ascii_whitespace( + byte + ): + as_string += chr(byte) + else: + return None + return as_string + + @staticmethod + def is_ascii_graphic(byte: int) -> bool: + return 33 <= byte <= 126 # b'!' to b'~' + + @staticmethod + def is_ascii_whitespace(byte: int) -> bool: + return byte in (9, 10, 12, 13, 32) # \t, \n, \f, \r, space + + @staticmethod + def as_cairo_short_string_ex(felt: int, length: int) -> Union[str, None]: + if length == 0: + return "" if felt == 0 else None + if length > 31: + return None + + bytes_data = felt.to_bytes(32, "big") + bytes_data = bytes_data[-length:] # Take last 'length' bytes + + as_string = "" + for byte in bytes_data: + if byte == 0: + as_string += r"\0" + elif ByteArray.is_ascii_graphic(byte) or ByteArray.is_ascii_whitespace( + byte + ): + as_string += chr(byte) + else: + as_string += f"\\x{byte:02x}" + + # Prepend missing nulls + missing_nulls = length - len(bytes_data) + as_string = r"\0" * missing_nulls + as_string + + return as_string + + +class UIntBase: + """ + Base class for unsigned integer types with a maximum value. + """ + + def __init__(self, value: int): + if not 0 <= value <= self.MAX_VALUE: + raise ValueError( + f"{self.__class__.__name__} value must be between 0 and {self.MAX_VALUE}" + ) + self.value = value + + def __eq__(self, other: "UIntBase") -> bool: + return isinstance(other, self.__class__) and self.value == other.value + + def __repr__(self) -> str: + return f"{self.value:#x}" + + @classmethod + def from_json(cls, json_string: str) -> "UIntBase": + value = json.loads(json_string) + if not isinstance(value, int): + raise ValueError( + f"Invalid JSON format for {cls.__name__}. Expected a single integer." + ) + return cls(value) + + def to_felt_array(self) -> List[int]: + return [self.value] + + +class U8(UIntBase): + MAX_VALUE = 2**8 - 1 + + +class U16(UIntBase): + MAX_VALUE = 2**16 - 1 + + +class U32(UIntBase): + MAX_VALUE = 2**32 - 1 + + +class U64(UIntBase): + MAX_VALUE = 2**64 - 1 + + +class U128(UIntBase): + MAX_VALUE = 2**128 - 1 + + +class U256(UIntBase): + """ + Represents an unsigned 256-bit integer. + """ + + MAX_VALUE = 2**256 - 1 + + def to_felt_array(self) -> List[int]: + """ + Convert the U256 value to a list of field elements. + + Returns + ------- + List[int]: A list containing two 128-bit integers representing the U256 value. + + """ + return [self.value & 2**128 - 1, self.value >> 128] + + +class Stack: + """ + Represents a stack of U256 values. + """ + + def __init__(self, values: List[U256]): + self.values = values + + def __eq__(self, other: "Stack") -> bool: + return self.values == other.values + + def __repr__(self) -> str: + return f"Stack({self.values})" + + @classmethod + def from_json(cls, json_string: str) -> "Stack": + int_array = json.loads(json_string) + if not isinstance(int_array, list): + raise ValueError( + f"Invalid JSON format: expected list, got {type(int_array)}" + ) + values = [U256(int_value) for int_value in int_array] + return cls(values) + + +# Add a dictionary to map type names to their respective classes +TYPE_MAP = { + "U8": U8, + "U16": U16, + "U32": U32, + "U64": U64, + "U128": U128, + "U256": U256, + "Stack": Stack, + "Tuple": Tuple, + "List": ListType, +} + + +def handle_list_type(json_string: str, list_type: Type[List[Any]]) -> List[Any]: + """ + Handle deserialization of list types. + + Args: + ---- + json_string (str): The JSON string to deserialize. + list_type (Type[List[Any]]): The type of the list. + + Returns: + ------- + List[Any]: A list of deserialized items. + + """ + item_type = get_args(list_type)[0] + return ListType.from_json(json_string, item_type) diff --git a/pyproject.toml b/pyproject.toml index 365ec16ab..625ba55bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ deploy = "kakarot_scripts.deploy_kakarot:main_sync" ef_tests = "kakarot_scripts.ef_tests.fetch:generate_tests" [tool.uv.sources] +py_tests = { path = "./cairo/kakarot-ssj/py_tests" } kakarot-scripts = { path = "./kakarot_scripts" } ethereum = { git = "https://github.com/ethereum/execution-specs.git" }