Skip to content

Commit

Permalink
Fall back to cpu in small fft size.
Browse files Browse the repository at this point in the history
  • Loading branch information
alonh5 committed Sep 17, 2024
1 parent a51f630 commit 77b7cdd
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
25 changes: 18 additions & 7 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use std::mem::transmute;
use bytemuck::{cast_slice, Zeroable};
use num_traits::One;

use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE};
use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::{Col, CpuBackend};
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand Down Expand Up @@ -123,7 +123,7 @@ impl SimdBackend {
// Decide if and when it's ok and what to do if it's not.
impl PolyOps for SimdBackend {
// The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation
// requries one of the numbers to be shifted left by 1 bit. This is not a reduced
// requires one of the numbers to be shifted left by 1 bit. This is not a reduced
// representation of the field.
type Twiddles = Vec<u32>;

Expand All @@ -143,9 +143,13 @@ impl PolyOps for SimdBackend {
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
twiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
let mut values = eval.values;
let log_size = values.length.ilog2();
let log_size = eval.values.length.ilog2();
if log_size < MIN_FFT_LOG_SIZE {
let cpu_poly = eval.to_cpu().interpolate();
return CirclePoly::new(cpu_poly.coeffs.into_iter().collect());
}

let mut values = eval.values;
let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles);

// Safe because [PackedBaseField] is aligned on 64 bytes.
Expand Down Expand Up @@ -221,15 +225,22 @@ impl PolyOps for SimdBackend {
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// TODO(spapini): Precompute twiddles.
// TODO(spapini): Handle small cases.
let log_size = domain.log_size();
let fft_log_size = poly.log_size();
assert!(
log_size >= fft_log_size,
"Can only evaluate on larger domains"
);

if fft_log_size < MIN_FFT_LOG_SIZE {
let cpu_poly: CirclePoly<CpuBackend> = CirclePoly::new(poly.coeffs.to_cpu());
let cpu_eval = cpu_poly.evaluate(domain);
return CircleEvaluation::new(
cpu_eval.domain,
Col::<SimdBackend, BaseField>::from_iter(cpu_eval.values),
);
}

let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles);

// Evaluate on a big domains by evaluating on several subdomains.
Expand Down
7 changes: 1 addition & 6 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ impl QuotientOps for SimdBackend {
// Fall back to the CPU backend for small domains.
let columns = columns
.iter()
.map(|circle_eval| {
CircleEvaluation::<CpuBackend, _, BitReversedOrder>::new(
circle_eval.domain,
circle_eval.values.to_cpu(),
)
})
.map(|circle_eval| circle_eval.to_cpu())
.collect_vec();
let eval = CpuBackend::accumulate_quotients(
domain,
Expand Down
12 changes: 11 additions & 1 deletion crates/prover/src/core/poly/circle/evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use educe::Educe;

use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps};
use crate::core::backend::cpu::CpuCircleEvaluation;
use crate::core::backend::{Col, Column};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePointIndex, Coset};
use crate::core::fields::m31::BaseField;
use crate::core::fields::{ExtensionOf, FieldOps};
Expand Down Expand Up @@ -107,6 +108,15 @@ impl<B: FieldOps<F>, F: ExtensionOf<BaseField>> CircleEvaluation<B, F, BitRevers
}
}

impl<F: ExtensionOf<BaseField>, EvalOrder> CircleEvaluation<SimdBackend, F, EvalOrder>
where
SimdBackend: FieldOps<F>,
{
pub fn to_cpu(&self) -> CircleEvaluation<CpuBackend, F, EvalOrder> {
CircleEvaluation::new(self.domain, self.values.to_cpu())
}
}

impl<B: FieldOps<F>, F: ExtensionOf<BaseField>, EvalOrder> Deref
for CircleEvaluation<B, F, EvalOrder>
{
Expand Down

0 comments on commit 77b7cdd

Please sign in to comment.