Skip to content

Commit

Permalink
implement StwoAdd and StwoData
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Jan 6, 2025
1 parent 0222d6e commit 53c2599
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 18 deletions.
Empty file removed backends/stwo/src/air/data.rs
Empty file.
1 change: 0 additions & 1 deletion backends/stwo/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use stwo_prover::core::{

pub mod add;
pub mod tensor;
pub mod data;

pub trait Circuit<B: Backend> {
type Component;
Expand Down
24 changes: 18 additions & 6 deletions backends/stwo/src/compiler/data.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
use stwo_prover::core::backend::simd::m31::PackedBaseField;
use luminal::prelude::*;
use stwo_prover::core::backend::simd::m31::PackedBaseField;
use std::{fmt::Debug, sync::Arc};

#[derive(Clone, Debug)]
pub struct StwoData(pub Arc<Vec<PackedBaseField>>);

// #[derive(Debug)]
// pub struct StwoData(pub Vec<PackedBaseField>);
impl StwoData {
pub fn as_slice(&self) -> &[PackedBaseField] {
&self.0
}
}

impl Data for StwoData {
fn as_any(&self) -> &dyn std::any::Any {
self
}

// fn get_buffer_from_tensor<'a>(tensor: &'a InputTensor) -> &'a Vec<PackedBaseField> {
// &tensor.borrowed().downcast_ref::<StwoData>().unwrap().0
// }
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
32 changes: 21 additions & 11 deletions backends/stwo/src/compiler/prim.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::any::{Any, TypeId};
use std::{any::{Any, TypeId}, sync::Arc};

use luminal::prelude::*;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::simd::{m31::LOG_N_LANES, SimdBackend};

use crate::air::add::trace::TensorAddTracer;
use crate::{air::{add::trace::TensorAddTracer, tensor::AirTensor}, compiler::data::StwoData};

#[derive(Debug)]
pub struct PrimitiveCompiler {}
Expand All @@ -17,17 +17,27 @@ impl Operator for StwoAdd {
panic!("Add operator requires exactly two input tensors.");
}

// TODO:
// - convert inputs A and B into AirTensors for SIMD backend
// - calculate log_size
let (a_tensor, a_shape) = &inp[0];
let (b_tensor, b_shape) = &inp[1];

let (_trace, c) = SimdBackend::generate_trace(todo!(), todo!(), todo!());
// Extract data from tensors
let a_data = a_tensor.borrowed().downcast_ref::<StwoData>().unwrap();
let b_data = b_tensor.borrowed().downcast_ref::<StwoData>().unwrap();

// TODO:
// - convert result C into luminal Tensor
// - return vec![c_tensor]
// Create AirTensors
let a = AirTensor::new(a_data.as_slice(), a_shape.shape_usize());
let b = AirTensor::new(b_data.as_slice(), b_shape.shape_usize());

todo!()
// Calculate required trace size based on tensor dimensions
let max_size = a.size().max(b.size());
let required_log_size = ((max_size + (1 << LOG_N_LANES) - 1) >> LOG_N_LANES)
.next_power_of_two()
.trailing_zeros() + LOG_N_LANES;

// Generate trace and get result tensor
let (_trace, c) = SimdBackend::generate_trace(required_log_size, &a, &b);

vec![Tensor::new(StwoData(Arc::new(c.data().to_vec())))]
}
}

Expand Down

0 comments on commit 53c2599

Please sign in to comment.