diff --git a/internal/stats/latest_stats.csv b/internal/stats/latest_stats.csv index eb7b4efb7..b3e5b2bb9 100644 --- a/internal/stats/latest_stats.csv +++ b/internal/stats/latest_stats.csv @@ -195,14 +195,14 @@ pairing_bn254,bls24_315,plonk,0,0 pairing_bn254,bls24_317,plonk,0,0 pairing_bn254,bw6_761,plonk,0,0 pairing_bn254,bw6_633,plonk,0,0 -pairing_bw6761,bn254,groth16,3014749,4979960 +pairing_bw6761,bn254,groth16,1843705,3084217 pairing_bw6761,bls12_377,groth16,0,0 pairing_bw6761,bls12_381,groth16,0,0 pairing_bw6761,bls24_315,groth16,0,0 pairing_bw6761,bls24_317,groth16,0,0 pairing_bw6761,bw6_761,groth16,0,0 pairing_bw6761,bw6_633,groth16,0,0 -pairing_bw6761,bn254,plonk,11486969,10777222 +pairing_bw6761,bn254,plonk,6947630,6315782 pairing_bw6761,bls12_377,plonk,0,0 pairing_bw6761,bls12_381,plonk,0,0 pairing_bw6761,bls24_315,plonk,0,0 diff --git a/std/algebra/emulated/fields_bw6761/e6.go b/std/algebra/emulated/fields_bw6761/e6.go index 433b92cc5..7a7cbd015 100644 --- a/std/algebra/emulated/fields_bw6761/e6.go +++ b/std/algebra/emulated/fields_bw6761/e6.go @@ -190,9 +190,7 @@ func (e Ext6) mulFpByNonResidue(fp *curveF, x *baseEl) *baseEl { } func (e Ext6) Mul(x, y *E6) *E6 { - x = e.Reduce(x) - y = e.Reduce(y) - return e.mulToomCook6(x, y) + return e.mulDirect(x, y) } func (e Ext6) mulMontgomery6(x, y *E6) *E6 { @@ -426,6 +424,37 @@ func (e Ext6) mulMontgomery6(x, y *E6) *E6 { } } +func (e Ext6) mulDirect(x, y *E6) *E6 { + nonResidue := e.fp.NewElement(-4) + // c0 = a0b0 + β(a1b5 + a2b4 + a3b3 + a4b2 + a5b1) + c0 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A0}, {nonResidue, &x.A1, &y.A5}, {nonResidue, &x.A2, &y.A4}, {nonResidue, &x.A3, &y.A3}, {nonResidue, &x.A4, &y.A2}, {nonResidue, &x.A5, &y.A1}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c1 = a0b1 + a1b0 + β(a2b5 + a3b4 + a4b3 + a5b2) + c1 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A1}, {&x.A1, &y.A0}, {nonResidue, &x.A2, &y.A5}, {nonResidue, &x.A3, &y.A4}, {nonResidue, &x.A4, &y.A3}, {nonResidue, &x.A5, &y.A2}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c2 = a0b2 + a1b1 + a2b0 + β(a3b5 + a4b4 + a5b3) + c2 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A2}, {&x.A1, &y.A1}, {&x.A2, &y.A0}, {nonResidue, &x.A3, &y.A5}, {nonResidue, &x.A4, &y.A4}, {nonResidue, &x.A5, &y.A3}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c3 = a0b3 + a1b2 + a2b1 + a3b0 + β(a4b5 + a5b4) + c3 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A3}, {&x.A1, &y.A2}, {&x.A2, &y.A1}, {&x.A3, &y.A0}, {nonResidue, &x.A4, &y.A5}, {nonResidue, &x.A5, &y.A4}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c4 = a0b4 + a1b3 + a2b2 + a3b1 + a4b0 + βa5b5 + c4 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A4}, {&x.A1, &y.A3}, {&x.A2, &y.A2}, {&x.A3, &y.A1}, {&x.A4, &y.A0}, {nonResidue, &x.A5, &y.A5}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c5 = a0b5 + a1b4 + a2b3 + a3b2 + a4b1 + a5b0, + c5 := e.fp.Eval([][]*baseEl{{&x.A0, &y.A5}, {&x.A1, &y.A4}, {&x.A2, &y.A3}, {&x.A3, &y.A2}, {&x.A4, &y.A1}, {&x.A5, &y.A0}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + + return &E6{ + A0: *c0, + A1: *c1, + A2: *c2, + A3: *c3, + A4: *c4, + A5: *c5, + } +} + func (e Ext6) mulToomCook6(x, y *E6) *E6 { // Toom-Cook 6-way multiplication: // @@ -704,6 +733,43 @@ func (e Ext6) mulToomCook6(x, y *E6) *E6 { } func (e Ext6) Square(x *E6) *E6 { + // return e.squarePolyWithRand(x, e.fp.NewElement(-1)) + return e.squareDirect(x) +} + +// squareDirect computes the square of an element in E6 using schoolbook multiplication. +func (e Ext6) squareDirect(x *E6) *E6 { + nonResidue := e.fp.NewElement(-4) + // c0 = a0b0 + β(a1b5 + a2b4 + a3b3 + a4b2 + a5b1) + c0 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A0}, {nonResidue, &x.A1, &x.A5}, {nonResidue, &x.A2, &x.A4}, {nonResidue, &x.A3, &x.A3}, {nonResidue, &x.A4, &x.A2}, {nonResidue, &x.A5, &x.A1}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c1 = a0b1 + a1b0 + β(a2b5 + a3b4 + a4b3 + a5b2) + c1 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A1}, {&x.A1, &x.A0}, {nonResidue, &x.A2, &x.A5}, {nonResidue, &x.A3, &x.A4}, {nonResidue, &x.A4, &x.A3}, {nonResidue, &x.A5, &x.A2}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c2 = a0b2 + a1b1 + a2b0 + β(a3b5 + a4b4 + a5b3) + c2 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A2}, {&x.A1, &x.A1}, {&x.A2, &x.A0}, {nonResidue, &x.A3, &x.A5}, {nonResidue, &x.A4, &x.A4}, {nonResidue, &x.A5, &x.A3}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c3 = a0b3 + a1b2 + a2b1 + a3b0 + β(a4b5 + a5b4) + c3 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A3}, {&x.A1, &x.A2}, {&x.A2, &x.A1}, {&x.A3, &x.A0}, {nonResidue, &x.A4, &x.A5}, {nonResidue, &x.A5, &x.A4}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c4 = a0b4 + a1b3 + a2b2 + a3b1 + a4b0 + βa5b5 + c4 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A4}, {&x.A1, &x.A3}, {&x.A2, &x.A2}, {&x.A3, &x.A1}, {&x.A4, &x.A0}, {nonResidue, &x.A5, &x.A5}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c5 = a0b5 + a1b4 + a2b3 + a3b2 + a4b1 + a5b0, + c5 := e.fp.Eval([][]*baseEl{{&x.A0, &x.A5}, {&x.A1, &x.A4}, {&x.A2, &x.A3}, {&x.A3, &x.A2}, {&x.A4, &x.A1}, {&x.A5, &x.A0}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + + return &E6{ + A0: *c0, + A1: *c1, + A2: *c2, + A3: *c3, + A4: *c4, + A5: *c5, + } +} + +func (e Ext6) squareEmulatedTower(x *E6) *E6 { // We don't use Montgomery-6 or Toom-Cook-6 for the squaring but instead we // simulate a quadratic over cubic extension tower because Karatsuba over // Chung-Hasan SQR2 is better constraint wise. @@ -792,6 +858,45 @@ func (e Ext6) Square(x *E6) *E6 { // https://eprint.iacr.org/2010/542.pdf // Sec. 5.6 with minor modifications to fit our tower func (e Ext6) CyclotomicSquareKarabina12345(x *E6) *E6 { + return e.cyclotomicSquareKarabina12345Eval(x) +} + +// cyclotomicSquareKarabina12345Eval computes +// [Ext6.cyclotomicSquareKarabina12345] but with the non-native Eval method. +func (e Ext6) cyclotomicSquareKarabina12345Eval(x *E6) *E6 { + c := e.fp.NewElement(-4) + mone := e.fp.NewElement(-1) + g1 := x.A2 + g2 := x.A4 + g3 := x.A1 + g4 := x.A3 + g5 := x.A5 + // h1 = 3*c*g2^2 + 3*g3^2 - 2*g1 + h1 := e.fp.Eval([][]*baseEl{{c, &g2, &g2}, {&g3, &g3}, {mone, &g1}}, []*big.Int{big.NewInt(3), big.NewInt(3), big.NewInt(2)}) + // h2 = 3*c*g5^2 + 3*g1^2 - 2*g2 + h2 := e.fp.Eval([][]*baseEl{{c, &g5, &g5}, {&g1, &g1}, {mone, &g2}}, []*big.Int{big.NewInt(3), big.NewInt(3), big.NewInt(2)}) + // h3 = 6*c*g1*g5 + 2*g3 + h3 := e.fp.Eval([][]*baseEl{{c, &g1, &g5}, {&g3}}, []*big.Int{big.NewInt(6), big.NewInt(2)}) + // h4 = 3*c*g2*g5 + 3*g1*g3 - g4 + h4 := e.fp.Eval([][]*baseEl{{c, &g2, &g5}, {&g1, &g3}, {mone, &g4}}, []*big.Int{big.NewInt(3), big.NewInt(3), big.NewInt(1)}) + // h5 = 6*g2*g3 + 2*g5 + h5 := e.fp.Eval([][]*baseEl{{&g2, &g3}, {&g5}}, []*big.Int{big.NewInt(6), big.NewInt(2)}) + + return &E6{ + A0: x.A0, + A1: *h3, + A2: *h1, + A3: *h4, + A4: *h2, + A5: *h5, + } + +} + +// Karabina's compressed cyclotomic square SQR12345 +// https://eprint.iacr.org/2010/542.pdf +// Sec. 5.6 with minor modifications to fit our tower +func (e Ext6) cyclotomicSquareKarabina12345(x *E6) *E6 { x = e.Reduce(x) // h4 = -g4 + 3((g3+g5)(g1+c*g2)-g1g5-c*g3g2) @@ -852,6 +957,32 @@ func (e Ext6) CyclotomicSquareKarabina12345(x *E6) *E6 { // DecompressKarabina12345 decompresses Karabina's cyclotomic square result SQR12345 func (e Ext6) DecompressKarabina12345(x *E6) *E6 { + return e.decompressKarabina12345Eval(x) +} + +// decompressKarabina12345Eval computes [Ext6.DecompressKarabina12345] but with the non-native Eval method. +func (e Ext6) decompressKarabina12345Eval(x *E6) *E6 { + mone := e.fp.NewElement(-1) + c := e.fp.NewElement(-4) + g1 := x.A2 + g2 := x.A4 + g3 := x.A1 + g4 := x.A3 + g5 := x.A5 + // h0 = -3*c*g1*g2 + 2*c*g4^2 + c*g3*g5 + 1 + h0 := e.fp.Eval([][]*baseEl{{mone, c, &g1, &g2}, {c, &g4, &g4}, {c, &g3, &g5}, {e.fp.One()}}, []*big.Int{big.NewInt(3), big.NewInt(2), big.NewInt(1), big.NewInt(1)}) + return &E6{ + A0: *h0, + A1: g3, + A2: g1, + A3: g4, + A4: g2, + A5: g5, + } +} + +// DecompressKarabina12345 decompresses Karabina's cyclotomic square result SQR12345 +func (e Ext6) decompressKarabina12345(x *E6) *E6 { x = e.Reduce(x) // h0 = (2g4^2 + g3g5 - 3g2g1)*c + 1 diff --git a/std/algebra/emulated/fields_bw6761/e6_pairing.go b/std/algebra/emulated/fields_bw6761/e6_pairing.go index e2715f3cc..9a3e7d0dd 100644 --- a/std/algebra/emulated/fields_bw6761/e6_pairing.go +++ b/std/algebra/emulated/fields_bw6761/e6_pairing.go @@ -132,6 +132,38 @@ func (e Ext6) ExpC2(z *E6) *E6 { // // E6{A0: c0, A1: 0, A2: c1, A3: 1, A4: 0, A5: 0} func (e *Ext6) MulBy023(z *E6, c0, c1 *baseEl) *E6 { + return e.mulBy023Direct(z, c0, c1) +} + +// mulBy023Direct multiplies z by an E6 sparse element 023 using schoolbook multiplication +func (e Ext6) mulBy023Direct(z *E6, c0, c1 *baseEl) *E6 { + nonResidue := e.fp.NewElement(-4) + + // z0 = a0c0 + β(a3 + a4c1) + z0 := e.fp.Eval([][]*baseEl{{&z.A0, c0}, {nonResidue, &z.A3}, {nonResidue, &z.A4, c1}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // z1 = a1c0 + β(a4 + a5c1) + z1 := e.fp.Eval([][]*baseEl{{&z.A1, c0}, {nonResidue, &z.A4}, {nonResidue, &z.A5, c1}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // z2 = a0c1 + a2c0 + β(a5) + z2 := e.fp.Eval([][]*baseEl{{&z.A0, c1}, {&z.A2, c0}, {nonResidue, &z.A5}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c3 = a0 + a1c1 + a3c0 + z3 := e.fp.Eval([][]*baseEl{{&z.A0}, {&z.A1, c1}, {&z.A3, c0}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c4 = a1 + a2c1 + a4c0 + z4 := e.fp.Eval([][]*baseEl{{&z.A1}, {&z.A2, c1}, {&z.A4, c0}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c5 = a2 + a3c1 + a5c0, + z5 := e.fp.Eval([][]*baseEl{{&z.A2}, {&z.A3, c1}, {&z.A5, c0}}, []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + + return &E6{ + A0: *z0, + A1: *z1, + A2: *z2, + A3: *z3, + A4: *z4, + A5: *z5, + } +} + +// mulBy023 multiplies z by an E6 sparse element 023 +func (e Ext6) mulBy023(z *E6, c0, c1 *baseEl) *E6 { z = e.Reduce(z) a := e.fp.Mul(&z.A0, c0) @@ -198,7 +230,7 @@ func (e *Ext6) MulBy023(z *E6, c0, c1 *baseEl) *E6 { } -// Mul023By023 multiplies two E6 sparse element of the form: +// Mul023By023 multiplies two E6 sparse element of the form: // // E6{A0: c0, A1: 0, A2: c1, A3: 1, A4: 0, A5: 0} // @@ -206,6 +238,34 @@ func (e *Ext6) MulBy023(z *E6, c0, c1 *baseEl) *E6 { // // E6{A0: c0, A1: 0, A2: c1, A3: 1, A4: 0, A5: 0} func (e Ext6) Mul023By023(d0, d1, c0, c1 *baseEl) [5]*baseEl { + return e.mul023by023Direct(d0, d1, c0, c1) +} + +// mul023by023Direct multiplies two E6 sparse element using schoolbook multiplication +func (e Ext6) mul023by023Direct(d0, d1, c0, c1 *baseEl) [5]*baseEl { + nonResidue := e.fp.NewElement(-4) + // c0 = d0c0 + β + z0 := e.fp.Eval([][]*baseEl{{d0, c0}, {nonResidue}}, []*big.Int{big.NewInt(1), big.NewInt(1)}) + // c2 = d0c1 + d1c0 + z2 := e.fp.Eval([][]*baseEl{{d0, c1}, {d1, c0}}, []*big.Int{big.NewInt(1), big.NewInt(1)}) + // c3 = d0 + c0 + z3 := e.fp.Add(d0, c0) + // c4 = d1c1 + z4 := e.fp.Mul(d1, c1) + // c5 = d1 + c1, + z5 := e.fp.Add(d1, c1) + + return [5]*baseEl{z0, z2, z3, z4, z5} +} + +// mul023By023 multiplies two E6 sparse element of the form: +// +// E6{A0: c0, A1: 0, A2: c1, A3: 1, A4: 0, A5: 0} +// +// and +// +// E6{A0: c0, A1: 0, A2: c1, A3: 1, A4: 0, A5: 0} +func (e Ext6) mul023By023(d0, d1, c0, c1 *baseEl) [5]*baseEl { x0 := e.fp.Mul(c0, d0) x1 := e.fp.Mul(c1, d1) x04 := e.fp.Add(c0, d0) @@ -224,9 +284,48 @@ func (e Ext6) Mul023By023(d0, d1, c0, c1 *baseEl) [5]*baseEl { // MulBy02345 multiplies z by an E6 sparse element of the form // -// E6{A0: y0, A1: 0, A2: y1, A3: y2, A4: y3, A5: y4}, -// } +// E6{A0: y0, A1: 0, A2: y1, A3: y2, A4: y3, A5: y4} func (e *Ext6) MulBy02345(z *E6, x [5]*baseEl) *E6 { + return e.mulBy02345Direct(z, x) +} + +// mulBy02345Direct multiplies z by an E6 sparse element using schoolbook multiplication +func (e Ext6) mulBy02345Direct(z *E6, x [5]*baseEl) *E6 { + nonResidue := e.fp.NewElement(-4) + + // c0 = a0y0 + β(a1y4 + a2y3 + a3y2 + a4y1) + c0 := e.fp.Eval([][]*baseEl{{&z.A0, x[0]}, {nonResidue, &z.A1, x[4]}, {nonResidue, &z.A2, x[3]}, {nonResidue, &z.A3, x[2]}, {nonResidue, &z.A4, x[1]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c1 = a1y0 + β(a2y4 + a3y3 + a4y2 + a5y1) + c1 := e.fp.Eval([][]*baseEl{{&z.A1, x[0]}, {nonResidue, &z.A2, x[4]}, {nonResidue, &z.A3, x[3]}, {nonResidue, &z.A4, x[2]}, {nonResidue, &z.A5, x[1]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c2 = a0y1 + a2y0 + β(a3y4 + a4y3 + a5y2) + c2 := e.fp.Eval([][]*baseEl{{&z.A0, x[1]}, {&z.A2, x[0]}, {nonResidue, &z.A3, x[4]}, {nonResidue, &z.A4, x[3]}, {nonResidue, &z.A5, x[2]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c3 = a0y2 + a1y1 + a3y0 + β(a4y4 + a5y3) + c3 := e.fp.Eval([][]*baseEl{{&z.A0, x[2]}, {&z.A1, x[1]}, {&z.A3, x[0]}, {nonResidue, &z.A4, x[4]}, {nonResidue, &z.A5, x[3]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c4 = a0y3 + a1y2 + a2y1 + a4y0 + βa5y4 + c4 := e.fp.Eval([][]*baseEl{{&z.A0, x[3]}, {&z.A1, x[2]}, {&z.A2, x[1]}, {&z.A4, x[0]}, {nonResidue, &z.A5, x[4]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + // c5 = a0y4 + a1y3 + a2y2 + a3y1 + a5y0, + c5 := e.fp.Eval([][]*baseEl{{&z.A0, x[4]}, {&z.A1, x[3]}, {&z.A2, x[2]}, {&z.A3, x[1]}, {&z.A5, x[0]}}, + []*big.Int{big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1), big.NewInt(1)}) + + return &E6{ + A0: *c0, + A1: *c1, + A2: *c2, + A3: *c3, + A4: *c4, + A5: *c5, + } +} + +// mulBy02345 multiplies z by an E6 sparse element of the form +// +// E6{A0: y0, A1: 0, A2: y1, A3: y2, A4: y3, A5: y4}, +func (e *Ext6) mulBy02345(z *E6, x [5]*baseEl) *E6 { a0 := e.fp.Add(&z.A0, &z.A1) a1 := e.fp.Add(&z.A2, &z.A3) a2 := e.fp.Add(&z.A4, &z.A5) diff --git a/std/algebra/emulated/fields_bw6761/e6_test.go b/std/algebra/emulated/fields_bw6761/e6_test.go index b5745e77c..c011eb8b5 100644 --- a/std/algebra/emulated/fields_bw6761/e6_test.go +++ b/std/algebra/emulated/fields_bw6761/e6_test.go @@ -105,10 +105,12 @@ type e6MulVariants struct { func (circuit *e6MulVariants) Define(api frontend.API) error { e := NewExt6(api) - expected1 := *e.mulMontgomery6(&circuit.A, &circuit.B) - expected2 := *e.mulToomCook6(&circuit.A, &circuit.B) - e.AssertIsEqual(&expected1, &circuit.C) - e.AssertIsEqual(&expected2, &circuit.C) + expected1 := e.mulMontgomery6(&circuit.A, &circuit.B) + expected2 := e.mulToomCook6(&circuit.A, &circuit.B) + expected3 := e.mulDirect(&circuit.A, &circuit.B) + e.AssertIsEqual(expected1, &circuit.C) + e.AssertIsEqual(expected2, &circuit.C) + e.AssertIsEqual(expected3, &circuit.C) return nil } @@ -135,10 +137,9 @@ type e6Mul struct { } func (circuit *e6Mul) Define(api frontend.API) error { - var expected E6 e := NewExt6(api) - expected = *e.Mul(&circuit.A, &circuit.B) - e.AssertIsEqual(&expected, &circuit.C) + expected := e.Mul(&circuit.A, &circuit.B) + e.AssertIsEqual(expected, &circuit.C) return nil } @@ -160,6 +161,35 @@ func TestMulFp6(t *testing.T) { assert.NoError(err) } +type e6SquareVariants struct { + A, C E6 +} + +func (circuit *e6SquareVariants) Define(api frontend.API) error { + e := NewExt6(api) + expected1 := e.squareDirect(&circuit.A) + expected2 := e.squareEmulatedTower(&circuit.A) + e.AssertIsEqual(expected1, &circuit.C) + e.AssertIsEqual(expected2, &circuit.C) + return nil +} + +func TestSquareVariantsFp6(t *testing.T) { + assert := test.NewAssert(t) + // witness values + var a, c bw6761.E6 + _, _ = a.SetRandom() + c.Square(&a) + + witness := e6SquareVariants{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6SquareVariants{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + type e6Square struct { A, B E6 } @@ -312,6 +342,45 @@ func TestExptFp6(t *testing.T) { assert.NoError(err) } +type e6MulBy023Variants struct { + A E6 `gnark:",public"` + W E6 + B, C baseEl +} + +func (circuit *e6MulBy023Variants) Define(api frontend.API) error { + e := NewExt6(api) + expected1 := e.mulBy023(&circuit.A, &circuit.B, &circuit.C) + expected2 := e.mulBy023Direct(&circuit.A, &circuit.B, &circuit.C) + e.AssertIsEqual(expected1, &circuit.W) + e.AssertIsEqual(expected2, &circuit.W) + return nil +} + +func TestFp6MulBy023Variants(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, w bw6761.E6 + _, _ = a.SetRandom() + var one, b, c fp.Element + one.SetOne() + _, _ = b.SetRandom() + _, _ = c.SetRandom() + w.Set(&a) + w.MulBy014(&b, &c, &one) + + witness := e6MulBy023Variants{ + A: FromE6(&a), + B: emulated.ValueOf[emulated.BW6761Fp](&b), + C: emulated.ValueOf[emulated.BW6761Fp](&c), + W: FromE6(&w), + } + + err := test.IsSolved(&e6MulBy023Variants{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + type e6MulBy023 struct { A E6 `gnark:",public"` W E6 @@ -348,3 +417,104 @@ func TestFp6MulBy023(t *testing.T) { err := test.IsSolved(&e6MulBy023{}, &witness, ecc.BN254.ScalarField()) assert.NoError(err) } + +type e6Mul023By023Variants struct { + A E6 `gnark:",public"` + B E6 `gnark:",public"` +} + +func (circuit *e6Mul023By023Variants) Define(api frontend.API) error { + e := NewExt6(api) + expected1 := e.mul023By023(&circuit.A.A0, &circuit.A.A2, &circuit.B.A0, &circuit.B.A2) + expected2 := e.mul023by023Direct(&circuit.A.A0, &circuit.A.A2, &circuit.B.A0, &circuit.B.A2) + for i := range expected1 { + e.fp.AssertIsEqual(expected1[i], expected2[i]) + } + return nil +} + +func TestFp6Mul023By023Variants(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b bw6761.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + witness := e6Mul023By023Variants{ + A: FromE6(&a), + B: FromE6(&b), + } + + err := test.IsSolved(&e6Mul023By023Variants{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6MulBy02345Variants struct { + A E6 `gnark:",public"` + B E6 `gnark:",public"` +} + +func (circuit *e6MulBy02345Variants) Define(api frontend.API) error { + e := NewExt6(api) + expected1 := e.mulBy02345(&circuit.A, [5]*baseEl{&circuit.B.A0, &circuit.B.A2, &circuit.B.A3, &circuit.B.A4, &circuit.B.A5}) + expected2 := e.mulBy02345Direct(&circuit.A, [5]*baseEl{&circuit.B.A0, &circuit.B.A2, &circuit.B.A3, &circuit.B.A4, &circuit.B.A5}) + e.AssertIsEqual(expected1, expected2) + return nil +} + +func TestFp6MulBy02345Variants(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, b bw6761.E6 + _, _ = a.SetRandom() + _, _ = b.SetRandom() + + witness := e6MulBy02345Variants{ + A: FromE6(&a), + B: FromE6(&b), + } + + err := test.IsSolved(&e6MulBy02345Variants{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + +type e6CycolotomicSquareKarabina12345Variants struct { + A E6 `gnark:",public"` + C E6 `gnark:",public"` +} + +func (circuit *e6CycolotomicSquareKarabina12345Variants) Define(api frontend.API) error { + e := NewExt6(api) + expected1 := e.cyclotomicSquareKarabina12345(&circuit.A) + expected2 := e.cyclotomicSquareKarabina12345Eval(&circuit.A) + e.AssertIsEqual(expected1, expected2) + dec1 := e.decompressKarabina12345(expected1) + dec2 := e.decompressKarabina12345Eval(expected2) + e.AssertIsEqual(dec1, dec2) + e.fp.AssertIsEqual(&dec1.A1, &circuit.C.A1) + e.fp.AssertIsEqual(&dec1.A2, &circuit.C.A2) + // e.fp.AssertIsEqual(&dec1.A3, &circuit.C.A3) + e.fp.AssertIsEqual(&dec1.A4, &circuit.C.A4) + e.fp.AssertIsEqual(&dec1.A5, &circuit.C.A5) + + return nil +} + +func TestFp6CyclotomicSquareKarabina12345Variants(t *testing.T) { + + assert := test.NewAssert(t) + // witness values + var a, c bw6761.E6 + _, _ = a.SetRandom() + c.CyclotomicSquare(&a) + + witness := e6CycolotomicSquareKarabina12345Variants{ + A: FromE6(&a), + C: FromE6(&c), + } + + err := test.IsSolved(&e6CycolotomicSquareKarabina12345Variants{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 264d0effc..b5e6f147a 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -1277,3 +1277,145 @@ func TestIsZeroEdgeCases(t *testing.T) { testIsZeroEdgeCases[BN254Fr](t) testIsZeroEdgeCases[emparams.Mod1e512](t) } + +type PolyEvalCircuit[T FieldParams] struct { + Inputs []Element[T] + Terms [][]int + Coeffs []*big.Int + Expected Element[T] +} + +func (c *PolyEvalCircuit[T]) Define(api frontend.API) error { + // withEval + f, err := NewField[T](api) + if err != nil { + return err + } + terms := make([][]*Element[T], len(c.Terms)) + for i := range terms { + terms[i] = make([]*Element[T], len(c.Terms[i])) + for j := range terms[i] { + terms[i][j] = &c.Inputs[c.Terms[i][j]] + } + } + resEval := f.Eval(terms, c.Coeffs) + + // withSum + addTerms := make([]*Element[T], len(c.Terms)) + for i, term := range c.Terms { + termVal := f.One() + for j := range term { + termVal = f.Mul(termVal, &c.Inputs[term[j]]) + } + addTerms[i] = f.MulConst(termVal, c.Coeffs[i]) + } + resSum := f.Sum(addTerms...) + + // mul no reduce + addTerms2 := make([]*Element[T], len(c.Terms)) + for i, term := range c.Terms { + termVal := f.One() + for j := range term { + termVal = f.MulNoReduce(termVal, &c.Inputs[term[j]]) + } + addTerms2[i] = f.MulConst(termVal, c.Coeffs[i]) + } + resNoReduce := f.Sum(addTerms2...) + resReduced := f.Reduce(resNoReduce) + + // assertions + f.AssertIsEqual(resEval, &c.Expected) + f.AssertIsEqual(resSum, &c.Expected) + f.AssertIsEqual(resNoReduce, &c.Expected) + f.AssertIsEqual(resReduced, &c.Expected) + + return nil +} + +func TestPolyEval(t *testing.T) { + testPolyEval[Goldilocks](t) + testPolyEval[BN254Fr](t) + testPolyEval[emparams.Mod1e512](t) +} + +func testPolyEval[T FieldParams](t *testing.T) { + const nbInputs = 2 + assert := test.NewAssert(t) + var fp T + var err error + // 2*x^3 + 3*x^2 y + 4*x y^2 + 5*y^3 + terms := [][]int{{0, 0, 0}, {0, 0, 1}, {0, 1, 1}, {1, 1, 1}} + coefficients := []*big.Int{big.NewInt(2), big.NewInt(3), big.NewInt(4), big.NewInt(5)} + inputs := make([]*big.Int, nbInputs) + assignmentInput := make([]Element[T], nbInputs) + for i := range inputs { + inputs[i], err = rand.Int(rand.Reader, fp.Modulus()) + assert.NoError(err) + } + for i := range inputs { + assignmentInput[i] = ValueOf[T](inputs[i]) + } + expected := new(big.Int) + for i, term := range terms { + termVal := new(big.Int).Set(coefficients[i]) + for j := range term { + termVal.Mul(termVal, inputs[term[j]]) + } + expected.Add(expected, termVal) + } + expected.Mod(expected, fp.Modulus()) + + assignment := &PolyEvalCircuit[T]{ + Inputs: assignmentInput, + Expected: ValueOf[T](expected), + } + assert.CheckCircuit(&PolyEvalCircuit[T]{Inputs: make([]Element[T], nbInputs), Terms: terms, Coeffs: coefficients}, test.WithValidAssignment(assignment)) +} + +type PolyEvalNegativeCoefficient[T FieldParams] struct { + Inputs []Element[T] + Res Element[T] +} + +func (c *PolyEvalNegativeCoefficient[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + // x - y + coefficients := []*big.Int{big.NewInt(1), big.NewInt(-1)} + res := f.Eval([][]*Element[T]{{&c.Inputs[0]}, {&c.Inputs[1]}}, coefficients) + f.AssertIsEqual(res, &c.Res) + return nil +} + +func TestPolyEvalNegativeCoefficient(t *testing.T) { + testPolyEvalNegativeCoefficient[Goldilocks](t) + testPolyEvalNegativeCoefficient[BN254Fr](t) + testPolyEvalNegativeCoefficient[emparams.Mod1e512](t) +} + +func testPolyEvalNegativeCoefficient[T FieldParams](t *testing.T) { + t.Skip("not implemented yet") + assert := test.NewAssert(t) + var fp T + fmt.Println("modulus", fp.Modulus()) + var err error + const nbInputs = 2 + inputs := make([]*big.Int, nbInputs) + assignmentInput := make([]Element[T], nbInputs) + for i := range inputs { + inputs[i], err = rand.Int(rand.Reader, fp.Modulus()) + assert.NoError(err) + } + for i := range inputs { + fmt.Println("input", i, inputs[i]) + assignmentInput[i] = ValueOf[T](inputs[i]) + } + expected := new(big.Int).Sub(inputs[0], inputs[1]) + expected.Mod(expected, fp.Modulus()) + fmt.Println("expected", expected) + assignment := &PolyEvalNegativeCoefficient[T]{Inputs: assignmentInput, Res: ValueOf[T](expected)} + err = test.IsSolved(&PolyEvalNegativeCoefficient[T]{Inputs: make([]Element[T], nbInputs)}, assignment, testCurve.ScalarField()) + assert.NoError(err) +} diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go index dce4074fa..d7aae8714 100644 --- a/std/math/emulated/field.go +++ b/std/math/emulated/field.go @@ -47,7 +47,7 @@ type Field[T FieldParams] struct { constrainedLimbs map[[16]byte]struct{} checker frontend.Rangechecker - mulChecks []mulCheck[T] + deferredChecks []deferredChecker } type ctxKey[T FieldParams] struct{} @@ -103,7 +103,7 @@ func NewField[T FieldParams](native frontend.API) (*Field[T], error) { return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb()) } - native.Compiler().Defer(f.performMulChecks) + native.Compiler().Defer(f.performDeferredChecks) if storer, ok := native.(kvstore.Store); ok { storer.SetKeyValue(ctxKey[T]{}, f) } @@ -282,3 +282,15 @@ func max[T constraints.Ordered](a ...T) T { } return m } + +func sum[T constraints.Ordered](a ...T) T { + if len(a) == 0 { + var f T + return f + } + m := a[0] + for _, v := range a[1:] { + m += v + } + return m +} diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index 78785d82b..122f0f785 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -4,12 +4,23 @@ import ( "fmt" "math/big" "math/bits" + "slices" "github.com/consensys/gnark/frontend" limbs "github.com/consensys/gnark/std/internal/limbcomposition" "github.com/consensys/gnark/std/multicommit" ) +type deferredChecker interface { + toCommit() []frontend.Variable + maxLen() int + + evalRound1(at []frontend.Variable) + evalRound2(at []frontend.Variable) + check(api frontend.API, peval frontend.Variable, coef frontend.Variable) + cleanEvaluations() +} + // mulCheck represents a single multiplication check. Instead of doing a // multiplication exactly where called, we compute the result using hint and // return it. Additionally, we store the correctness check for later checking @@ -62,6 +73,31 @@ type mulCheck[T FieldParams] struct { p *Element[T] // modulus if non-nil } +func (mc *mulCheck[T]) toCommit() []frontend.Variable { + var toCommit []frontend.Variable + toCommit = append(toCommit, mc.a.Limbs...) + toCommit = append(toCommit, mc.b.Limbs...) + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + if mc.p != nil { + toCommit = append(toCommit, mc.p.Limbs...) + } + return toCommit +} + +func (mc *mulCheck[T]) maxLen() int { + maxLen := len(mc.a.Limbs) + maxLen = max(maxLen, len(mc.b.Limbs)) + maxLen = max(maxLen, len(mc.r.Limbs)) + maxLen = max(maxLen, len(mc.k.Limbs)) + maxLen = max(maxLen, len(mc.c.Limbs)) + if mc.p != nil { + maxLen = max(maxLen, len(mc.p.Limbs)) + } + return maxLen +} + // evalRound1 evaluates first c(X), r(X) and k(X) at a given random point at[0]. // In the first round we do not assume that any of them is already evaluated as // they come directly from hint. @@ -132,7 +168,7 @@ func (f *Field[T]) mulMod(a, b *Element[T], _ uint, p *Element[T]) *Element[T] { r: r, p: p, } - f.mulChecks = append(f.mulChecks, mc) + f.deferredChecks = append(f.deferredChecks, &mc) return r } @@ -156,7 +192,7 @@ func (f *Field[T]) checkZero(a *Element[T], p *Element[T]) { r: r, // expected to be zero on zero limbs. p: p, } - f.mulChecks = append(f.mulChecks, mc) + f.deferredChecks = append(f.deferredChecks, &mc) } // evalWithChallenge represents element a as a polynomial a(X) and evaluates at @@ -184,12 +220,12 @@ func (f *Field[T]) evalWithChallenge(a *Element[T], at []frontend.Variable) *Ele // performMulChecks should be deferred to actually perform all the // multiplication checks. -func (f *Field[T]) performMulChecks(api frontend.API) error { +func (f *Field[T]) performDeferredChecks(api frontend.API) error { // use given api. We are in defer and API may be different to what we have // stored. // there are no multiplication checks, nothing to do - if len(f.mulChecks) == 0 { + if len(f.deferredChecks) == 0 { return nil } @@ -201,23 +237,15 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { // multi-commit and range checks are in different commitment, then we have // problem. var toCommit []frontend.Variable - for i := range f.mulChecks { - toCommit = append(toCommit, f.mulChecks[i].a.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].b.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].r.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].k.Limbs...) - toCommit = append(toCommit, f.mulChecks[i].c.Limbs...) - if f.mulChecks[i].p != nil { - toCommit = append(toCommit, f.mulChecks[i].p.Limbs...) - } + for i := range f.deferredChecks { + toCommit = append(toCommit, f.deferredChecks[i].toCommit()...) } // we give all the inputs as inputs to obtain random verifier challenge. multicommit.WithCommitment(api, func(api frontend.API, commitment frontend.Variable) error { // for efficiency, we compute all powers of the challenge as slice at. coefsLen := int(f.fParams.NbLimbs()) - for i := range f.mulChecks { - coefsLen = max(coefsLen, len(f.mulChecks[i].a.Limbs), len(f.mulChecks[i].b.Limbs), - len(f.mulChecks[i].c.Limbs), len(f.mulChecks[i].k.Limbs)) + for i := range f.deferredChecks { + coefsLen = max(coefsLen, f.deferredChecks[i].maxLen()) } at := make([]frontend.Variable, coefsLen) at[0] = commitment @@ -225,12 +253,12 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { at[i] = api.Mul(at[i-1], commitment) } // evaluate all r, k, c - for i := range f.mulChecks { - f.mulChecks[i].evalRound1(at) + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound1(at) } // assuming r is input to some other multiplication, then is already evaluated - for i := range f.mulChecks { - f.mulChecks[i].evalRound2(at) + for i := range f.deferredChecks { + f.deferredChecks[i].evalRound2(at) } // evaluate p(X) at challenge pval := f.evalWithChallenge(f.Modulus(), at) @@ -239,13 +267,13 @@ func (f *Field[T]) performMulChecks(api frontend.API) error { coef.Lsh(coef, f.fParams.BitsPerLimb()) ccoef := api.Sub(coef, commitment) // verify all mulchecks - for i := range f.mulChecks { - f.mulChecks[i].check(api, pval.evaluation, ccoef) + for i := range f.deferredChecks { + f.deferredChecks[i].check(api, pval.evaluation, ccoef) } // clean cached evaluation. Helps in case we compile the same circuit // multiple times. - for i := range f.mulChecks { - f.mulChecks[i].cleanEvaluations() + for i := range f.deferredChecks { + f.deferredChecks[i].cleanEvaluations() } return nil }, toCommit...) @@ -367,30 +395,8 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil { return fmt.Errorf("decompose rem: %w", err) } - xp := make([]*big.Int, nbMultiplicationResLimbs(nbALen, nbBLen)) - yp := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLen, nbLimbs)) - for i := range xp { - xp[i] = new(big.Int) - } - for i := range yp { - yp[i] = new(big.Int) - } - tmp := new(big.Int) - // we know compute the schoolbook multiprecision multiplication of a*b and - // r+k*p - for i := 0; i < nbALen; i++ { - for j := 0; j < nbBLen; j++ { - tmp.Mul(alimbs[i], blimbs[j]) - xp[i+j].Add(xp[i+j], tmp) - } - } - for i := 0; i < nbLimbs; i++ { - yp[i].Add(yp[i], remLimbs[i]) - for j := 0; j < nbQuoLen; j++ { - tmp.Mul(quoLimbs[j], plimbs[i]) - yp[i+j].Add(yp[i+j], tmp) - } - } + xp := limbMul(alimbs, blimbs) + yp := limbMul(quoLimbs, plimbs) carry := new(big.Int) for i := range carryLimbs { if i < len(xp) { @@ -510,3 +516,374 @@ func (f *Field[T]) Exp(base, exp *Element[T]) *Element[T] { res = f.Select(expBts[n-1], f.Mul(base, res), res) return res } + +// multivariate represents a multivariate polynomial. It is a list of terms +// where each term is a list of exponents for each variable. The coefficients +// are stored in the same order as the terms. +type multivariate[T FieldParams] struct { + Terms [][]int + Coefficients []*big.Int +} + +// Eval evaluates the multivariate polynomial. The elements of the inner slices +// are multiplied together and then summed together with the corresponding +// coefficient. +// +// NB! This is experimental API. It does not support negative coefficients. It +// does not check that computing the term wouldn't overflow the field. +func (f *Field[T]) Eval(at [][]*Element[T], coefs []*big.Int) *Element[T] { + if len(at) != len(coefs) { + panic("terms and coefficients mismatch") + } + if len(at) == 0 { + return f.Zero() + } + for i := range coefs { + if coefs[i].Sign() < 0 { + panic("negative coefficient") + } + } + // initialize the multivariate struct from the inputs + + // it would be easier to use a map to store the elements and then use the + // map to get the inputs in the right order. However, for deterministic + // circuit compilation we need to use the same order of inputs. So we use + // slice instead. + var allElems []*Element[T] + for i := range at { + AT_INNER: + for j := range at[i] { + for k := range allElems { + if allElems[k] == at[i][j] { + continue AT_INNER + } + } + allElems = append(allElems, at[i][j]) + } + } + terms := make([][]int, 0, len(at)) + for i := range at { + term := make([]int, len(allElems)) + for j := range at[i] { + term[slices.Index(allElems, at[i][j])]++ + } + terms = append(terms, term) + } + + // ensure that all the elements have the range checks enforced on limbs. + // Necessary in case the input is a witness. + for i := range allElems { + f.enforceWidthConditional(allElems[i]) + } + + mv := &multivariate[T]{ + Terms: terms, + Coefficients: coefs, + } + + k, r, c, err := f.callPolyHint(mv, allElems) + if err != nil { + panic(err) + } + + mvc := mvCheck[T]{ + f: f, + mv: mv, + vals: allElems, + r: r, + k: k, + c: c, + } + + f.deferredChecks = append(f.deferredChecks, &mvc) + return r +} + +func (f *Field[T]) callPolyHint(mv *multivariate[T], at []*Element[T]) (quo, rem, carries *Element[T], err error) { + // first compute the length of the result so that we know how many bits we need for the quotient. + nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() + modBits := uint(f.fParams.Modulus().BitLen()) + quoSize := f.polyEvalQuoSize(mv, at) + nbQuoLimbs := (quoSize - modBits + nbBits) / nbBits + nbRemLimbs := nbLimbs + nbCarryLimbs := nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs)) - 1 + + hintInputs := []frontend.Variable{ + nbBits, + nbLimbs, + len(mv.Terms), + len(at), + nbQuoLimbs, + nbRemLimbs, + nbCarryLimbs, + } + for i := range mv.Terms { + for j := range mv.Terms[i] { + hintInputs = append(hintInputs, mv.Terms[i][j]) + } + } + for i := range mv.Coefficients { + hintInputs = append(hintInputs, mv.Coefficients[i]) + } + hintInputs = append(hintInputs, f.Modulus().Limbs...) + for i := range at { + hintInputs = append(hintInputs, len(at[i].Limbs)) + hintInputs = append(hintInputs, at[i].Limbs...) + } + ret, err := f.api.NewHint(polyHint, int(nbQuoLimbs)+int(nbRemLimbs)+int(nbCarryLimbs), hintInputs...) + if err != nil { + err = fmt.Errorf("call hint: %w", err) + return + } + quo = f.packLimbs(ret[:nbQuoLimbs], false) + rem = f.packLimbs(ret[nbQuoLimbs:nbQuoLimbs+nbRemLimbs], true) + carries = f.newInternalElement(ret[nbQuoLimbs+nbRemLimbs:], 0) + return quo, rem, carries, nil +} + +type mvCheck[T FieldParams] struct { + f *Field[T] + mv *multivariate[T] + vals []*Element[T] + r *Element[T] // reduced result + k *Element[T] // quotient + c *Element[T] // carry +} + +func (mc *mvCheck[T]) toCommit() []frontend.Variable { + var toCommit []frontend.Variable + toCommit = append(toCommit, mc.r.Limbs...) + toCommit = append(toCommit, mc.k.Limbs...) + toCommit = append(toCommit, mc.c.Limbs...) + for j := range mc.vals { + toCommit = append(toCommit, mc.vals[j].Limbs...) + } + return toCommit +} + +func (mc *mvCheck[T]) maxLen() int { + maxLen := len(mc.r.Limbs) + maxLen = max(maxLen, len(mc.k.Limbs)) + maxLen = max(maxLen, len(mc.c.Limbs)) + for j := range mc.vals { + maxLen = max(maxLen, len(mc.vals[j].Limbs)) + } + return maxLen +} + +func (mc *mvCheck[T]) evalRound1(at []frontend.Variable) { + mc.c = mc.f.evalWithChallenge(mc.c, at) + mc.r = mc.f.evalWithChallenge(mc.r, at) + mc.k = mc.f.evalWithChallenge(mc.k, at) +} + +func (mc *mvCheck[T]) evalRound2(at []frontend.Variable) { + for i := range mc.vals { + mc.vals[i] = mc.f.evalWithChallenge(mc.vals[i], at) + } +} + +func (mc *mvCheck[T]) check(api frontend.API, peval, coef frontend.Variable) { + ls := frontend.Variable(0) + for i, term := range mc.mv.Terms { + termProd := frontend.Variable(mc.mv.Coefficients[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + termProd = api.Mul(termProd, mc.vals[i].evaluation) + } + } + ls = api.Add(ls, termProd) + } + rs := api.Add(mc.r.evaluation, api.Mul(peval, mc.k.evaluation), api.Mul(mc.c.evaluation, coef)) + api.AssertIsEqual(ls, rs) +} + +func (mc *mvCheck[T]) cleanEvaluations() { + for i := range mc.vals { + mc.vals[i].evaluation = 0 + mc.vals[i].isEvaluated = false + } + mc.r.evaluation = 0 + mc.r.isEvaluated = false + mc.k.evaluation = 0 + mc.k.isEvaluated = false + mc.c.evaluation = 0 + mc.c.isEvaluated = false +} + +func (f *Field[T]) polyEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize uint) { + modBits := f.fParams.Modulus().BitLen() + quoSizes := make([]uint, len(mv.Terms)) + for i, term := range mv.Terms { + var lengths []uint + for j, pow := range term { + for k := 0; k < pow; k++ { + lengths = append(lengths, uint(modBits)+at[j].overflow) + } + } + lengths = append(lengths, uint(mv.Coefficients[i].BitLen())) + quoSizes[i] = sum(lengths...) + } + quoSize = max(quoSizes...) + uint(len(quoSizes)) + return quoSize +} + +func polyHint(mod *big.Int, inputs, outputs []*big.Int) error { + if len(inputs) < 7 { + return fmt.Errorf("not enough inputs") + } + var ( + nbBits = int(inputs[0].Int64()) + nbLimbs = int(inputs[1].Int64()) + nbTerms = int(inputs[2].Int64()) + nbVars = int(inputs[3].Int64()) + nbQuoLimbs = int(inputs[4].Int64()) + nbRemLimbs = int(inputs[5].Int64()) + nbCarryLimbs = int(inputs[6].Int64()) + ) + if len(outputs) != nbQuoLimbs+nbRemLimbs+nbCarryLimbs { + return fmt.Errorf("output length mismatch") + } + outPtr := 0 + quoLimbs := outputs[outPtr : outPtr+nbQuoLimbs] + outPtr += nbQuoLimbs + remLimbs := outputs[outPtr : outPtr+nbRemLimbs] + outPtr += nbRemLimbs + carryLimbs := outputs[outPtr : outPtr+nbCarryLimbs] + terms := make([][]int, nbTerms) + ptr := 7 + for i := range terms { + terms[i] = make([]int, nbVars) + for j := range terms[i] { + terms[i][j] = int(inputs[ptr].Int64()) + ptr++ + } + } + coeffs := make([]*big.Int, nbTerms) + for i := range coeffs { + coeffs[i] = inputs[ptr] + ptr++ + } + plimbs := inputs[ptr : ptr+nbLimbs] + ptr += nbLimbs + p := new(big.Int) + if err := limbs.Recompose(plimbs, uint(nbBits), p); err != nil { + return fmt.Errorf("recompose p: %w", err) + } + varsLimbs := make([][]*big.Int, nbVars) + for i := range varsLimbs { + varsLimbs[i] = make([]*big.Int, int(inputs[ptr].Int64())) + ptr++ + for j := range varsLimbs[i] { + varsLimbs[i][j] = inputs[ptr] + ptr++ + } + } + if ptr != len(inputs) { + return fmt.Errorf("inputs not exhausted") + } + vars := make([]*big.Int, nbVars) + for i := range vars { + vars[i] = new(big.Int) + if err := limbs.Recompose(varsLimbs[i], uint(nbBits), vars[i]); err != nil { + return fmt.Errorf("recompose vars[%d]: %w", i, err) + } + } + + // compute the result on full inputs + + fullLhs := new(big.Int) + for i, term := range terms { + termRes := new(big.Int).Set(coeffs[i]) + for i, pow := range term { + for j := 0; j < pow; j++ { + termRes.Mul(termRes, vars[i]) + } + } + fullLhs.Add(fullLhs, termRes) + } + + // compute the result as r + k*p for now + var ( + quo = new(big.Int) + rem = new(big.Int) + ) + if p.Cmp(new(big.Int)) != 0 { + quo.QuoRem(fullLhs, p, rem) + } + // write the remainder and quotient to output + if err := limbs.Decompose(quo, uint(nbBits), quoLimbs); err != nil { + return fmt.Errorf("decompose quo: %w", err) + } + if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil { + return fmt.Errorf("decompose rem: %w", err) + } + + // compute the result on limbs + tmp := new(big.Int) + var lhs []*big.Int + for i, term := range terms { + // collect the variables to be multiplied together + var termVarLimbs [][]*big.Int + for i, pow := range term { + for j := 0; j < pow; j++ { + termVarLimbs = append(termVarLimbs, varsLimbs[i]) + } + } + if len(termVarLimbs) == 0 { + continue + } + termRes := []*big.Int{new(big.Int).Set(coeffs[i])} + for _, toMul := range termVarLimbs { + termRes = limbMul(termRes, toMul) + } + for i := len(lhs); i < len(termRes); i++ { + lhs = append(lhs, new(big.Int)) + } + for i := range termRes { + lhs[i].Add(lhs[i], termRes[i]) + } + } + + // compute the result as r + k*p on limbs + rhs := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs)) + for i := range rhs { + rhs[i] = new(big.Int) + } + for i := 0; i < nbLimbs; i++ { + rhs[i].Add(rhs[i], remLimbs[i]) + for j := 0; j < nbQuoLimbs; j++ { + tmp.Mul(quoLimbs[j], plimbs[i]) + rhs[i+j].Add(rhs[i+j], tmp) + } + } + + // compute the carries + carry := new(big.Int) + for i := range carryLimbs { + if i < len(lhs) { + carry.Add(carry, lhs[i]) + } + if i < len(rhs) { + carry.Sub(carry, rhs[i]) + } + carry.Rsh(carry, uint(nbBits)) + carryLimbs[i] = new(big.Int).Set(carry) + } + + return nil +} + +func limbMul(lhs []*big.Int, rhs []*big.Int) []*big.Int { + tmp := new(big.Int) + res := make([]*big.Int, nbMultiplicationResLimbs(len(lhs), len(rhs))) + for i := range res { + res[i] = new(big.Int) + } + for i := 0; i < len(lhs); i++ { + for j := 0; j < len(rhs); j++ { + res[i+j].Add(res[i+j], tmp.Mul(lhs[i], rhs[j])) + } + } + return res +} diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index 7ebdba29e..a132a9b89 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -24,6 +24,7 @@ func GetHints() []solver.Hint { SqrtHint, mulHint, subPaddingHint, + polyHint, } }