From 77b7cdda98e4007b015570b28144c71262f485bd Mon Sep 17 00:00:00 2001 From: Alon Haramati Date: Mon, 9 Sep 2024 11:56:46 +0300 Subject: [PATCH] Fall back to cpu in small fft size. --- crates/prover/src/core/backend/simd/circle.rs | 25 +++++++++++++------ .../prover/src/core/backend/simd/quotients.rs | 7 +----- .../prover/src/core/poly/circle/evaluation.rs | 12 ++++++++- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index e930f77b2..7ad0ab149 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -4,12 +4,12 @@ use std::mem::transmute; use bytemuck::{cast_slice, Zeroable}; use num_traits::One; -use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE}; +use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::{Col, CpuBackend}; +use crate::core::backend::{Col, Column, CpuBackend}; use crate::core::circle::{CirclePoint, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; @@ -123,7 +123,7 @@ impl SimdBackend { // Decide if and when it's ok and what to do if it's not. impl PolyOps for SimdBackend { // The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation - // requries one of the numbers to be shifted left by 1 bit. This is not a reduced + // requires one of the numbers to be shifted left by 1 bit. This is not a reduced // representation of the field. type Twiddles = Vec; @@ -143,9 +143,13 @@ impl PolyOps for SimdBackend { eval: CircleEvaluation, twiddles: &TwiddleTree, ) -> CirclePoly { - let mut values = eval.values; - let log_size = values.length.ilog2(); + let log_size = eval.values.length.ilog2(); + if log_size < MIN_FFT_LOG_SIZE { + let cpu_poly = eval.to_cpu().interpolate(); + return CirclePoly::new(cpu_poly.coeffs.into_iter().collect()); + } + let mut values = eval.values; let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); // Safe because [PackedBaseField] is aligned on 64 bytes. @@ -221,8 +225,6 @@ impl PolyOps for SimdBackend { domain: CircleDomain, twiddles: &TwiddleTree, ) -> CircleEvaluation { - // TODO(spapini): Precompute twiddles. - // TODO(spapini): Handle small cases. let log_size = domain.log_size(); let fft_log_size = poly.log_size(); assert!( @@ -230,6 +232,15 @@ impl PolyOps for SimdBackend { "Can only evaluate on larger domains" ); + if fft_log_size < MIN_FFT_LOG_SIZE { + let cpu_poly: CirclePoly = CirclePoly::new(poly.coeffs.to_cpu()); + let cpu_eval = cpu_poly.evaluate(domain); + return CircleEvaluation::new( + cpu_eval.domain, + Col::::from_iter(cpu_eval.values), + ); + } + let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); // Evaluate on a big domains by evaluating on several subdomains. diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index ff8e1c580..764ba8439 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -40,12 +40,7 @@ impl QuotientOps for SimdBackend { // Fall back to the CPU backend for small domains. let columns = columns .iter() - .map(|circle_eval| { - CircleEvaluation::::new( - circle_eval.domain, - circle_eval.values.to_cpu(), - ) - }) + .map(|circle_eval| circle_eval.to_cpu()) .collect_vec(); let eval = CpuBackend::accumulate_quotients( domain, diff --git a/crates/prover/src/core/poly/circle/evaluation.rs b/crates/prover/src/core/poly/circle/evaluation.rs index 4cf23b976..8354a8ac6 100644 --- a/crates/prover/src/core/poly/circle/evaluation.rs +++ b/crates/prover/src/core/poly/circle/evaluation.rs @@ -5,7 +5,8 @@ use educe::Educe; use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps}; use crate::core::backend::cpu::CpuCircleEvaluation; -use crate::core::backend::{Col, Column}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column, CpuBackend}; use crate::core::circle::{CirclePointIndex, Coset}; use crate::core::fields::m31::BaseField; use crate::core::fields::{ExtensionOf, FieldOps}; @@ -107,6 +108,15 @@ impl, F: ExtensionOf> CircleEvaluation, EvalOrder> CircleEvaluation +where + SimdBackend: FieldOps, +{ + pub fn to_cpu(&self) -> CircleEvaluation { + CircleEvaluation::new(self.domain, self.values.to_cpu()) + } +} + impl, F: ExtensionOf, EvalOrder> Deref for CircleEvaluation {