Skip to content

Commit

Permalink
[3/4 EIP 4844 in inserter circuit] Add gadget to polynomial evaluatio…
Browse files Browse the repository at this point in the history
…n using barycentric formula (#13)

prover: barycentric: implement barycentric formula gadget

The new `barycentric` package adds `CalculateBarycentricFormula`. The function
implements the evaluation of a polynomial in evaluation form at a point outside
the domain, using barycentric interpolation. This function follows
implementation by Dankrad Feist, as described in his blog post:
https://dankradfeist.de/ethereum/2021/06/18/pcs-multiproofs.html.

Another helper package is added - `field_utils`. It is a place for helper
gadgets for field elements manipulations. It contains one function `Exp`
to calculate field element's power of n, where n is an integer (not
a field element).

Signed-off-by: Wojciech Zmuda <[email protected]>
  • Loading branch information
wzmuda authored Jun 21, 2024
1 parent 9666200 commit 4245085
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 0 deletions.
45 changes: 45 additions & 0 deletions prover/barycentric/barycentric.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package barycentric

import (
"github.com/consensys/gnark/std/math/emulated"

"worldcoin/gnark-mbu/prover/field_utils"
)

// CalculateBarycentricFormula implements the evaluation of a polynomial in evaluation form at a point outside the
// domain, using barycentric interpolation. This function follows the formulation by Dankrad Feist, as described
// in his blog post: https://dankradfeist.de/ethereum/2021/06/18/pcs-multiproofs.html.
//
// The formula used for calculation is: ((z^d - 1) / d) * Σ((f_i * ω^i) / (z - ω^i)) for i=0 to d-1,
// where z is the target evaluation point, d is the degree of the polynomial, f_i are the polynomial coefficients,
// and ω^i are the domain elements.
//
// field - reference to the emulated field operations structure, used for arithmetic operations within the specified
// field.
// omegasToI - slice containing the powers of the primitive root of unity ω, raised to the power of index i,
// representing the domain elements.
// yNodes - slice containing the polynomial coefficients or the values of the polynomial at the domain elements.
// targetPoint - point outside the domain at which the polynomial is to be evaluated.
func CalculateBarycentricFormula[T emulated.FieldParams](
field *emulated.Field[T], omegasToI, yNodes []emulated.Element[T], targetPoint emulated.Element[T],
) emulated.Element[T] {

polynomialDegree := len(yNodes)

// First term: (z^d - 1) / d
zToD := field_utils.Exp(field, &targetPoint, polynomialDegree)
firstTerm := *field.Sub(zToD, field.One())
d := emulated.ValueOf[T](polynomialDegree)
firstTerm = *field.Div(&firstTerm, &d)

// Second term: Σ(f_i * ω^i)/(z - ω^i) from i=0 to d-1
secondTerm := field.Zero()
for i := range polynomialDegree {
numerator := *field.Mul(&yNodes[i], &omegasToI[i])
denominator := *field.Sub(&targetPoint, &omegasToI[i])
term := *field.Div(&numerator, &denominator)
secondTerm = field.Add(secondTerm, &term)
}

return *field.Mul(&firstTerm, secondTerm)
}
149 changes: 149 additions & 0 deletions prover/barycentric/barycentric_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package barycentric

import (
"math"
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/test"
)

type BarycentricCircuit[T emulated.FieldParams] struct {
Omega big.Int // ω
PolynomialDegree int

// Inputs (private)
YNodes []emulated.Element[T] // len(YNodes) == PolynomialDegree
TargetPoint emulated.Element[T]

// Output
InterpolatedPoint emulated.Element[T] `gnark:",public"`
}

func (circuit *BarycentricCircuit[T]) Define(api frontend.API) error {
field, err := emulated.NewField[T](api)
if err != nil {
return err
}

api.AssertIsEqual(len(circuit.YNodes), circuit.PolynomialDegree)

omegasToI := make([]emulated.Element[T], circuit.PolynomialDegree)
omegaToI := big.NewInt(1)
for i := range circuit.PolynomialDegree {
omegasToI[i] = emulated.ValueOf[T](omegaToI)
omegaToI.Mul(omegaToI, &circuit.Omega)
}

// Method under test
interpolatedPointCalculated := CalculateBarycentricFormula[T](field, omegasToI, circuit.YNodes, circuit.TargetPoint)

field.AssertIsEqual(&circuit.InterpolatedPoint, &interpolatedPointCalculated)

return nil
}

func setupTestEnvironment(polynomialDegree int) (*big.Int, *big.Int) {
// The test assumes BLS12381Fr field and a certain polynomial degree
modulus, _ := new(big.Int).SetString(
"52435875175126190479447740508185965837690552500527637822603658699938581184513", 10,
)

// For polynomial degree d = 4096 = 2^12:
// ω^(2^32) = ω^(2^20 * 2^12)
// Calculate ω^20 starting with root of unity of 2^32 degree
omega, _ := new(big.Int).SetString(
"10238227357739495823651030575849232062558860180284477541189508159991286009131", 10,
)
polynomialDegreeExp := int(math.Log2(float64(polynomialDegree)))
omegaExpExp := 32 // ω^(2^32)
for range omegaExpExp - polynomialDegreeExp {
omega.Mul(omega, omega)
omega.Mod(omega, modulus)
}

return omega, modulus
}

