Skip to content

Commit

Permalink
derive iterable structs
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 26, 2024
1 parent 5a3e44f commit 222c42f
Show file tree
Hide file tree
Showing 10 changed files with 663 additions and 180 deletions.
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = ["crates/prover", "crates/air_utils"]
members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"]
resolver = "2"

[workspace.package]
Expand Down
1 change: 1 addition & 0 deletions crates/air_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bytemuck.workspace = true
itertools.workspace = true
rayon = { version = "1.10.0", optional = false }
stwo-prover = { path = "../prover" }
stwo-air-utils-derive = { path = "../air_utils_derive" }

[lib]
bench = false
1 change: 1 addition & 0 deletions crates/air_utils/src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

259 changes: 80 additions & 179 deletions crates/air_utils/src/trace/examle_lookup_data.rs
Original file line number Diff line number Diff line change
@@ -1,197 +1,50 @@
#![allow(unused)]
// TODO(Ohad): write a derive macro for this.
use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
use rayon::prelude::*;
use stwo_air_utils_derive::StwoIterable;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};

#[derive(StwoIterable)]
pub struct LookupData {
pub lu0: Vec<[PackedM31; 2]>,
pub lu1: Vec<[PackedM31; 4]>,
}
impl LookupData {
/// # Safety
pub unsafe fn uninitialized(log_size: u32) -> Self {
let length = 1 << log_size;
let n_simd_elems = length / N_LANES;
let mut lu0 = Vec::with_capacity(n_simd_elems);
let mut lu1 = Vec::with_capacity(n_simd_elems);
lu0.set_len(n_simd_elems);
lu1.set_len(n_simd_elems);

Self { lu0, lu1 }
}

pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> {
LookupDataIterMut::new(&mut self.lu0, &mut self.lu1)
}

pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> {
ParLookupDataIterMut {
lu0: &mut self.lu0,
lu1: &mut self.lu1,
}
}
}

pub struct LookupDataMutChunk<'trace> {
pub lu0: &'trace mut [PackedM31; 2],
pub lu1: &'trace mut [PackedM31; 4],
}
pub struct LookupDataIterMut<'trace> {
lu0: *mut [[PackedM31; 2]],
lu1: *mut [[PackedM31; 4]],
phantom: std::marker::PhantomData<&'trace ()>,
}
impl<'trace> LookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0 as *mut _,
lu1: slice1 as *mut _,
phantom: std::marker::PhantomData,
}
}
}
impl<'trace> Iterator for LookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn next(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(1);
let (head1, tail1) = self.lu1.split_at_mut(1);
self.lu0 = tail0;
self.lu1 = tail1;
LookupDataMutChunk {
lu0: &mut (*head0)[0],
lu1: &mut (*head1)[0],
}
};
Some(item)
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.lu0.len();
(len, Some(len))
}
}

impl ExactSizeIterator for LookupDataIterMut<'_> {}
impl DoubleEndedIterator for LookupDataIterMut<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.lu0.is_empty() {
return None;
}
let item = unsafe {
let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1);
let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1);
self.lu0 = head0;
self.lu1 = head1;
LookupDataMutChunk {
lu0: &mut (*tail0)[0],
lu1: &mut (*tail1)[0],
}
};
Some(item)
}
}

struct RowProducer<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> Producer for RowProducer<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn split_at(self, index: usize) -> (Self, Self) {
let (lu0, rh0) = self.lu0.split_at_mut(index);
let (lu1, rh1) = self.lu1.split_at_mut(index);
(RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 })
}

type IntoIter = LookupDataIterMut<'trace>;

fn into_iter(self) -> Self::IntoIter {
LookupDataIterMut::new(self.lu0, self.lu1)
}
}

pub struct ParLookupDataIterMut<'trace> {
lu0: &'trace mut [[PackedM31; 2]],
lu1: &'trace mut [[PackedM31; 4]],
}

impl<'trace> ParLookupDataIterMut<'trace> {
pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self {
Self {
lu0: slice0,
lu1: slice1,
}
}
}

impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> {
type Item = LookupDataMutChunk<'trace>;

fn drive_unindexed<D>(self, consumer: D) -> D::Result
where
D: UnindexedConsumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl IndexedParallelIterator for ParLookupDataIterMut<'_> {
fn len(&self) -> usize {
self.lu0.len()
}

fn drive<D: Consumer<Self::Item>>(self, consumer: D) -> D::Result {
bridge(self, consumer)
}

fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
callback.callback(RowProducer {
lu0: self.lu0,
lu1: self.lu1,
})
}
lu0: Vec<PackedM31>,
lu1: Vec<[PackedM31; 2]>,
lu2: [Vec<[PackedM31; 2]>; 2],
}

