Skip to content

Commit

Permalink
Optimize sum for XFE by batch-reading elements
Browse files Browse the repository at this point in the history
For input sizes of around 200, this reduces clock cycle count by ~60 %.
  • Loading branch information
Sword-Smith committed Feb 6, 2024
1 parent ec3cdb9 commit 2ea4a96
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 21 deletions.
12 changes: 6 additions & 6 deletions tasm-lib/benchmarks/tasm_list_unsafeimplu32_sum_xfe.json
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
[
{
"name": "tasm_list_unsafeimplu32_sum_xfe",
"clock_cycle_count": 1064,
"hash_table_height": 36,
"u32_table_height": 0,
"clock_cycle_count": 466,
"hash_table_height": 72,
"u32_table_height": 12,
"case": "CommonCase"
},
{
"name": "tasm_list_unsafeimplu32_sum_xfe",
"clock_cycle_count": 10064,
"hash_table_height": 36,
"u32_table_height": 0,
"clock_cycle_count": 3886,
"hash_table_height": 72,
"u32_table_height": 15,
"case": "WorstCase"
}
]
157 changes: 142 additions & 15 deletions tasm-lib/src/list/sum_xfes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::data_type::DataType;
use crate::library::Library;
use crate::traits::basic_snippet::BasicSnippet;
use triton_vm::prelude::*;
use triton_vm::twenty_first::shared_math::x_field_element::EXTENSION_DEGREE;

use super::ListType;

