From 0c8aad0b404960c8c483a005e8daebbe6d0f54a4 Mon Sep 17 00:00:00 2001 From: sword_smith Date: Mon, 2 Sep 2024 13:38:12 +0200 Subject: [PATCH] test(struct): ensure max-field length cannot be exceeded in hashing context --- .../structs/exceed_allowed_field_size.rs | 96 +++++++++++++++---- 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/src/tests_and_benchmarks/ozk/programs/structs/exceed_allowed_field_size.rs b/src/tests_and_benchmarks/ozk/programs/structs/exceed_allowed_field_size.rs index 03300ef..4f5317f 100644 --- a/src/tests_and_benchmarks/ozk/programs/structs/exceed_allowed_field_size.rs +++ b/src/tests_and_benchmarks/ozk/programs/structs/exceed_allowed_field_size.rs @@ -1,6 +1,7 @@ use arbitrary::Arbitrary; use tasm_lib::prelude::TasmObject; use tasm_lib::triton_vm::prelude::*; +use twenty_first::prelude::AlgebraicHasher; use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; @@ -22,7 +23,7 @@ struct TestStruct { g: Digest, } -fn main() { +fn field_getter() { let test_struct: Box = TestStruct::decode(&tasm::load_from_memory(BFieldElement::new(300))).unwrap(); @@ -33,6 +34,15 @@ fn main() { return; } +fn hash_boxed_struct() { + let test_struct: Box = + TestStruct::decode(&tasm::load_from_memory(BFieldElement::new(300))).unwrap(); + let ts_digest: Digest = Tip5::hash(&test_struct); + tasm::tasmlib_io_write_to_stdout___digest(ts_digest); + + return; +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -53,55 +63,103 @@ mod tests { const OBJ_POINTER: BFieldElement = BFieldElement::new(300); - fn prepare_random_object(seed: [u8; 32]) -> TestStruct { + /// Return non-determinism and object used in memory-initialization + fn nd_mem_with_random_obj(seed: [u8; 32]) -> (NonDeterminism, TestStruct) { let mut rng: StdRng = SeedableRng::from_seed(seed); let mut randomness = [0u8; 1_000_000]; rng.fill_bytes(&mut randomness); let mut unstructured = Unstructured::new(&randomness); - TestStruct::arbitrary(&mut unstructured).unwrap() + let ts = TestStruct::arbitrary(&mut unstructured).unwrap(); + + let mut beningn_memory = HashMap::default(); + encode_to_memory(&mut beningn_memory, OBJ_POINTER, &ts); + (NonDeterminism::default().with_ram(beningn_memory), ts) + } + + #[test] + fn exceed_allowed_total_length_test() { + // Positive test + let (benign_nd, ts) = nd_mem_with_random_obj(random()); + let expected_output = Tip5::hash(&ts).values().to_vec(); + let stdin = vec![]; + let native_output = wrap_main_with_io(&hash_boxed_struct)(stdin.clone(), benign_nd.clone()); + assert_eq!(expected_output, native_output); + + let entrypoint = + EntrypointLocation::disk("structs", "exceed_allowed_field_size", "hash_boxed_struct"); + let vm_output = TritonVMTestCase::new(entrypoint.clone()) + .with_non_determinism(benign_nd.clone()) + .execute() + .unwrap(); + assert_eq!(native_output, vm_output.public_output); + + // Negative test: size indicator exceeds max allowed size, but is valid u32. + const OFFSET_FOR_OF_MALICIOUS_SI: BFieldElement = BFieldElement::new(5); + let mut malicious_nd_ram_0 = benign_nd.clone(); + malicious_nd_ram_0.ram.insert( + OBJ_POINTER + OFFSET_FOR_OF_MALICIOUS_SI, + bfe!(DataType::MAX_DYN_FIELD_SIZE + 1), + ); + let err0 = TritonVMTestCase::new(entrypoint.clone()) + .with_non_determinism(malicious_nd_ram_0) + .execute() + .unwrap_err(); + let err0 = err0.downcast::().unwrap(); + assert_eq!(InstructionError::AssertionFailed, err0); + + // Negative test: size indicator is negative + let mut malicious_nd_ram_1 = benign_nd.clone(); + let negative_number = bfe!(-1); + malicious_nd_ram_1 + .ram + .insert(OBJ_POINTER + OFFSET_FOR_OF_MALICIOUS_SI, negative_number); + let err1 = TritonVMTestCase::new(entrypoint.clone()) + .with_non_determinism(malicious_nd_ram_1) + .execute() + .unwrap_err(); + let err1 = err1.downcast::().unwrap(); + assert_eq!(InstructionError::FailedU32Conversion(negative_number), err1); } #[test] fn exceed_allowed_size_indicator_test() { // Positive test - let ts = prepare_random_object(random()); - let mut beningn_memory = HashMap::default(); - encode_to_memory(&mut beningn_memory, OBJ_POINTER, &ts); - let benign_nd = NonDeterminism::default().with_ram(beningn_memory.clone()); + let (benign_nd, ts) = nd_mem_with_random_obj(random()); let expected_output = vec![bfe!(ts.d.len() as u64)]; let stdin = vec![]; - let native_output = wrap_main_with_io(&main)(stdin.clone(), benign_nd.clone()); + let native_output = wrap_main_with_io(&field_getter)(stdin.clone(), benign_nd.clone()); assert_eq!(expected_output, native_output); - let entrypoint = EntrypointLocation::disk("structs", "exceed_allowed_field_size", "main"); + let entrypoint = + EntrypointLocation::disk("structs", "exceed_allowed_field_size", "field_getter"); let vm_output = TritonVMTestCase::new(entrypoint.clone()) - .with_non_determinism(benign_nd) + .with_non_determinism(benign_nd.clone()) .execute() .unwrap(); assert_eq!(native_output, vm_output.public_output); - // Negative test: size indicator exceeds 2^32 + // Negative test: size indicator exceeds max allowed size, but is valid u32. const OFFSET_FOR_OF_MALICIOUS_SI: BFieldElement = BFieldElement::new(5); - let mut malicious_memory0 = beningn_memory.clone(); - malicious_memory0.insert( + let mut malicious_nd_ram_0 = benign_nd.clone(); + malicious_nd_ram_0.ram.insert( OBJ_POINTER + OFFSET_FOR_OF_MALICIOUS_SI, bfe!(DataType::MAX_DYN_FIELD_SIZE + 1), ); - let malicious_nd_0 = NonDeterminism::default().with_ram(malicious_memory0.clone()); let err0 = TritonVMTestCase::new(entrypoint.clone()) - .with_non_determinism(malicious_nd_0) + .with_non_determinism(malicious_nd_ram_0) .execute() .unwrap_err(); let err0 = err0.downcast::().unwrap(); assert_eq!(InstructionError::AssertionFailed, err0); // Negative test: size indicator is negative - let mut malicious_memory1 = beningn_memory.clone(); + let mut malicious_nd_ram_1 = benign_nd.clone(); let negative_number = bfe!(-1); - malicious_memory1.insert(OBJ_POINTER + OFFSET_FOR_OF_MALICIOUS_SI, negative_number); - let malicious_nd_1 = NonDeterminism::default().with_ram(malicious_memory1.clone()); + malicious_nd_ram_1 + .ram + .insert(OBJ_POINTER + OFFSET_FOR_OF_MALICIOUS_SI, negative_number); let err1 = TritonVMTestCase::new(entrypoint.clone()) - .with_non_determinism(malicious_nd_1) + .with_non_determinism(malicious_nd_ram_1) .execute() .unwrap_err(); let err1 = err1.downcast::().unwrap();