func TestCalculateBarycentricFormula(t *testing.T) {
type Fr = emulated.BLS12381Fr
const polynomialDegree = 4096
omega, modulus := setupTestEnvironment(polynomialDegree)

// Test cases
type PolynomialTestCase[T emulated.FieldParams] struct {
Name string
CalculateYNodes func(omega *big.Int, modulus *big.Int, polynomialDegree int) []emulated.Element[T]
TargetPoint int64
InterpolatedPoint int64
}
tests := []PolynomialTestCase[Fr]{
{
Name: "f(x) = x^3",
CalculateYNodes: func(omega *big.Int, modulus *big.Int, polynomialDegree int) []emulated.Element[Fr] {
y := make([]emulated.Element[Fr], polynomialDegree)
for i := range y {
res := new(big.Int).Exp(omega, big.NewInt(int64(i*3)), modulus)
y[i] = emulated.ValueOf[Fr](res)
}
return y
},
TargetPoint: 3,
InterpolatedPoint: 27,
},
{
Name: "f(x) = 3x^7 + 2x^4 + 4x + 20",
CalculateYNodes: func(omega *big.Int, modulus *big.Int, polynomialDegree int) []emulated.Element[Fr] {
y := make([]emulated.Element[Fr], polynomialDegree)
for i := range y {
a := new(big.Int).Exp(omega, big.NewInt(int64(i*7)), modulus)
a.Mul(a, big.NewInt(3))

b := new(big.Int).Exp(omega, big.NewInt(int64(i*4)), modulus)
b.Mul(b, big.NewInt(2))

c := new(big.Int).Exp(omega, big.NewInt(int64(i)), modulus)
c.Mul(c, big.NewInt(4))

res := new(big.Int).Add(a, b)
res.Add(res, c)
res.Add(res, big.NewInt(20))
res.Mod(res, modulus)

y[i] = emulated.ValueOf[Fr](res)
}
return y
},
TargetPoint: 3,
InterpolatedPoint: 6755,
},
}

for _, tc := range tests {
assert := test.NewAssert(t)
assert.Run(
func(a *test.Assert) {
circuit := BarycentricCircuit[Fr]{
Omega: *omega,
PolynomialDegree: polynomialDegree,
YNodes: make([]emulated.Element[Fr], polynomialDegree),
}

assignment := BarycentricCircuit[Fr]{
YNodes: tc.CalculateYNodes(omega, modulus, polynomialDegree),
TargetPoint: emulated.ValueOf[Fr](tc.TargetPoint),
InterpolatedPoint: emulated.ValueOf[Fr](tc.InterpolatedPoint),
}

assert.CheckCircuit(
&circuit, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254),
test.WithValidAssignment(&assignment),
)
}, tc.Name,
)
}
}
3 changes: 3 additions & 0 deletions prover/field_utils/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package field_utils

// Package field_utils contains convenience functions to manipulate prime field elements.
22 changes: 22 additions & 0 deletions prover/field_utils/field_exp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package field_utils

import "github.com/consensys/gnark/std/math/emulated"

// Exp raises base to the exponent in the given prime field.
//
// field - given prime field where base and the result belong.
// base - the number in the field to be risen to the given exponent.
// exponent - an integer to rise base to.
func Exp[T emulated.FieldParams](
field *emulated.Field[T], base *emulated.Element[T], exponent int,
) *emulated.Element[T] {
res := field.One()
for exponent > 0 {
if exponent%2 == 1 {
res = field.Mul(res, base)
}
base = field.Mul(base, base)
exponent /= 2
}
return res
}
66 changes: 66 additions & 0 deletions prover/field_utils/field_exp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package field_utils

import (
"fmt"
"math/big"
"math/rand"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/test"
)

type ExpCircuit[T emulated.FieldParams] struct {
Base emulated.Element[T]
Exp int
Res emulated.Element[T]
}

func (c *ExpCircuit[T]) Define(api frontend.API) error {
field, err := emulated.NewField[T](api)
if err != nil {
return err
}

// Function under test
calculatedRes := Exp[T](field, &c.Base, c.Exp)

field.AssertIsEqual(calculatedRes, &c.Res)

return nil
}

func randomPower() (int, int, *big.Int) {
base := rand.Intn(16)
exponent := rand.Intn(16)
result := new(big.Int).Exp(big.NewInt(int64(base)), big.NewInt(int64(exponent)), nil)
return base, exponent, result
}

func TestExp(t *testing.T) {
assert := test.NewAssert(t)

for range 16 { // Arbitrary choice of number of tests
base, exp, want := randomPower()
circuit := ExpCircuit[emulated.BLS12381Fr]{
Exp: exp,
}

assignment := ExpCircuit[emulated.BLS12381Fr]{
Base: emulated.ValueOf[emulated.BLS12381Fr](base),
Res: emulated.ValueOf[emulated.BLS12381Fr](want),
}

t.Run(
fmt.Sprintf("%d^%d", base, exp), func(t *testing.T) {
assert.CheckCircuit(
&circuit, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254),
test.WithValidAssignment(&assignment),
)
},
)
}
}

0 comments on commit 4245085

Please sign in to comment.