Expand All @@ -13,7 +14,8 @@ struct SumOfXfes {
impl BasicSnippet for SumOfXfes {
fn inputs(&self) -> Vec<(DataType, String)> {
vec![(
// For naming the input argument, I just follow what `Rust` calls this argument
// For naming the input argument, I just follow what `Rust` calls this argument in
// its `sum` method.
DataType::List(Box::new(DataType::Xfe)),
"self".to_owned(),
)]
Expand All @@ -32,7 +34,48 @@ impl BasicSnippet for SumOfXfes {
}

fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
assert_eq!(
3, EXTENSION_DEGREE,
"Code only works for extension degree = 3, got: {EXTENSION_DEGREE}"
);
let entrypoint = self.entrypoint();
let accumulate_5_elements_loop_label = format!("{entrypoint}_acc_5_elements_loop");
let accumulate_5_elements_loop = triton_asm!(
// Invariant: _ *end_loop *element_last_word [acc; 3]
{accumulate_5_elements_loop_label}:
// _ *end_loop *element_last_word [acc; 3]
dup 4
dup 4
eq
skiz
return
// _ *end_loop *element_last_word [acc]

dup 3
read_mem 5
read_mem 5
read_mem 5
// _ *end_loop *element_last_word [acc] [elem_4] [elem_3] [elem_2] [elem_1] [elem_0] (*element_last_word - 15)

pop 1
// _ *end_loop *element_last_word [acc] [elem_4] [elem_3] [elem_2] [elem_1] [elem_0]

xxadd
xxadd
xxadd
xxadd
xxadd
// _ *end_loop *element_last_word [acc']

swap 3
push -15
add
swap 3
// _ *end_loop *element_last_word' [acc']

recurse
);

let accumulate_one_element_loop_label = format!("{entrypoint}_acc_1_element_loop");
let accumulate_one_element_loop = triton_asm!(
// Invariant: _ *end_loop *element_last_word [acc; 3]
Expand Down Expand Up @@ -75,12 +118,12 @@ impl BasicSnippet for SumOfXfes {
let adjust_loops_end_condition_for_metadata = match self.list_type.metadata_size() {
1 => triton_asm!(),
2 => triton_asm!(
// _ *list something
swap 1
// _ *list s3 s2 s1 s0
swap 4
push 1
add
swap 1
// _ (*list + 1) something
swap 4
// _ (*list + 1) s3 s2 s1 s0
),
n => panic!("Unhandled metadata size. Got: {n}"),
};
Expand All @@ -104,32 +147,65 @@ impl BasicSnippet for SumOfXfes {
add
// _ *list *last_word

{&adjust_loops_end_condition_for_metadata}
// _ *end_condition *last_word
// Get pointer to *end_loop that is the loop termination condition
push 5
dup 2
// _ *list *last_word 5 *list

read_mem 1
pop 1
// _ *list *last_word 5 len

div_mod
// _ *list *last_word (len / 5) (len % 5)

swap 1
pop 1
// _ *list *last_word (len % 5)

push {EXTENSION_DEGREE}
mul
// _ *list *last_word ((len % 5) * 3)

dup 2
add
// _ *list *last_word ((len % 5) * 3 + *list)

{&adjust_offset_for_metadata}
// _ *list *last_word *end_5_loop

swap 1
push 0
push 0
push 0
// _ *end_condition *last_word [acc]
// _ *list *end_5_loop *last_word [acc]

call {accumulate_5_elements_loop_label}
// _ *list *end_5_loop *end_5_loop [acc]

swap 1
swap 2
swap 3
pop 1
// _ *list *end_5_loop [acc]

{&adjust_loops_end_condition_for_metadata}
// _ *end_condition_1_loop *end_5_loop [acc]

call {accumulate_one_element_loop_label}
// _ *end_condition *last_word [acc]
// _ *end_condition_1_loop *end_5_loop [acc]

// _ *end_condition *last_word acc_2 acc_1 acc_0
swap 2
swap 4
// _ acc_2 *last_word acc_0 acc_1 *end_condition

pop 1
// _ acc_2 *last_word acc_0 acc_1

swap 2
pop 1
// _ acc_2 acc_1 acc_0
// _ [acc]

return

{&accumulate_one_element_loop}
{&accumulate_5_elements_loop}
)
}
}
Expand All @@ -139,13 +215,15 @@ mod tests {
use std::collections::HashMap;

use itertools::Itertools;
use rand::random;
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use triton_vm::twenty_first::shared_math::x_field_element::EXTENSION_DEGREE;

use super::*;
use crate::snippet_bencher::BenchmarkCase;
use crate::test_helpers::test_rust_equivalence_given_complete_state;
use crate::traits::function::Function;
use crate::traits::function::FunctionInitialState;
use crate::traits::function::ShadowedFunction;
Expand Down Expand Up @@ -234,6 +312,55 @@ mod tests {
})
.test()
}

#[test]
fn sum_xfes_unit_test_unsafe_list() {
let snippet = SumOfXfes {
list_type: ListType::Unsafe,
};
let input_list_2_long: Vec<XFieldElement> = vec![random(), random()];
let expected_sum: XFieldElement = input_list_2_long.clone().into_iter().sum();

let mut memory = HashMap::default();
let list_pointer = BFieldElement::new(1u64 << 33);
insert_xfe_list_into_memory(list_pointer, input_list_2_long, &mut memory);
let init_stack = [snippet.init_stack_for_isolated_run(), vec![list_pointer]].concat();
let expected_final_stack = [
snippet.init_stack_for_isolated_run(),
vec![
expected_sum.coefficients[2],
expected_sum.coefficients[1],
expected_sum.coefficients[0],
],
]
.concat();

test_rust_equivalence_given_complete_state(
&ShadowedFunction::new(snippet),
&init_stack,
&[],
&NonDeterminism::default().with_ram(memory),
&None,
0,
Some(&expected_final_stack),
);
}

fn insert_xfe_list_into_memory(
list_pointer: BFieldElement,
list: Vec<XFieldElement>,
memory: &mut HashMap<BFieldElement, BFieldElement>,
) {
let mut pointer = list_pointer;
memory.insert(pointer, BFieldElement::new(list.len() as u64));
pointer.increment();
for xfe in list.iter() {
for bfe in xfe.coefficients.iter() {
memory.insert(pointer, *bfe);
pointer.increment();
}
}
}
}

#[cfg(test)]
Expand Down

0 comments on commit 2ea4a96

Please sign in to comment.