From 9ba83df7cce3f50308f949299a6c5e886b6a9c17 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 6 Jan 2025 17:45:42 +0100 Subject: [PATCH] refactor!(leading_zeroes): Impl `BasicSnippet` Remove the implementations of `DeprecatedSnippet`, implement `BasicSnippet` directly. --- .../tasmlib_arithmetic_u32_leading_zeros.json | 24 ++ .../tasmlib_arithmetic_u32_leadingzeros.json | 24 -- .../tasmlib_arithmetic_u64_div_mod.json | 4 +- .../tasmlib_arithmetic_u64_leading_zeros.json | 8 +- .../benchmarks/tasmlib_verifier_xfe_ntt.json | 12 +- tasm-lib/src/arithmetic/u32.rs | 2 +- tasm-lib/src/arithmetic/u32/leading_zeros.rs | 92 ++++++++ tasm-lib/src/arithmetic/u32/leadingzeros.rs | 158 ------------- tasm-lib/src/arithmetic/u64/leading_zeros.rs | 218 ++++++------------ tasm-lib/src/exported_snippets.rs | 2 +- tasm-lib/src/verifier/xfe_ntt.rs | 5 +- 11 files changed, 202 insertions(+), 347 deletions(-) create mode 100644 tasm-lib/benchmarks/tasmlib_arithmetic_u32_leading_zeros.json delete mode 100644 tasm-lib/benchmarks/tasmlib_arithmetic_u32_leadingzeros.json create mode 100644 tasm-lib/src/arithmetic/u32/leading_zeros.rs delete mode 100644 tasm-lib/src/arithmetic/u32/leadingzeros.rs diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leading_zeros.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leading_zeros.json new file mode 100644 index 00000000..41cd964c --- /dev/null +++ b/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leading_zeros.json @@ -0,0 +1,24 @@ +[ + { + "name": "tasmlib_arithmetic_u32_leading_zeros", + "benchmark_result": { + "clock_cycle_count": 12, + "hash_table_height": 12, + "u32_table_height": 17, + "op_stack_table_height": 4, + "ram_table_height": 0 + }, + "case": "CommonCase" + }, + { + "name": "tasmlib_arithmetic_u32_leading_zeros", + "benchmark_result": { + "clock_cycle_count": 12, + "hash_table_height": 12, + "u32_table_height": 33, + "op_stack_table_height": 4, + "ram_table_height": 0 + }, + "case": "WorstCase" + } +] \ No newline at end of file diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leadingzeros.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leadingzeros.json deleted file mode 100644 index ae10b57f..00000000 --- a/tasm-lib/benchmarks/tasmlib_arithmetic_u32_leadingzeros.json +++ /dev/null @@ -1,24 +0,0 @@ -[ - { - "name": "tasmlib_arithmetic_u32_leadingzeros", - "benchmark_result": { - "clock_cycle_count": 14, - "hash_table_height": 18, - "u32_table_height": 17, - "op_stack_table_height": 8, - "ram_table_height": 0 - }, - "case": "CommonCase" - }, - { - "name": "tasmlib_arithmetic_u32_leadingzeros", - "benchmark_result": { - "clock_cycle_count": 14, - "hash_table_height": 18, - "u32_table_height": 33, - "op_stack_table_height": 8, - "ram_table_height": 0 - }, - "case": "WorstCase" - } -] \ No newline at end of file diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div_mod.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div_mod.json index 66d5d334..a7a85978 100644 --- a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div_mod.json +++ b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_div_mod.json @@ -13,10 +13,10 @@ { "name": "tasmlib_arithmetic_u64_div_mod", "benchmark_result": { - "clock_cycle_count": 8020, + "clock_cycle_count": 8016, "hash_table_height": 522, "u32_table_height": 10526, - "op_stack_table_height": 5390, + "op_stack_table_height": 5382, "ram_table_height": 142 }, "case": "WorstCase" diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_leading_zeros.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_leading_zeros.json index cfbd8d16..9b64bcef 100644 --- a/tasm-lib/benchmarks/tasmlib_arithmetic_u64_leading_zeros.json +++ b/tasm-lib/benchmarks/tasmlib_arithmetic_u64_leading_zeros.json @@ -2,10 +2,10 @@ { "name": "tasmlib_arithmetic_u64_leading_zeros", "benchmark_result": { - "clock_cycle_count": 36, + "clock_cycle_count": 32, "hash_table_height": 30, "u32_table_height": 33, - "op_stack_table_height": 21, + "op_stack_table_height": 13, "ram_table_height": 0 }, "case": "CommonCase" @@ -13,10 +13,10 @@ { "name": "tasmlib_arithmetic_u64_leading_zeros", "benchmark_result": { - "clock_cycle_count": 23, + "clock_cycle_count": 21, "hash_table_height": 30, "u32_table_height": 32, - "op_stack_table_height": 13, + "op_stack_table_height": 9, "ram_table_height": 0 }, "case": "WorstCase" diff --git a/tasm-lib/benchmarks/tasmlib_verifier_xfe_ntt.json b/tasm-lib/benchmarks/tasmlib_verifier_xfe_ntt.json index be5c4092..5db9e364 100644 --- a/tasm-lib/benchmarks/tasmlib_verifier_xfe_ntt.json +++ b/tasm-lib/benchmarks/tasmlib_verifier_xfe_ntt.json @@ -2,10 +2,10 @@ { "name": "tasmlib_verifier_xfe_ntt", "benchmark_result": { - "clock_cycle_count": 116763, - "hash_table_height": 222, + "clock_cycle_count": 116761, + "hash_table_height": 216, "u32_table_height": 10855, - "op_stack_table_height": 111550, + "op_stack_table_height": 111546, "ram_table_height": 13729 }, "case": "CommonCase" @@ -13,10 +13,10 @@ { "name": "tasmlib_verifier_xfe_ntt", "benchmark_result": { - "clock_cycle_count": 257114, - "hash_table_height": 222, + "clock_cycle_count": 257112, + "hash_table_height": 216, "u32_table_height": 24122, - "op_stack_table_height": 245974, + "op_stack_table_height": 245970, "ram_table_height": 30529 }, "case": "WorstCase" diff --git a/tasm-lib/src/arithmetic/u32.rs b/tasm-lib/src/arithmetic/u32.rs index 3ab1db53..ce012210 100644 --- a/tasm-lib/src/arithmetic/u32.rs +++ b/tasm-lib/src/arithmetic/u32.rs @@ -1,6 +1,6 @@ pub mod isodd; pub mod isu32; -pub mod leadingzeros; +pub mod leading_zeros; pub mod next_power_of_two; pub mod or; pub mod overflowingadd; diff --git a/tasm-lib/src/arithmetic/u32/leading_zeros.rs b/tasm-lib/src/arithmetic/u32/leading_zeros.rs new file mode 100644 index 00000000..9fe0ecb6 --- /dev/null +++ b/tasm-lib/src/arithmetic/u32/leading_zeros.rs @@ -0,0 +1,92 @@ +use triton_vm::prelude::*; + +use crate::prelude::*; + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct LeadingZeros; + +impl BasicSnippet for LeadingZeros { + fn inputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U32, "arg".to_string())] + } + + fn outputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U32, "leading_zeros(arg)".to_string())] + } + + fn entrypoint(&self) -> String { + "tasmlib_arithmetic_u32_leading_zeros".to_string() + } + + fn code(&self, _: &mut Library) -> Vec { + let entrypoint = self.entrypoint(); + let non_zero_label = format!("{entrypoint}_non_zero"); + + triton_asm! { + // BEFORE: _ value + // AFTER: _ (leading zeros in value) + {entrypoint}: + dup 0 + skiz + call {non_zero_label} + + push -1 + mul + addi 32 + + return + + {non_zero_label}: + log_2_floor + addi 1 + return + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_prelude::*; + + impl Closure for LeadingZeros { + type Args = u32; + + fn rust_shadow(&self, stack: &mut Vec) { + let arg = pop_encodable::(stack); + push_encodable(stack, &arg.leading_zeros()); + } + + fn pseudorandom_args( + &self, + seed: [u8; 32], + bench_case: Option, + ) -> Self::Args { + match bench_case { + Some(BenchmarkCase::CommonCase) => 1 << 15, + Some(BenchmarkCase::WorstCase) => u32::MAX, + None => StdRng::from_seed(seed).gen(), + } + } + + fn corner_case_args(&self) -> Vec { + vec![0, 1, 2, 3, 1 << 28, 1 << 29, 1 << 30, 1 << 31, u32::MAX] + } + } + + #[test] + fn unit() { + ShadowedClosure::new(LeadingZeros).test(); + } +} + +#[cfg(test)] +mod benches { + use super::*; + use crate::test_prelude::*; + + #[test] + fn benchmark() { + ShadowedClosure::new(LeadingZeros).bench(); + } +} diff --git a/tasm-lib/src/arithmetic/u32/leadingzeros.rs b/tasm-lib/src/arithmetic/u32/leadingzeros.rs deleted file mode 100644 index a10b15c3..00000000 --- a/tasm-lib/src/arithmetic/u32/leadingzeros.rs +++ /dev/null @@ -1,158 +0,0 @@ -use rand::prelude::*; -use triton_vm::prelude::*; - -use crate::empty_stack; -use crate::prelude::*; -use crate::traits::deprecated_snippet::DeprecatedSnippet; -use crate::InitVmState; - -#[derive(Clone, Debug)] -pub struct Leadingzeros; - -impl DeprecatedSnippet for Leadingzeros { - fn entrypoint_name(&self) -> String { - "tasmlib_arithmetic_u32_leadingzeros".to_string() - } - - fn input_field_names(&self) -> Vec { - vec!["value".to_string()] - } - - fn input_types(&self) -> Vec { - vec![DataType::U32] - } - - fn output_field_names(&self) -> Vec { - vec!["leading zeros in value".to_string()] - } - - fn output_types(&self) -> Vec { - vec![DataType::U32] - } - - fn stack_diff(&self) -> isize { - 0 - } - - fn function_code(&self, _library: &mut crate::library::Library) -> String { - let entrypoint = self.entrypoint_name(); - format!( - " - // BEFORE: _ value - // AFTER: _ (leading zeros in value) - {entrypoint}: - dup 0 - skiz - call {entrypoint}_non_zero - - push -1 - mul - push 32 - add - - return - - {entrypoint}_non_zero: - log_2_floor - push 1 - add - return - " - ) - } - - fn crash_conditions(&self) -> Vec { - vec!["Input is not u32".to_owned()] - } - - fn gen_input_states(&self) -> Vec { - let mut ret: Vec = vec![]; - for _ in 0..100 { - let mut stack = empty_stack(); - let value = thread_rng().next_u32(); - let value = BFieldElement::new(value as u64); - stack.push(value); - ret.push(InitVmState::with_stack(stack)); - } - - ret - } - - fn common_case_input_state(&self) -> InitVmState { - InitVmState::with_stack([empty_stack(), vec![BFieldElement::new(1 << 15)]].concat()) - } - - fn worst_case_input_state(&self) -> InitVmState { - InitVmState::with_stack([empty_stack(), vec![BFieldElement::new((1 << 32) - 1)]].concat()) - } - - fn rust_shadowing( - &self, - stack: &mut Vec, - _std_in: Vec, - _secret_in: Vec, - _memory: &mut std::collections::HashMap, - ) { - let value: u32 = stack.pop().unwrap().try_into().unwrap(); - - let value = value.leading_zeros(); - stack.push(BFieldElement::new(value as u64)); - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; - - #[test] - fn snippet_test() { - test_rust_equivalence_multiple_deprecated(&Leadingzeros, true); - } - - #[test] - fn leading_zeros_u32_simple_test() { - prop_safe_leading_zeros(1, Some(31)); - prop_safe_leading_zeros(2, Some(30)); - prop_safe_leading_zeros(3, Some(30)); - prop_safe_leading_zeros(4, Some(29)); - prop_safe_leading_zeros(256, Some(23)); - prop_safe_leading_zeros(123, Some(25)); - prop_safe_leading_zeros(0, Some(32)); - prop_safe_leading_zeros(1 << 31, Some(0)); - prop_safe_leading_zeros(1 << 30, Some(1)); - prop_safe_leading_zeros(1 << 29, Some(2)); - prop_safe_leading_zeros(1 << 28, Some(3)); - prop_safe_leading_zeros(u32::MAX, Some(0)); - } - - fn prop_safe_leading_zeros(value: u32, _expected: Option) { - let mut init_stack = empty_stack(); - init_stack.push(BFieldElement::new(value as u64)); - - let expected = value.leading_zeros(); - let expected = [empty_stack(), vec![BFieldElement::new(expected as u64)]].concat(); - - test_rust_equivalence_given_input_values_deprecated( - &Leadingzeros, - &init_stack, - &[], - HashMap::default(), - Some(&expected), - ); - } -} - -#[cfg(test)] -mod benches { - use super::*; - use crate::snippet_bencher::bench_and_write; - - #[test] - fn u32_leading_zeros_benchmark() { - bench_and_write(Leadingzeros); - } -} diff --git a/tasm-lib/src/arithmetic/u64/leading_zeros.rs b/tasm-lib/src/arithmetic/u64/leading_zeros.rs index be162aab..fe547c3f 100644 --- a/tasm-lib/src/arithmetic/u64/leading_zeros.rs +++ b/tasm-lib/src/arithmetic/u64/leading_zeros.rs @@ -1,191 +1,111 @@ -use rand::prelude::*; use triton_vm::prelude::*; -use crate::arithmetic::u32::leadingzeros::Leadingzeros; -use crate::empty_stack; +use crate::arithmetic::u32::leading_zeros::LeadingZeros as U32LeadingZeroes; use crate::prelude::*; -use crate::traits::deprecated_snippet::DeprecatedSnippet; -use crate::InitVmState; -#[derive(Clone, Debug)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub struct LeadingZeros; -impl DeprecatedSnippet for LeadingZeros { - fn entrypoint_name(&self) -> String { - "tasmlib_arithmetic_u64_leading_zeros".to_string() +impl BasicSnippet for LeadingZeros { + fn inputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U64, "arg".to_string())] } - fn input_field_names(&self) -> Vec { - vec!["value_hi".to_string(), "value_lo".to_string()] + fn outputs(&self) -> Vec<(DataType, String)> { + vec![(DataType::U32, "leading_zeros(arg)".to_string())] } - fn input_types(&self) -> Vec { - vec![DataType::U64] + fn entrypoint(&self) -> String { + "tasmlib_arithmetic_u64_leading_zeros".to_string() } - fn output_field_names(&self) -> Vec { - vec!["leading zeros in value".to_string()] - } + fn code(&self, library: &mut Library) -> Vec { + let leading_zeros_u32 = library.import(Box::new(U32LeadingZeroes)); - fn output_types(&self) -> Vec { - vec![DataType::U32] - } + let entrypoint = self.entrypoint(); + let hi_is_zero_label = format!("{entrypoint}_hi_is_zero"); - fn stack_diff(&self) -> isize { - -1 - } - - fn function_code(&self, library: &mut crate::library::Library) -> String { - let leading_zeros_u32 = library.import(Box::new(Leadingzeros)); - let entrypoint = self.entrypoint_name(); - format!( - " - // BEFORE: _ value_hi value_lo - // AFTER: _ (leading_zeros as u32) - {entrypoint}: - swap 1 - call {leading_zeros_u32} - // _ value_lo leading_zeros_value_hi - - dup 0 - push 32 - eq - skiz - call {entrypoint}_hi_was_zero - - // _ temp leading_zeros - - swap 1 - pop 1 - return - - {entrypoint}_hi_was_zero: - // _ value_lo 32 - - swap 1 - call {leading_zeros_u32} - // _ 32 leading_zeros_value_lo - - dup 1 - add - // _ 32 leading_zeros - return -" - ) - } + triton_asm!( + // BEFORE: _ value_hi value_lo + // AFTER: _ (leading_zeros as u32) + {entrypoint}: + pick 1 + call {leading_zeros_u32} + // _ value_lo leading_zeros_value_hi - fn crash_conditions(&self) -> Vec { - vec!["Inputs are not u32".to_owned()] - } + dup 0 + push 32 + eq + skiz + call {hi_is_zero_label} - fn gen_input_states(&self) -> Vec { - let mut rng = thread_rng(); - let mut ret = vec![]; - for _ in 0..10 { - ret.push(prepare_state(rng.next_u64())); - } + // _ temp leading_zeros - ret - } + pick 1 + pop 1 + return - fn common_case_input_state(&self) -> InitVmState { - prepare_state(1 << 31) - } + {hi_is_zero_label}: + // _ value_lo 32 - fn worst_case_input_state(&self) -> InitVmState { - prepare_state(1 << 62) - } + pick 1 + call {leading_zeros_u32} + // _ 32 leading_zeros_value_lo - fn rust_shadowing( - &self, - stack: &mut Vec, - _std_in: Vec, - _secret_in: Vec, - _memory: &mut std::collections::HashMap, - ) { - let value_lo: u32 = stack.pop().unwrap().try_into().unwrap(); - let value_hi: u32 = stack.pop().unwrap().try_into().unwrap(); - let value: u64 = ((value_hi as u64) << 32) + value_lo as u64; - - let value = value.leading_zeros(); - stack.push(BFieldElement::new(value as u64)); + addi 32 + // _ 32 leading_zeros + return + ) } } -fn prepare_state(value: u64) -> InitVmState { - let value_hi: u32 = (value >> 32) as u32; - let value_lo: u32 = (value & u32::MAX as u64) as u32; - let mut stack = empty_stack(); - stack.push(BFieldElement::new(value_hi as u64)); - stack.push(BFieldElement::new(value_lo as u64)); - InitVmState::with_stack(stack) -} - #[cfg(test)] mod tests { - use std::collections::HashMap; - use super::*; - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; + use crate::test_prelude::*; - #[test] - fn snippet_test() { - test_rust_equivalence_multiple_deprecated(&LeadingZeros, true); - } + impl Closure for LeadingZeros { + type Args = u64; - #[test] - fn leading_zeros_u64_simple_test() { - prop_leading_zeros(1, Some(63)); - prop_leading_zeros(2, Some(62)); - prop_leading_zeros(3, Some(62)); - prop_leading_zeros(4, Some(61)); - prop_leading_zeros(256, Some(55)); - prop_leading_zeros(123, Some(57)); - prop_leading_zeros(0, Some(64)); - prop_leading_zeros(1 << 31, Some(32)); - prop_leading_zeros(1 << 30, Some(33)); - prop_leading_zeros(1 << 29, Some(34)); - prop_leading_zeros(1 << 28, Some(35)); - prop_leading_zeros(u32::MAX as u64, Some(32)); - prop_leading_zeros(1000, Some(54)); - prop_leading_zeros(2000, Some(53)); - prop_leading_zeros(4000, Some(52)); - prop_leading_zeros(4095, Some(52)); - prop_leading_zeros(4096, Some(51)); - prop_leading_zeros(4097, Some(51)); - } - - fn prop_leading_zeros(value: u64, expected: Option) { - let mut init_stack = empty_stack(); - init_stack.push(BFieldElement::new(value >> 32)); - init_stack.push(BFieldElement::new(value & u32::MAX as u64)); + fn rust_shadow(&self, stack: &mut Vec) { + let arg = pop_encodable::(stack); + push_encodable(stack, &arg.leading_zeros()); + } - let leading_zeros = value.leading_zeros(); - if let Some(exp) = expected { - assert_eq!(exp, leading_zeros as u64); + fn pseudorandom_args( + &self, + seed: [u8; 32], + bench_case: Option, + ) -> Self::Args { + match bench_case { + Some(BenchmarkCase::CommonCase) => 1 << 31, + Some(BenchmarkCase::WorstCase) => 1 << 62, + None => StdRng::from_seed(seed).gen(), + } } - let mut expected_stack = empty_stack(); - expected_stack.push(BFieldElement::new(leading_zeros as u64)); + fn corner_case_args(&self) -> Vec { + let small = 0..10; + let medium = (27..35).map(|i| 1 << i); + let large = (0..10).map(|i| u64::MAX - i); + + small.chain(medium).chain(large).collect() + } + } - test_rust_equivalence_given_input_values_deprecated( - &LeadingZeros, - &init_stack, - &[], - HashMap::default(), - Some(&expected_stack), - ); + #[test] + fn unit() { + ShadowedClosure::new(LeadingZeros).test(); } } #[cfg(test)] mod benches { use super::*; - use crate::snippet_bencher::bench_and_write; + use crate::test_prelude::*; #[test] - fn u32_leading_zeros_benchmark() { - bench_and_write(LeadingZeros); + fn benchmark() { + ShadowedClosure::new(LeadingZeros).bench(); } } diff --git a/tasm-lib/src/exported_snippets.rs b/tasm-lib/src/exported_snippets.rs index d9747be5..11fba92e 100644 --- a/tasm-lib/src/exported_snippets.rs +++ b/tasm-lib/src/exported_snippets.rs @@ -76,7 +76,7 @@ pub fn name_to_snippet(fn_name: &str) -> Box { "tasmlib_arithmetic_u32_shiftright" => Box::new(u32::shiftright::Shiftright), "tasmlib_arithmetic_u32_shiftleft" => Box::new(u32::shiftleft::Shiftleft), "tasmlib_arithmetic_u32_or" => Box::new(u32::or::Or), - "tasmlib_arithmetic_u32_leadingzeros" => Box::new(u32::leadingzeros::Leadingzeros), + "tasmlib_arithmetic_u32_leading_zeros" => Box::new(u32::leading_zeros::LeadingZeros), "tasmlib_arithmetic_u32_safepow" => Box::new(u32::safepow::Safepow), "tasmlib_arithmetic_u32_overflowingadd" => Box::new(u32::overflowingadd::Overflowingadd), diff --git a/tasm-lib/src/verifier/xfe_ntt.rs b/tasm-lib/src/verifier/xfe_ntt.rs index 52e48e90..7afef2e0 100644 --- a/tasm-lib/src/verifier/xfe_ntt.rs +++ b/tasm-lib/src/verifier/xfe_ntt.rs @@ -23,8 +23,9 @@ impl BasicSnippet for XfeNtt { fn code(&self, library: &mut Library) -> Vec { let entrypoint = self.entrypoint(); - let tasm_arithmetic_u32_leadingzeros = - library.import(Box::new(crate::arithmetic::u32::leadingzeros::Leadingzeros)); + let tasm_arithmetic_u32_leadingzeros = library.import(Box::new( + crate::arithmetic::u32::leading_zeros::LeadingZeros, + )); #[allow(non_snake_case)] let tasm_list_length___xfe = library.import(Box::new(crate::list::length::Length::new(DataType::Xfe)));