From 9411356e16e851ca5ea9485cde39e818401edf61 Mon Sep 17 00:00:00 2001 From: Wojciech Zmuda Date: Tue, 14 May 2024 22:11:17 +0200 Subject: [PATCH] prover: insertion_circuit: implement EIP 4844 Co-authored-by: Marcin Kostrzewa Signed-off-by: Wojciech Zmuda --- go.mod | 1 + go.sum | 2 + prover/insertion_circuit.go | 171 +++++++++++++++++++++++-------- prover/insertion_circuit_test.go | 152 +++++++++++++++++++++++++++ 4 files changed, 285 insertions(+), 41 deletions(-) create mode 100644 prover/insertion_circuit_test.go diff --git a/go.mod b/go.mod index da5e2dc..02f313e 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/consensys/bavard v0.1.13 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/crate-crypto/go-kzg-4844 v1.0.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect github.com/ingonyama-zk/icicle v0.0.0-20230928131117-97f0079e5c71 // indirect diff --git a/go.sum b/go.sum index 8a886d8..a784a2f 100644 --- a/go.sum +++ b/go.sum @@ -64,6 +64,8 @@ github.com/consensys/gnark-crypto v0.12.2-0.20240504013751-564b6f724c3b/go.mod h github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/crate-crypto/go-kzg-4844 v1.0.0 h1:TsSgHwrkTKecKJ4kadtHi4b3xHW5dCFUDFnUp1TsawI= +github.com/crate-crypto/go-kzg-4844 v1.0.0/go.mod h1:1kMhvPgI0Ky3yIa+9lFySEBUBXkYxeOi8ZF1sYioxhc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/prover/insertion_circuit.go b/prover/insertion_circuit.go index eab909a..9822de6 100644 --- a/prover/insertion_circuit.go +++ b/prover/insertion_circuit.go @@ -1,7 +1,15 @@ package prover import ( + "math" + "math/big" + "math/bits" + + "github.com/consensys/gnark/std/math/emulated" + + "worldcoin/gnark-mbu/prover/barycentric" "worldcoin/gnark-mbu/prover/keccak" + "worldcoin/gnark-mbu/prover/poseidon" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/frontend" @@ -10,67 +18,148 @@ import ( ) type InsertionMbuCircuit struct { - // single public input - InputHash frontend.Variable `gnark:",public"` - - // private inputs, but used as public inputs - StartIndex frontend.Variable `gnark:"input"` - PreRoot frontend.Variable `gnark:"input"` - PostRoot frontend.Variable `gnark:"input"` - IdComms []frontend.Variable `gnark:"input"` + // public inputs + InputHash frontend.Variable `gnark:",public"` + ExpectedEvaluation frontend.Variable `gnark:",public"` + Commitment4844 frontend.Variable `gnark:",public"` + StartIndex frontend.Variable `gnark:",public"` + PreRoot frontend.Variable `gnark:",public"` + PostRoot frontend.Variable `gnark:",public"` // private inputs + IdComms []frontend.Variable `gnark:"input"` MerkleProofs [][]frontend.Variable `gnark:"input"` BatchSize int Depth int } -func (circuit *InsertionMbuCircuit) Define(api frontend.API) error { - // Hash private inputs. - // We keccak hash all input to save verification gas. Inputs are arranged as follows: - // StartIndex || PreRoot || PostRoot || IdComms[0] || IdComms[1] || ... || IdComms[batchSize-1] - // 32 || 256 || 256 || 256 || 256 || ... || 256 bits - var bits []frontend.Variable - - // We convert all the inputs to the keccak hash to use big-endian (network) byte - // ordering so that it agrees with Solidity. This ensures that we don't have to - // perform the conversion inside the contract and hence save on gas. - bits_start := abstractor.Call1(api, ToReducedBigEndian{Variable: circuit.StartIndex, Size: 32}) - bits = append(bits, bits_start...) +// getMerkleTreeRoot calculates the Merkle Tree root repeatedly hashing pairs of elements in the input slice until only +// one element remains. This process effectively builds a binary tree of hashes, where each level of the tree is half +// the size of the level below it. +// At the end or the process the function returns the root value of such constructed Merkle Tree. +func getMerkleTreeRoot(api frontend.API, input []frontend.Variable) frontend.Variable { + temp := input[:] + for len(temp) > 1 { + newInput := make([]frontend.Variable, len(temp)/2) + for i := range newInput { + newInput[i] = abstractor.Call( + api, poseidon.Poseidon2{ + In1: temp[2*i], + In2: temp[2*i+1], + }, + ) + } + temp = newInput + } + return temp[0] +} - bits_pre := abstractor.Call1(api, ToReducedBigEndian{Variable: circuit.PreRoot, Size: 256}) - bits = append(bits, bits_pre...) +type Fr = emulated.BLS12381Fr + +const polynomialDegree = 4096 + +func computeOmegaToI() (*big.Int, *big.Int) { + // This function assumes BLS12381Fr field and polynomial degree 4096 + modulus, _ := new(big.Int).SetString( + "52435875175126190479447740508185965837690552500527637822603658699938581184513", 10, + ) + + // For polynomial degree d = 4096 = 2^12: + // ω^(2^32) = ω^(2^20 * 2^12) + // Calculate ω^20 starting with root of unity of 2^32 degree + omega, _ := new(big.Int).SetString( + "10238227357739495823651030575849232062558860180284477541189508159991286009131", 10, + ) + polynomialDegreeExp := int(math.Log2(float64(polynomialDegree))) + omegaExpExp := 32 // ω^(2^32) + for range omegaExpExp - polynomialDegreeExp { + omega.Mul(omega, omega) + omega.Mod(omega, modulus) + } - bits_post := abstractor.Call1(api, ToReducedBigEndian{Variable: circuit.PostRoot, Size: 256}) - bits = append(bits, bits_post...) + return omega, modulus +} - for i := 0; i < circuit.BatchSize; i++ { - bits_id := abstractor.Call1(api, ToReducedBigEndian{Variable: circuit.IdComms[i], Size: 256}) - bits = append(bits, bits_id...) +func evaluatePolynomial( + api frontend.API, interpolatingPoints []frontend.Variable, pointOfEvaluation frontend.Variable, +) (evaluationValue frontend.Variable) { + startingOmega, _ := computeOmegaToI() + omegasToI := make([]emulated.Element[Fr], polynomialDegree) + omegaToI := big.NewInt(1) + for i := range polynomialDegree { + omegasToI[bits.Reverse64(uint64(i))>>52] = emulated.ValueOf[Fr](omegaToI) + omegaToI.Mul(omegaToI, startingOmega) } - hash, err := keccak.Keccak256(api, bits) + field, err := emulated.NewField[Fr](api) if err != nil { return err } - sum := abstractor.Call(api, FromBinaryBigEndian{Variable: hash}) - // The same endianness conversion has been performed in the hash generation - // externally, so we can safely assert their equality here. - api.AssertIsEqual(circuit.InputHash, sum) + x := *field.FromBits(api.ToBinary(pointOfEvaluation)...) + w := make([]emulated.Element[Fr], len(interpolatingPoints)) + for i, p := range interpolatingPoints { + w[i] = *field.FromBits(api.ToBinary(p)...) + } + y := barycentric.CalculateBarycentricFormula(field, omegasToI, w, x) + + evaluationValue = api.FromBinary(field.ToBits(&y)...) + return +} - // Actual batch merkle proof verification. - root := abstractor.Call(api, InsertionProof{ - StartIndex: circuit.StartIndex, - PreRoot: circuit.PreRoot, - IdComms: circuit.IdComms, +func (circuit *InsertionMbuCircuit) Define(api frontend.API) error { + paddedIdComms := make([]frontend.Variable, polynomialDegree) + for i := range paddedIdComms { + paddedIdComms[i] = 0 + } + copy(paddedIdComms, circuit.IdComms) + rootHash := getMerkleTreeRoot(api, paddedIdComms) + api.AssertIsEqual(circuit.InputHash, rootHash) - MerkleProofs: circuit.MerkleProofs, + var bitsHashCommitment []frontend.Variable + // We convert all the inputs to the keccak hash to use big-endian (network) byte + // ordering so that it agrees with Solidity. This ensures that we don't have to + // perform the conversion inside the contract and hence save on gas. + bitsHash := abstractor.Call1( + api, ToReducedBigEndian{ + Variable: circuit.InputHash, + Size: 256, + }, + ) + bitsHashCommitment = append(bitsHashCommitment, bitsHash...) + bitsCommitment := abstractor.Call1( + api, ToReducedBigEndian{ + Variable: circuit.Commitment4844, + Size: 256, + }, + ) + bitsHashCommitment = append(bitsHashCommitment, bitsCommitment...) + + // Compute Fiat-Shamir challenge of input hash and 4844 commitment + hash, err := keccak.Keccak256(api, bitsHashCommitment) + if err != nil { + return err + } + challenge := abstractor.Call(api, FromBinaryBigEndian{Variable: hash}) + + // Calculate evaluation of polynomial interpolated by identities in the point x=challenge + evaluation := evaluatePolynomial(api, paddedIdComms, challenge) + api.AssertIsEqual(circuit.ExpectedEvaluation, evaluation) - BatchSize: circuit.BatchSize, - Depth: circuit.Depth, - }) + // Actual batch merkle proof verification. + root := abstractor.Call( + api, InsertionProof{ + StartIndex: circuit.StartIndex, + PreRoot: circuit.PreRoot, + IdComms: circuit.IdComms, + + MerkleProofs: circuit.MerkleProofs, + + BatchSize: circuit.BatchSize, + Depth: circuit.Depth, + }, + ) // Final root needs to match. api.AssertIsEqual(root, circuit.PostRoot) diff --git a/prover/insertion_circuit_test.go b/prover/insertion_circuit_test.go new file mode 100644 index 0000000..3e95173 --- /dev/null +++ b/prover/insertion_circuit_test.go @@ -0,0 +1,152 @@ +package prover + +import ( + "crypto/rand" + "math" + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + bn254fr "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + gokzg4844 "github.com/crate-crypto/go-kzg-4844" + "github.com/iden3/go-iden3-crypto/keccak256" + "github.com/stretchr/testify/require" + + poseidon "worldcoin/gnark-mbu/poseidon_native" +) + +const ( + numGoRoutines = 0 + existingUsersCount = 1500 + batchSize = 20 + depth = 20 +) + +func TestInsertionCircuit(t *testing.T) { + incomingIds := generateRandomIdentities(batchSize) + smallTree := poseidon.NewTree(treeDepth(polynomialDegree)) + idComms := make([]frontend.Variable, batchSize) + for i, id := range incomingIds { + idComms[i] = id + smallTree.Update(i, id) + } + incomingIdsTreeRoot := smallTree.Root() + incomingIdsTreeRoot = *bytesToBn254BigInt(incomingIdsTreeRoot.Bytes()) + + ctx, err := gokzg4844.NewContext4096Secure() + require.NoError(t, err) + blob := identitiesToBlob(incomingIds) + commitment, err := ctx.BlobToKZGCommitment(blob, numGoRoutines) + require.NoError(t, err) + commitment4844 := *bytesToBn254BigInt(commitment[:]) + + challenge := bigIntsToChallenge([]big.Int{incomingIdsTreeRoot, commitment4844}) + proof, evaluation, err := ctx.ComputeKZGProof(blob, challenge, numGoRoutines) + require.NoError(t, err) + err = ctx.VerifyKZGProof(commitment, challenge, evaluation, proof) + require.NoError(t, err) + + existingIds := generateRandomIdentities(existingUsersCount) + bigTree := poseidon.NewTree(depth) + for i, id := range existingIds { + bigTree.Update(i, id) + } + preRoot := bigTree.Root() + merkleProofs := make([][]frontend.Variable, batchSize) + for i, id := range incomingIds { + mp := bigTree.Update(i+existingUsersCount, id) + merkleProofs[i] = make([]frontend.Variable, len(mp)) + for j, v := range mp { + merkleProofs[i][j] = v + } + } + postRoot := bigTree.Root() + + circuit := InsertionMbuCircuit{ + IdComms: make([]frontend.Variable, batchSize), + MerkleProofs: make([][]frontend.Variable, batchSize), + BatchSize: batchSize, + Depth: depth, + } + for i := range merkleProofs { + circuit.MerkleProofs[i] = make([]frontend.Variable, depth) + } + + assignment := InsertionMbuCircuit{ + InputHash: incomingIdsTreeRoot, + ExpectedEvaluation: evaluation[:], + Commitment4844: commitment4844, + StartIndex: existingUsersCount, + PreRoot: preRoot, + PostRoot: postRoot, + IdComms: idComms, + MerkleProofs: merkleProofs, + BatchSize: batchSize, + Depth: depth, + } + + assert := test.NewAssert(t) + assert.CheckCircuit( + &circuit, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254), + test.WithValidAssignment(&assignment), + ) +} + +// generateRandomIdentities generates a slice of random big integers reduced modulo BN254 FR. +func generateRandomIdentities(count int) []big.Int { + ids := make([]big.Int, count) + for i := range ids { + n, _ := rand.Int(rand.Reader, bn254fr.Modulus()) + ids[i] = *n + } + return ids +} + +// identitiesToBlob converts a slice of big.Int into a KZG 4844 Blob. +func identitiesToBlob(ids []big.Int) *gokzg4844.Blob { + if len(ids) > gokzg4844.ScalarsPerBlob { + panic("too many identities for a blob") + } + var blob gokzg4844.Blob + for i, id := range ids { + startByte := i * 32 + id.FillBytes(blob[startByte : startByte+32]) + } + return &blob +} + +// bytesToBn254BigInt converts a slice of bytes to a *big.Int and reduces it by BN254 modulus +func bytesToBn254BigInt(b []byte) *big.Int { + n := new(big.Int).SetBytes(b) + modulus := bn254fr.Modulus() + return n.Mod(n, modulus) +} + +// bigIntsToChallenge converts input bit.Ints to a challenge for a proof of knowledge of a polynomial. +// The challenge is defined as a gokzg4844.Scalar of a keccak256 hash of all input big.Ints reduced +// by BN254 modulus. +func bigIntsToChallenge(input []big.Int) (challenge gokzg4844.Scalar) { + var inputBytes []byte + for _, i := range input { + temp := make([]byte, 32) + inputBytes = append(inputBytes, i.FillBytes(temp)...) + } + + // Reduce keccak because gokzg4844 API expects that + hashBytes := bytesToBn254BigInt(keccak256.Hash(inputBytes)).Bytes() + + copy(challenge[:], hashBytes) + return challenge +} + +// treeDepth calculates the depth of a binary tree containing the given number of leaves +func treeDepth(leavesCount int) (height int) { + if leavesCount <= 0 { + return 0 + } + height = int(math.Ceil(math.Log2(float64(leavesCount)))) + return +}