diff --git a/Cargo.lock b/Cargo.lock index c87fc6628..5eecc10f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -593,6 +593,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -1042,9 +1051,20 @@ dependencies = [ "bytemuck", "itertools 0.12.1", "rayon", + "stwo-air-utils-derive", "stwo-prover", ] +[[package]] +name = "stwo-air-utils-derive" +version = "0.1.0" +dependencies = [ + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "stwo-prover" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index fadd620de..d4bb782ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/prover", "crates/air_utils"] +members = ["crates/prover", "crates/air_utils", "crates/air_utils_derive"] resolver = "2" [workspace.package] diff --git a/crates/air_utils/Cargo.toml b/crates/air_utils/Cargo.toml index 7d09a7eaf..4463faf8b 100644 --- a/crates/air_utils/Cargo.toml +++ b/crates/air_utils/Cargo.toml @@ -8,6 +8,7 @@ bytemuck.workspace = true itertools.workspace = true rayon = { version = "1.10.0", optional = false } stwo-prover = { path = "../prover" } +stwo-air-utils-derive = { path = "../air_utils_derive" } [lib] bench = false diff --git a/crates/air_utils/src/examples/mod.rs b/crates/air_utils/src/examples/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/crates/air_utils/src/lib.rs b/crates/air_utils/src/lib.rs index dd5257d2f..626907b98 100644 --- a/crates/air_utils/src/lib.rs +++ b/crates/air_utils/src/lib.rs @@ -1,2 +1,3 @@ #![feature(exact_size_is_empty, raw_slice_split, portable_simd, array_chunks)] pub mod trace; +pub mod examples; \ No newline at end of file diff --git a/crates/air_utils/src/trace/component_trace.rs b/crates/air_utils/src/trace/component_trace.rs index 5c96211ba..a66a7347d 100644 --- a/crates/air_utils/src/trace/component_trace.rs +++ b/crates/air_utils/src/trace/component_trace.rs @@ -65,7 +65,7 @@ impl ComponentTrace { #[allow(clippy::uninit_vec)] pub unsafe fn uninitialized(log_size: u32) -> Self { let data = [(); N].map(|_| { - let n_simd_elems = (1 << log_size) / N_LANES; + let n_simd_elems = ((1 << log_size) as usize).div_ceil(N_LANES); let mut vec = Vec::with_capacity(n_simd_elems); vec.set_len(n_simd_elems); vec diff --git a/crates/air_utils/src/trace/examle_lookup_data.rs b/crates/air_utils/src/trace/examle_lookup_data.rs index 31fc9ea95..8bfe885e7 100644 --- a/crates/air_utils/src/trace/examle_lookup_data.rs +++ b/crates/air_utils/src/trace/examle_lookup_data.rs @@ -1,167 +1,21 @@ +#![allow(unused)] // TODO(Ohad): write a derive macro for this. +use stwo_air_utils_derive::StwoIterable; +use stwo_prover::core::backend::simd::m31::PackedM31; +use stwo_prover::core::backend::simd::m31::N_LANES; use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer}; use rayon::prelude::*; -use stwo_prover::core::backend::simd::m31::{PackedM31, N_LANES}; +#[derive(StwoIterable)] +#[allow(dead_code)] pub struct LookupData { - pub lu0: Vec<[PackedM31; 2]>, - pub lu1: Vec<[PackedM31; 4]>, + lu3: [Vec<[PackedM31; 16]>; 4], + lu0: Vec<[PackedM31; 2]>, + lu1: Vec<[PackedM31; 4]>, + lu2: Vec<[PackedM31; 8]>, + lu4: [Vec<[PackedM31; 32]>; 4], } -impl LookupData { - /// # Safety - pub unsafe fn uninitialized(log_size: u32) -> Self { - let length = 1 << log_size; - let n_simd_elems = length / N_LANES; - let mut lu0 = Vec::with_capacity(n_simd_elems); - let mut lu1 = Vec::with_capacity(n_simd_elems); - lu0.set_len(n_simd_elems); - lu1.set_len(n_simd_elems); - Self { lu0, lu1 } - } - - pub fn iter_mut(&mut self) -> LookupDataIterMut<'_> { - LookupDataIterMut::new(&mut self.lu0, &mut self.lu1) - } - - pub fn par_iter_mut(&mut self) -> ParLookupDataIterMut<'_> { - ParLookupDataIterMut { - lu0: &mut self.lu0, - lu1: &mut self.lu1, - } - } -} - -pub struct LookupDataMutChunk<'trace> { - pub lu0: &'trace mut [PackedM31; 2], - pub lu1: &'trace mut [PackedM31; 4], -} -pub struct LookupDataIterMut<'trace> { - lu0: *mut [[PackedM31; 2]], - lu1: *mut [[PackedM31; 4]], - phantom: std::marker::PhantomData<&'trace ()>, -} -impl<'trace> LookupDataIterMut<'trace> { - pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { - Self { - lu0: slice0 as *mut _, - lu1: slice1 as *mut _, - phantom: std::marker::PhantomData, - } - } -} -impl<'trace> Iterator for LookupDataIterMut<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn next(&mut self) -> Option { - if self.lu0.is_empty() { - return None; - } - let item = unsafe { - let (head0, tail0) = self.lu0.split_at_mut(1); - let (head1, tail1) = self.lu1.split_at_mut(1); - self.lu0 = tail0; - self.lu1 = tail1; - LookupDataMutChunk { - lu0: &mut (*head0)[0], - lu1: &mut (*head1)[0], - } - }; - Some(item) - } - - fn size_hint(&self) -> (usize, Option) { - let len = self.lu0.len(); - (len, Some(len)) - } -} - -impl ExactSizeIterator for LookupDataIterMut<'_> {} -impl DoubleEndedIterator for LookupDataIterMut<'_> { - fn next_back(&mut self) -> Option { - if self.lu0.is_empty() { - return None; - } - let item = unsafe { - let (head0, tail0) = self.lu0.split_at_mut(self.lu0.len() - 1); - let (head1, tail1) = self.lu1.split_at_mut(self.lu1.len() - 1); - self.lu0 = head0; - self.lu1 = head1; - LookupDataMutChunk { - lu0: &mut (*tail0)[0], - lu1: &mut (*tail1)[0], - } - }; - Some(item) - } -} - -struct RowProducer<'trace> { - lu0: &'trace mut [[PackedM31; 2]], - lu1: &'trace mut [[PackedM31; 4]], -} - -impl<'trace> Producer for RowProducer<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn split_at(self, index: usize) -> (Self, Self) { - let (lu0, rh0) = self.lu0.split_at_mut(index); - let (lu1, rh1) = self.lu1.split_at_mut(index); - (RowProducer { lu0, lu1 }, RowProducer { lu0: rh0, lu1: rh1 }) - } - - type IntoIter = LookupDataIterMut<'trace>; - - fn into_iter(self) -> Self::IntoIter { - LookupDataIterMut::new(self.lu0, self.lu1) - } -} - -pub struct ParLookupDataIterMut<'trace> { - lu0: &'trace mut [[PackedM31; 2]], - lu1: &'trace mut [[PackedM31; 4]], -} - -impl<'trace> ParLookupDataIterMut<'trace> { - pub fn new(slice0: &'trace mut [[PackedM31; 2]], slice1: &'trace mut [[PackedM31; 4]]) -> Self { - Self { - lu0: slice0, - lu1: slice1, - } - } -} - -impl<'trace> ParallelIterator for ParLookupDataIterMut<'trace> { - type Item = LookupDataMutChunk<'trace>; - - fn drive_unindexed(self, consumer: D) -> D::Result - where - D: UnindexedConsumer, - { - bridge(self, consumer) - } - - fn opt_len(&self) -> Option { - Some(self.len()) - } -} - -impl IndexedParallelIterator for ParLookupDataIterMut<'_> { - fn len(&self) -> usize { - self.lu0.len() - } - - fn drive>(self, consumer: D) -> D::Result { - bridge(self, consumer) - } - - fn with_producer>(self, callback: CB) -> CB::Output { - callback.callback(RowProducer { - lu0: self.lu0, - lu1: self.lu1, - }) - } -} #[cfg(test)] mod tests { @@ -181,7 +35,7 @@ mod tests { let mut trace = ComponentTrace::::zeroed(LOG_SIZE); let arr = (0..1 << LOG_SIZE).map(M31::from).collect_vec(); let mut lookup_data = unsafe { LookupData::uninitialized(LOG_SIZE) }; - let expected: (Vec<_>, Vec<_>) = arr + let expected: (Vec<_>, Vec<_>, Vec<_>) = arr .array_chunks::() .map(|x| { let x = PackedM31::from_array(*x); @@ -189,9 +43,22 @@ mod tests { let x2 = x + x1; let x3 = x + x1 + x2; let x4 = x + x1 + x2 + x3; - ([x, x4], [x1, x1.double(), x2, x2.double()]) + ( + [x, x4], + [x1, x1.double(), x2, x2.double()], + [ + x3, + x3.double(), + x4, + x4.double(), + x, + x.double(), + x1, + x1.double(), + ], + ) }) - .unzip(); + .multiunzip(); trace .par_iter_mut() @@ -207,6 +74,35 @@ mod tests { *row[4] = *row[0] + *row[1] + *row[2] + *row[3]; *lookup_data.lu0 = [*row[0], *row[4]]; *lookup_data.lu1 = [*row[1], row[1].double(), *row[2], row[2].double()]; + *lookup_data.lu2 = [ + *row[3], + row[3].double(), + *row[4], + row[4].double(), + *row[0], + row[0].double(), + *row[1], + row[1].double(), + ]; + *lookup_data.lu3[0] = [ + *row[3], + row[3].double(), + *row[4], + row[4].double(), + *row[0], + row[0].double(), + *row[1], + row[1].double(), + *row[3], + row[3].double(), + *row[4], + row[4].double(), + *row[0], + row[0].double(), + *row[1], + row[1].double(), + ] + }) }); diff --git a/crates/air_utils_derive/Cargo.toml b/crates/air_utils_derive/Cargo.toml new file mode 100644 index 000000000..0f36c43af --- /dev/null +++ b/crates/air_utils_derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "stwo-air-utils-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "2.0.90" +quote = "1.0.37" +itertools = "0.13.0" +proc-macro2 = "1.0.92" diff --git a/crates/air_utils_derive/src/expand_impl.rs b/crates/air_utils_derive/src/expand_impl.rs new file mode 100644 index 000000000..e69de29bb diff --git a/crates/air_utils_derive/src/expand_iterator.rs b/crates/air_utils_derive/src/expand_iterator.rs new file mode 100644 index 000000000..e69de29bb diff --git a/crates/air_utils_derive/src/iterable_field.rs b/crates/air_utils_derive/src/iterable_field.rs new file mode 100644 index 000000000..7edca6ac6 --- /dev/null +++ b/crates/air_utils_derive/src/iterable_field.rs @@ -0,0 +1,232 @@ +use proc_macro2::TokenStream; +use syn::{Expr, Ident, Lifetime, Type}; +use quote::{format_ident, quote}; + +pub(super) trait IterableField { + fn name(&self) -> &Ident; + fn r#type(&self) -> &Type; + fn mut_slice_type(&self, lifetime: &Lifetime) -> TokenStream; + fn mut_chunk_type(&self, lifetime: &Lifetime) -> TokenStream; + fn mut_ptr_type(&self) -> TokenStream; + fn uninitialized(&self) -> TokenStream; + fn split_first(&self) -> TokenStream; + fn split_last(&self, length: &Ident) -> TokenStream; + fn split_at(&self, index: Ident) -> TokenStream; + fn as_mut_slice(&self) -> TokenStream; + fn as_mut_ptr(&self) -> TokenStream; + fn len(&self) -> TokenStream; +} + +pub(super) struct PlainVec { + pub(super) name: Ident, + pub(super) r#type: Type, +} +impl IterableField for PlainVec { + fn name(&self) -> &Ident { + &self.name + } + + fn r#type(&self) -> &Type { + &self.r#type + } + + fn uninitialized(&self) -> TokenStream { + let name = self.name(); + quote! { + let mut #name= Vec::with_capacity(n_simd_elems); + #name.set_len(n_simd_elems); + } + } + + fn split_first(&self) -> TokenStream { + let name = self.name(); + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let (#head, #tail) = self.#name.split_at_mut(1); + self.#name = #tail; + let #name = &mut (*(#head))[0]; + } + } + + fn split_last(&self, length: &Ident) -> TokenStream { + let name = self.name(); + let head = format_ident!("{}_head", name); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #head, + #tail, + ) = self.#name.split_at_mut(#length - 1); + self.#name = #head; + let #name = &mut (*#tail)[0]; + } + } + + fn split_at(&self, index: Ident) -> TokenStream { + let name = self.name(); + let tail = format_ident!("{}_tail", name); + quote! { + let ( + #name, + #tail + ) = self.#name.split_at_mut(#index); + } + } + + fn as_mut_slice(&self) -> TokenStream { + let name = self.name(); + quote! { + #name.as_mut_slice() + } + } + + fn len(&self) -> TokenStream { + let name = self.name(); + quote! { + #name.len() + } + } + + fn mut_slice_type(&self, lifetime: &Lifetime) -> TokenStream { + let r#type = &self.r#type; + quote! { + &#lifetime mut [#r#type] + } + } + + fn mut_ptr_type(&self) -> TokenStream { + let r#type = &self.r#type; + quote! { + *mut [#r#type] + } + } + + fn as_mut_ptr(&self) -> TokenStream { + let name = self.name(); + quote! { + #name as *mut _ + } + } + + fn mut_chunk_type(&self, lifetime: &Lifetime) -> TokenStream { + let r#type = &self.r#type; + quote! { + &#lifetime mut #r#type + } + } +} + +#[allow(dead_code)] +pub(super) struct ArrayOfVecs { + pub(super) name: Ident, + pub(super) r#type: Type, + pub(super) inner_type: Type, + pub(super) outer_array_size: Expr, +} +impl IterableField for ArrayOfVecs { + fn name(&self) -> &Ident { + &self.name + } + + fn r#type(&self) -> &Type { + &self.r#type + } + + fn uninitialized(&self) -> TokenStream { + let name = self.name(); + let outer_array_size = &self.outer_array_size; + quote! { + let #name = [(); #outer_array_size].map(|_| { + let mut vec = Vec::with_capacity(n_simd_elems); + vec.set_len(n_simd_elems); + vec + }); + } + } + + fn split_first(&self) -> TokenStream { + let name = self.name(); + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(1); + *v = tail; + &mut (*head)[0] + }); + } + } + + fn split_last(&self, length: &Ident) -> TokenStream { + let name = self.name(); + quote! { + let #name = self.#name.each_mut().map(|v| { + let (head, tail) = v.split_at_mut(#length - 1); + *v = head; + &mut (*tail)[0] + }); + } + } + + fn split_at(&self, index: Ident) -> TokenStream { + let name = self.name(); + let tail = format_ident!("{}_tail", name); + let array_size = &self.outer_array_size; + quote! { + let ( + mut #name, + mut #tail + ):([_; #array_size],[_; #array_size]) = unsafe { (std::mem::zeroed(), std::mem::zeroed()) }; + self.#name.into_iter().enumerate().for_each(|(i, v)| { + let (head, tail) = v.split_at_mut(#index); + #name[i] = head; + #tail[i] = tail; + }); + } + } + + fn as_mut_slice(&self) -> TokenStream { + let name = self.name(); + quote! { + #name.each_mut().map(|v| v.as_mut_slice()) + } + } + + fn len(&self) -> TokenStream { + let name = self.name(); + quote! { + #name[0].len() + } + } + + fn mut_slice_type(&self, lifetime: &Lifetime) -> TokenStream { + let inner_type = &self.inner_type; + let outer_array_size = &self.outer_array_size; + quote! { + [&#lifetime mut [#inner_type]; #outer_array_size] + } + } + + fn mut_ptr_type(&self) -> TokenStream { + let inner_type = &self.inner_type; + let outer_array_size = &self.outer_array_size; + quote! { + [*mut [#inner_type]; #outer_array_size] + } + } + + // From mut slice to mut ptr. + fn as_mut_ptr(&self) -> TokenStream { + let name = self.name(); + quote! { + #name.map(|v| v as *mut _) + } + } + + fn mut_chunk_type(&self, lifetime: &Lifetime) -> TokenStream { + let inner_type = &self.inner_type; + let array_size = &self.outer_array_size; + quote! { + [&#lifetime mut #inner_type; #array_size] + } + } +} \ No newline at end of file diff --git a/crates/air_utils_derive/src/lib.rs b/crates/air_utils_derive/src/lib.rs new file mode 100644 index 000000000..b9d7138e8 --- /dev/null +++ b/crates/air_utils_derive/src/lib.rs @@ -0,0 +1,297 @@ +#![allow(dead_code)] +#![allow(unused_variables)] +mod iterable_field; +mod expand_iterator; +mod expand_impl; +use iterable_field::{ArrayOfVecs, IterableField, PlainVec}; +use itertools::Itertools; +use proc_macro2::Span; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, Data, DeriveInput, Fields, Lifetime, Type}; + +#[proc_macro_derive(StwoIterable)] +pub fn derive_stwo_iterable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let struct_name = &input.ident; + let input = match input.data { + Data::Struct(data_struct) => data_struct, + _ => panic!("Expected struct"), + }; + + let fields = match input.fields { + Fields::Named(fields) => fields.named, + _ => panic!("Expected named fields"), + }; + + let mut non_array_fields: Vec = vec![]; + let mut array_fields: Vec = vec![]; + + for field in fields { + match field.ty { + Type::Array(ref outer_array) => { + // Assert that the inner type is a Vec and get T: + let inner_type = match outer_array.elem.as_ref() { + Type::Path(ref type_path) => { + if let Some(last_segment) = type_path.path.segments.last() { + if last_segment.ident != "Vec" { + panic!("Expected Vec type"); + } + if let syn::PathArguments::AngleBracketed(ref args) = last_segment.arguments { + if args.args.len() != 1 { + panic!("Expected one type argument"); + } + if let syn::GenericArgument::Type(inner_type) = args.args.first().unwrap() { + inner_type + } else { + panic!("Expected type argument"); + } + } else { + panic!("Expected angle-bracketed arguments"); + } + } else { + panic!("Expected last segment"); + } + }, + _ => panic!("Expected path"), + }; + array_fields.push(ArrayOfVecs { + name: field.ident.unwrap(), + r#type: field.ty.clone(), + outer_array_size: outer_array.len.clone(), + inner_type: inner_type.clone(), + }); + }, + Type::Path(ref type_path) => { + // Assert that the type is Vec and get T: + let r#type = match type_path.path.segments.last() { + Some(last_segment) => { + if last_segment.ident != "Vec" { + panic!("Expected Vec type"); + } + if let syn::PathArguments::AngleBracketed(ref args) = last_segment.arguments { + if args.args.len() != 1 { + panic!("Expected one type argument"); + } + if let syn::GenericArgument::Type(r#type) = args.args.first().unwrap() { + r#type + } else { + panic!("Expected type argument"); + } + } else { + panic!("Expected angle-bracketed arguments"); + } + }, + None => panic!("Expected last segment"), + }.clone(); + non_array_fields.push(PlainVec { + name: field.ident.unwrap(), + r#type, + }); + }, + _ => panic!("Expected vector or array of vectors"), + } + } + + let iterable_fields = non_array_fields + .iter() + .map(|f| f as &dyn IterableField) + .chain(array_fields.iter().map(|f| f as &dyn IterableField)) + .collect_vec(); + + let mut_chunk_name = format_ident!("{}MutChunk", struct_name); + let iter_mut_name = format_ident!("{}IterMut", struct_name); + let row_producer_name = format_ident!("{}RowProducer", struct_name); + let par_iter_mut_name = format_ident!("Par{}IterMut", struct_name); + + let field_names = iterable_fields.iter().map(|f| f.name()).collect_vec(); + let field_types = iterable_fields.iter().map(|f| f.r#type()).collect_vec(); + let uninitialized_fields = iterable_fields + .iter() + .map(|f| f.uninitialized()) + .collect_vec(); + let as_mut_slice = iterable_fields + .iter() + .map(|f| f.as_mut_slice()) + .collect_vec(); + let mut_slice_types = iterable_fields + .iter() + .map(|f| f.mut_slice_type(&Lifetime::new("'trace", Span::call_site()))) + .collect_vec(); + let field_ptr_types = iterable_fields.iter().map(|f| f.mut_ptr_type()).collect_vec(); + let split_first = iterable_fields.iter().map(|f| f.split_first()).collect_vec(); + let split_last = iterable_fields.iter().map(|f| f.split_last(&format_ident!("len"))).collect_vec(); + let split_at = iterable_fields + .iter() + .map(|f| f.split_at(format_ident!("index"))) + .collect_vec(); + let field_names_tail = field_names.iter().map(|f| format_ident!("{}_tail", f)).collect_vec(); + let length_function = iterable_fields.first().unwrap().len(); + let as_mut_ptr = iterable_fields.iter().map(|f| f.as_mut_ptr()).collect_vec(); + let mut_chunk_types = iterable_fields + .iter() + .map(|f| f.mut_chunk_type(&Lifetime::new("'trace", Span::call_site()))) + .collect_vec(); + + let expansions = quote! { + impl #struct_name { + /// # Safety + /// The caller must ensure that the trace is populated before being used. + #[allow(clippy::uninit_vec)] + pub unsafe fn uninitialized(log_size: u32) -> Self { + let length: usize = 1 << log_size; + let n_simd_elems = length.div_ceil(N_LANES); + #(#uninitialized_fields)* + Self { + #(#field_names,)* + } + } + + pub fn iter_mut(&mut self) -> #iter_mut_name<'_> { + #iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + + pub fn par_iter_mut(&mut self) -> #par_iter_mut_name<'_> { + #par_iter_mut_name::new( + #(self.#as_mut_slice,)* + ) + } + } + + pub struct #mut_chunk_name<'trace> { + #(#field_names: #mut_chunk_types,)* + } + + pub struct #iter_mut_name<'trace> { + #(#field_names: #field_ptr_types,)* + phantom: std::marker::PhantomData<&'trace ()>, + } + + impl<'trace> #iter_mut_name<'trace> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names: #as_mut_ptr,)* + phantom: std::marker::PhantomData, + } + } + } + + impl<'trace> Iterator for #iter_mut_name<'trace> { + type Item = #mut_chunk_name<'trace>; + fn next(&mut self) -> Option { + if self.#length_function == 0 { + return None; + } + let item = unsafe { + #(#split_first)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.#length_function; + (len, Some(len)) + } + } + + impl ExactSizeIterator for #iter_mut_name<'_> {} + impl DoubleEndedIterator for #iter_mut_name<'_> { + fn next_back(&mut self) -> Option { + let len = self.#length_function; + if len == 0 { + return None; + } + let item = unsafe { + #(#split_last)* + #mut_chunk_name { + #(#field_names,)* + } + }; + Some(item) + } + } + + pub struct #row_producer_name<'trace> { + #(#field_names: #mut_slice_types,)* + } + + impl<'trace> Producer for #row_producer_name<'trace> { + type Item = #mut_chunk_name<'trace>; + type IntoIter = #iter_mut_name<'trace>; + + #[allow(invalid_value)] + fn split_at(self, index: usize) -> (Self, Self) { + #(#split_at)* + + ( + #row_producer_name { + #(#field_names,)* + }, + #row_producer_name { + #(#field_names: #field_names_tail,)* + } + ) + + } + + fn into_iter(self) -> Self::IntoIter { + #iter_mut_name::new(#(self.#field_names),*) + } + } + + pub struct #par_iter_mut_name<'trace> { + #(#field_names: #mut_slice_types,)* + } + + impl<'trace> #par_iter_mut_name<'trace> { + pub fn new( + #(#field_names: #mut_slice_types,)* + ) -> Self { + Self { + #(#field_names,)* + } + } + } + + impl<'trace> ParallelIterator for #par_iter_mut_name<'trace> { + type Item = #mut_chunk_name<'trace>; + + fn drive_unindexed(self, consumer: D) -> D::Result + where + D: UnindexedConsumer, + { + bridge(self, consumer) + } + + fn opt_len(&self) -> Option { + Some(self.len()) + } + } + + impl IndexedParallelIterator for #par_iter_mut_name<'_> { + fn len(&self) -> usize { + self.#length_function + } + + fn drive>(self, consumer: D) -> D::Result { + bridge(self, consumer) + } + + fn with_producer>(self, callback: CB) -> CB::Output { + callback.callback( + #row_producer_name { + #(#field_names : self.#field_names,)* + } + ) + } + } + }; + + proc_macro::TokenStream::from(expansions) +}