From 0ea12d2904bfe3d2cf470a30d15069206830ebfe Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Thu, 19 Dec 2024 15:00:16 +0100 Subject: [PATCH] signoff!: u64::div2 BREAKING CHANGE: Remove the implementation of `DeprecatedSnippet` for `u64::div2::Div2`, implement `BasicSnippet` directly. --- .../tasmlib_arithmetic_u64_div2.json | 4 +- ...alculate_new_peaks_from_leaf_mutation.json | 4 +- .../tasmlib_mmr_verify_from_memory.json | 8 +- tasm-lib/src/arithmetic/u64/div2.rs | 321 +++++++----------- 4 files changed, 138 insertions(+), 199 deletions(-) diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div2.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div2.json index 06ab34d4..79ca58d2 100644 --- a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div2.json +++ b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div2.json @@ -2,7 +2,7 @@ { "name": "tasmlib_arithmetic_u64_div2", "benchmark_result": { - "clock_cycle_count": 16, + "clock_cycle_count": 14, "hash_table_height": 18, "u32_table_height": 37, "op_stack_table_height": 6, @@ -13,7 +13,7 @@ { "name": "tasmlib_arithmetic_u64_div2", "benchmark_result": { - "clock_cycle_count": 16, + "clock_cycle_count": 14, "hash_table_height": 18, "u32_table_height": 72, "op_stack_table_height": 6, diff --git a/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json b/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json index c1ef095a..089104ba 100644 --- a/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json +++ b/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json @@ -2,7 +2,7 @@ { "name": "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation", "benchmark_result": { - "clock_cycle_count": 2267, + "clock_cycle_count": 2205, "hash_table_height": 408, "u32_table_height": 1077, "op_stack_table_height": 1450, @@ -13,7 +13,7 @@ { "name": "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation", "benchmark_result": { - "clock_cycle_count": 4326, + "clock_cycle_count": 4202, "hash_table_height": 594, "u32_table_height": 1724, "op_stack_table_height": 2824, diff --git a/tasm-lib/benchmarks/tasmlib_mmr_verify_from_memory.json b/tasm-lib/benchmarks/tasmlib_mmr_verify_from_memory.json index 2e1dbdad..bfaff253 100644 --- a/tasm-lib/benchmarks/tasmlib_mmr_verify_from_memory.json +++ b/tasm-lib/benchmarks/tasmlib_mmr_verify_from_memory.json @@ -2,8 +2,8 @@ { "name": "tasmlib_mmr_verify_from_memory", "benchmark_result": { - "clock_cycle_count": 1504, - "hash_table_height": 414, + "clock_cycle_count": 1442, + "hash_table_height": 408, "u32_table_height": 1407, "op_stack_table_height": 1160, "ram_table_height": 160 @@ -13,8 +13,8 @@ { "name": "tasmlib_mmr_verify_from_memory", "benchmark_result": { - "clock_cycle_count": 2850, - "hash_table_height": 600, + "clock_cycle_count": 2726, + "hash_table_height": 594, "u32_table_height": 1566, "op_stack_table_height": 2224, "ram_table_height": 315 diff --git a/tasm-lib/src/arithmetic/u64/div2.rs b/tasm-lib/src/arithmetic/u64/div2.rs index 23d9c285..85a9556a 100644 --- a/tasm-lib/src/arithmetic/u64/div2.rs +++ b/tasm-lib/src/arithmetic/u64/div2.rs @@ -1,237 +1,176 @@ use std::collections::HashMap; -use num::Zero; -use rand::prelude::*; use triton_vm::prelude::*; -use twenty_first::prelude::U32s; -use crate::empty_stack; use crate::prelude::*; -use crate::push_encodable; -use crate::traits::deprecated_snippet::DeprecatedSnippet; -use crate::InitVmState; - -#[derive(Clone, Debug)] +use crate::traits::basic_snippet::Reviewer; +use crate::traits::basic_snippet::SignOffFingerprint; + +/// Integer-divide the argument by 2. +/// +/// ### Behavior +/// +/// ```text +/// BEFORE: _ [arg: u64] +/// AFTER: _ [arg/2: u64] +/// ``` +/// +/// ### Preconditions +/// +/// - the input is properly [`BFieldCodec`] encoded +/// +/// ### Postconditions +/// +/// - the output is properly [`BFieldCodec`] encoded +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub struct Div2; -impl DeprecatedSnippet for Div2 { - fn entrypoint_name(&self) -> String { - "tasmlib_arithmetic_u64_div2".to_string() +impl BasicSnippet for Div2 { + fn inputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U64, "arg".to_string())] } - fn input_field_names(&self) -> Vec { - vec!["value_hi".to_string(), "value_lo".to_string()] + fn outputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U64, "(arg/2)".to_string())] } - fn input_types(&self) -> Vec { - vec![DataType::U64] - } - - fn output_field_names(&self) -> Vec { - vec!["(value / 2)_hi".to_string(), "(value / 2)_lo".to_string()] - } - - fn output_types(&self) -> Vec { - vec![DataType::U64] - } - - fn stack_diff(&self) -> isize { - 0 + fn entrypoint(&self) -> String { + "tasmlib_arithmetic_u64_div2".to_string() } - fn function_code(&self, _library: &mut Library) -> String { - let entrypoint = self.entrypoint_name(); - let two_pow_31 = 1u64 << 31; - - format!( - " - // BEFORE: _ value_hi value_lo - // AFTER: _ (value / 2)_hi (value / 2)_lo - {entrypoint}: - // Divide the lower number - push 2 - swap 1 - div_mod - pop 1 - // stack: _ value_hi (value_lo / 2) - - // Divide the upper number and carry its least significant bit into the lower number - swap 1 - // stack: _ (value_lo / 2) value_hi - - push 2 - swap 1 - div_mod - // stack: _ (value_lo / 2) (value_hi / 2) (value_hi % 2) - - push {two_pow_31} - mul - // stack: _ (value_lo / 2) (value_hi / 2) carry - - swap 1 - swap 2 - // stack: _ (value_hi / 2) carry (value_lo / 2) - - add - // stack: _ (value / 2)_hi (value / 2)_lo - - return - " + fn code(&self, _: &mut Library) -> Vec { + triton_asm!( + // BEFORE: _ arg_hi arg_lo + // AFTER: _ (arg / 2)_hi (arg / 2)_lo + {self.entrypoint()}: + /* divide low part */ + push 2 + pick 1 + div_mod + pop 1 + // _ arg_hi (arg_lo / 2) + + /* divide high part, carry its least significant bit into the low part */ + push 2 + pick 2 + div_mod + // _ (arg_lo / 2) (arg_hi / 2) (arg_hi % 2) + // _ (arg_lo / 2) (arg / 2)_hi (arg_hi % 2) + + push {1_u32 << 31} + hint two_pow_31: u32 = stack[0] + mul + hint carry: u32 = stack[0] + // _ (arg_lo / 2) (arg / 2)_hi carry + + pick 2 + add + // _ (arg / 2)_hi (arg / 2)_lo + + return ) } - fn crash_conditions(&self) -> Vec { - vec![ - "If value_hi is not a u32".to_string(), - "If value_lo is not a u32".to_string(), - ] - } - - fn gen_input_states(&self) -> Vec { - let n: u64 = rand::thread_rng().next_u64(); - let n: U32s<2> = n.try_into().unwrap(); - let mut input_stack = empty_stack(); - - push_encodable(&mut input_stack, &n); - - vec![InitVmState::with_stack(input_stack)] + fn sign_offs(&self) -> HashMap { + let mut sign_offs = HashMap::new(); + sign_offs.insert(Reviewer("ferdinand"), 0xe77a12ba30ef339b.into()); + sign_offs } +} - fn common_case_input_state(&self) -> InitVmState { - InitVmState::with_stack( - [ - empty_stack(), - vec![BFieldElement::zero(), BFieldElement::new(1 << 31)], - ] - .concat(), - ) +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::negative_test; + use crate::test_prelude::*; + + impl Div2 { + fn assert_expected_behavior(&self, arg: u64) { + let initial_stack = self.set_up_test_stack(arg); + + let mut expected_stack = initial_stack.clone(); + self.rust_shadow(&mut expected_stack); + + test_rust_equivalence_given_complete_state( + &ShadowedClosure::new(Self), + &initial_stack, + &[], + &NonDeterminism::default(), + &None, + Some(&expected_stack), + ); + } } - fn worst_case_input_state(&self) -> InitVmState { - let big_number = 1 << 31; - let worst_case_input = [big_number + 1, big_number].map(BFieldElement::new); - let worst_case_stack = [empty_stack(), worst_case_input.to_vec()].concat(); + impl Closure for Div2 { + type Args = u64; - InitVmState::with_stack(worst_case_stack) - } + fn rust_shadow(&self, stack: &mut Vec) { + let arg = pop_encodable::(stack); + push_encodable(stack, &(arg / 2)); + } - fn rust_shadowing( - &self, - stack: &mut Vec, - _std_in: Vec, - _secret_in: Vec, - _memory: &mut HashMap, - ) { - let value_lo: u32 = stack.pop().unwrap().try_into().unwrap(); - let value_hi: u32 = stack.pop().unwrap().try_into().unwrap(); - let value: u64 = ((value_hi as u64) << 32) + value_lo as u64; - let result: u64 = value / 2; - - stack.push(BFieldElement::new(result >> 32)); - stack.push(BFieldElement::new(result & u32::MAX as u64)); + fn pseudorandom_args( + &self, + seed: [u8; 32], + bench_case: Option, + ) -> Self::Args { + match bench_case { + Some(BenchmarkCase::CommonCase) => 0x8000_0000, + Some(BenchmarkCase::WorstCase) => 0xf000_0001_f000_0000, + None => StdRng::from_seed(seed).gen(), + } + } } -} - -#[cfg(test)] -mod tests { - use BFieldElement; - - use super::*; - use crate::empty_stack; - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; #[test] - fn div2_u64_test() { - test_rust_equivalence_multiple_deprecated(&Div2, true); + fn rust_shadow() { + ShadowedClosure::new(Div2).test(); } - #[should_panic] - #[test] - fn lo_is_not_u32() { - let mut init_stack = empty_stack(); - init_stack.push(BFieldElement::new(16)); - init_stack.push(BFieldElement::new(u32::MAX as u64 + 1)); - - test_rust_equivalence_given_input_values_deprecated::( - &Div2, - &init_stack, - &[], - HashMap::default(), - None, + #[proptest] + fn lo_is_not_u32(hi: u32, #[strategy(1_u64 << 32..)] lo: u64) { + let stack = [Div2.init_stack_for_isolated_run(), bfe_vec![hi, lo]].concat(); + + let error = InstructionError::OpStackError(OpStackError::FailedU32Conversion(bfe!(lo))); + negative_test( + &ShadowedClosure::new(Div2), + InitVmState::with_stack(stack), + &[error], ); } - #[should_panic] - #[test] - fn hi_is_not_u32() { - let mut init_stack = empty_stack(); - init_stack.push(BFieldElement::new(u32::MAX as u64 + 1)); - init_stack.push(BFieldElement::new(16)); - - test_rust_equivalence_given_input_values_deprecated::( - &Div2, - &init_stack, - &[], - HashMap::default(), - None, + #[proptest] + fn hi_is_not_u32(#[strategy(1_u64 << 32..)] hi: u64, lo: u32) { + let stack = [Div2.init_stack_for_isolated_run(), bfe_vec![hi, lo]].concat(); + + let error = InstructionError::OpStackError(OpStackError::FailedU32Conversion(bfe!(hi))); + negative_test( + &ShadowedClosure::new(Div2), + InitVmState::with_stack(stack), + &[error], ); } #[test] fn div_2_test() { - prop_div_2(0); - prop_div_2(1); - prop_div_2(2); - prop_div_2(3); - prop_div_2(4); - prop_div_2(5); - prop_div_2(6); - prop_div_2(7); - prop_div_2(8); - prop_div_2(1 << 32); - prop_div_2((1 << 32) + 1); - prop_div_2((1 << 32) + 2); - prop_div_2((1 << 32) + 3); - prop_div_2((1 << 32) + 4); - prop_div_2((1 << 63) + 4); - prop_div_2((1 << 63) + 4); - prop_div_2((1 << 63) + (1 << 31)); - prop_div_2((1 << 63) + (1 << 33) + (1 << 32) + (1 << 31)); - - let mut rng = thread_rng(); - for _ in 0..100 { - let value = rng.next_u64(); - prop_div_2(value); - } - } + let small_args = 0..9; + let mid_args = (0..9).map(|offset| (1 << 32) + offset); + let large_args = [0, 4, 1 << 31, 0b111 << 31].map(|offset| (1 << 63) + offset); - fn prop_div_2(value: u64) { - let mut init_stack = empty_stack(); - init_stack.push(BFieldElement::new(value >> 32)); - init_stack.push(BFieldElement::new(value & u32::MAX as u64)); - let mut expected_stack = empty_stack(); - let res = value / 2; - expected_stack.push(BFieldElement::new(res >> 32)); - expected_stack.push(BFieldElement::new(res & u32::MAX as u64)); - - test_rust_equivalence_given_input_values_deprecated::( - &Div2, - &init_stack, - &[], - HashMap::default(), - Some(&expected_stack), - ); + for arg in small_args.chain(mid_args).chain(large_args) { + Div2.assert_expected_behavior(arg); + } } } #[cfg(test)] mod benches { use super::*; - use crate::snippet_bencher::bench_and_write; + use crate::test_prelude::*; #[test] - fn div2_u64_benchmark() { - bench_and_write(Div2); + fn benchmark() { + ShadowedClosure::new(Div2).bench(); } }