#[cfg(test)]
mod tests {
use itertools::{all, Itertools};
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES};
use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES};
use stwo_prover::core::fields::m31::M31;

use crate::trace::component_trace::ComponentTrace;
use crate::trace::examle_lookup_data::LookupData;

#[test]
fn test_lookup_data() {
fn test_derived_par_lookup_data() {
const N_COLUMNS: usize = 5;
const LOG_SIZE: u32 = 8;
let mut trace = ComponentTrace::<N_COLUMNS>::zeroed(LOG_SIZE);
let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec();
let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) };
let expected: (Vec<_>, Vec<_>) = arr
let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE - LOG_N_LANES) };
let expected: (Vec<_>, Vec<_>, Vec<_>) = arr
.array_chunks::<N_LANES>()
.map(|x| {
let x = PackedM31::from_array(*x);
let x1 = x + PackedM31::broadcast(M31(1));
let x2 = x + x1;
let x3 = x + x1 + x2;
let x4 = x + x1 + x2 + x3;
([x, x4], [x1, x1.double(), x2, x2.double()])
(
x4,
[x1, x1.double()],
([x2, x2.double()], [x3, x3.double()]),
)
})
.unzip();
.multiunzip();

trace
.par_iter_mut()
Expand All @@ -205,24 +58,72 @@ mod tests {
*row[2] = *row[0] + *row[1];
*row[3] = *row[0] + *row[1] + *row[2];
*row[4] = *row[0] + *row[1] + *row[2] + *row[3];
*lookup_data.lu0 = [*row[0], *row[4]];
*lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()];
})
*lookup_data.lu0 = *row[4];
*lookup_data.lu1 = [*row[1], row[1].double()];
*lookup_data.lu2[0] = [*row[2], row[2].double()];
*lookup_data.lu2[1] = [*row[3], row[3].double()];
});
});

assert!(all(
lookup_data.lu0.into_iter().zip(expected.0),
|(actual, expected)| actual[0].to_array() == expected[0].to_array()
&& actual[1].to_array() == expected[1].to_array()
));
let actual = (
lookup_data.lu0,
lookup_data.lu1,
(lookup_data.lu2[0].clone(), lookup_data.lu2[1].clone()),
);

assert!(
all(
expected
.0
.iter()
.zip(actual.0.iter())
.map(|(expected, actual)| (expected.to_array(), actual.to_array())),
|(expected, actual)| expected == actual
),
"Failed on Vec<PackedM31>"
);
assert!(
all(
expected
.1
.iter()
.zip(actual.1.iter())
.map(|(expected, actual)| (
expected
.into_iter()
.flat_map(|v| v.to_array())
.collect_vec(),
actual.into_iter().flat_map(|v| v.to_array()).collect_vec()
)),
|(expected, actual)| expected == actual
),
"Failed on Vec<[PackedM31; 2]>"
);
assert!(all(
lookup_data.lu1.into_iter().zip(expected.1),
|(actual, expected)| {
actual[0].to_array() == expected[0].to_array()
&& actual[1].to_array() == expected[1].to_array()
&& actual[2].to_array() == expected[2].to_array()
&& actual[3].to_array() == expected[3].to_array()
}
expected
.2
.iter()
.map(|expected| expected.0)
.zip(actual.2 .0.into_iter())
.map(|(expected, actual)| (
expected.iter().flat_map(|v| v.to_array()).collect_vec(),
actual.iter().flat_map(|v| v.to_array()).collect_vec()
)),
|(expected, actual)| expected == actual
));
assert!(
all(
expected
.2
.iter()
.map(|e| e.1)
.zip(actual.2 .1.into_iter())
.map(|(expected, actual)| (
expected.iter().flat_map(|v| v.to_array()).collect_vec(),
actual.iter().flat_map(|v| v.to_array()).collect_vec()
)),
|(expected, actual)| expected == actual
),
"Failed on [Vec<[PackedM31; 2]>; 2]"
);
}
}
13 changes: 13 additions & 0 deletions crates/air_utils_derive/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "stwo-air-utils-derive"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
syn = "2.0.90"
quote = "1.0.37"
itertools = "0.13.0"
proc-macro2 = "1.0.92"
Empty file.
Empty file.
Loading

0 comments on commit 222c42f

Please sign in to comment.