From 056a6bcad4c5274e48690b1f253358de650d6638 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Mon, 11 Dec 2023 10:01:20 -0600 Subject: [PATCH] refactor: plonk.Setup takes kzg srs in canonical and lagrange form (#953) * refactor: plonk.Setup takes kzg srs in lagrange and canonical forms * fix: revert one change in previous commit * feat: added test/unsafekzg package * perf: unsafeReadFrom + bufio for kzg srs cache * fix: address PR comments and add logs * docs: update plonk setup doc --- backend/plonk/bls12-377/setup.go | 29 +- backend/plonk/bls12-381/setup.go | 29 +- backend/plonk/bls24-315/setup.go | 29 +- backend/plonk/bls24-317/setup.go | 29 +- backend/plonk/bn254/setup.go | 29 +- backend/plonk/bw6-633/setup.go | 29 +- backend/plonk/bw6-761/setup.go | 29 +- backend/plonk/plonk.go | 19 +- backend/plonk/plonk_test.go | 38 +- debug_test.go | 7 +- examples/plonk/main.go | 10 +- go.mod | 2 +- go.sum | 8 +- .../zkpschemes/plonk/plonk.setup.go.tmpl | 29 +- std/gkr/api_test.go | 6 +- std/recursion/plonk/native_doc_test.go | 7 +- std/recursion/plonk/nonnative_doc_test.go | 12 +- std/recursion/plonk/verifier_test.go | 10 +- test/assert_checkcircuit.go | 5 +- test/kzg_srs.go | 105 ----- test/unsafekzg/kzgsrs.go | 418 ++++++++++++++++++ test/unsafekzg/options.go | 66 +++ 22 files changed, 681 insertions(+), 264 deletions(-) delete mode 100644 test/kzg_srs.go create mode 100644 test/unsafekzg/kzgsrs.go create mode 100644 test/unsafekzg/options.go diff --git a/backend/plonk/bls12-377/setup.go b/backend/plonk/bls12-377/setup.go index cede3fe039..e45d620c18 100644 --- a/backend/plonk/bls12-377/setup.go +++ b/backend/plonk/bls12-377/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bls12-381/setup.go b/backend/plonk/bls12-381/setup.go index 470f928e90..3fdb43cd52 100644 --- a/backend/plonk/bls12-381/setup.go +++ b/backend/plonk/bls12-381/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bls24-315/setup.go b/backend/plonk/bls24-315/setup.go index bdb6c050ad..2a2756991a 100644 --- a/backend/plonk/bls24-315/setup.go +++ b/backend/plonk/bls24-315/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bls24-317/setup.go b/backend/plonk/bls24-317/setup.go index 05e3c69112..6ebb0f8ee3 100644 --- a/backend/plonk/bls24-317/setup.go +++ b/backend/plonk/bls24-317/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bn254/setup.go b/backend/plonk/bn254/setup.go index dcd0481581..4a51f25b79 100644 --- a/backend/plonk/bn254/setup.go +++ b/backend/plonk/bn254/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bw6-633/setup.go b/backend/plonk/bw6-633/setup.go index 3be6c6fd5d..81f2735a32 100644 --- a/backend/plonk/bw6-633/setup.go +++ b/backend/plonk/bw6-633/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/bw6-761/setup.go b/backend/plonk/bw6-761/setup.go index e19f17f151..cd54cf19e7 100644 --- a/backend/plonk/bw6-761/setup.go +++ b/backend/plonk/bw6-761/setup.go @@ -17,7 +17,6 @@ package plonk import ( - "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -107,7 +106,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -120,22 +119,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size)+3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -153,7 +156,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/backend/plonk/plonk.go b/backend/plonk/plonk.go index 2fd11b18aa..66264bd1dd 100644 --- a/backend/plonk/plonk.go +++ b/backend/plonk/plonk.go @@ -96,23 +96,26 @@ type VerifyingKey interface { } // Setup prepares the public data associated to a circuit + public inputs. -func Setup(ccs constraint.ConstraintSystem, kzgSrs kzg.SRS) (ProvingKey, VerifyingKey, error) { +// The kzg SRS must be provided in canonical and lagrange form. +// For test purposes, see test/unsafekzg package. With an existing SRS generated through MPC in canonical form, +// gnark-crypto offers the ToLagrangeG1 method to convert it to lagrange form. +func Setup(ccs constraint.ConstraintSystem, srs, srsLagrange kzg.SRS) (ProvingKey, VerifyingKey, error) { switch tccs := ccs.(type) { case *cs_bn254.SparseR1CS: - return plonk_bn254.Setup(tccs, *kzgSrs.(*kzg_bn254.SRS)) + return plonk_bn254.Setup(tccs, *srs.(*kzg_bn254.SRS), *srsLagrange.(*kzg_bn254.SRS)) case *cs_bls12381.SparseR1CS: - return plonk_bls12381.Setup(tccs, *kzgSrs.(*kzg_bls12381.SRS)) + return plonk_bls12381.Setup(tccs, *srs.(*kzg_bls12381.SRS), *srsLagrange.(*kzg_bls12381.SRS)) case *cs_bls12377.SparseR1CS: - return plonk_bls12377.Setup(tccs, *kzgSrs.(*kzg_bls12377.SRS)) + return plonk_bls12377.Setup(tccs, *srs.(*kzg_bls12377.SRS), *srsLagrange.(*kzg_bls12377.SRS)) case *cs_bw6761.SparseR1CS: - return plonk_bw6761.Setup(tccs, *kzgSrs.(*kzg_bw6761.SRS)) + return plonk_bw6761.Setup(tccs, *srs.(*kzg_bw6761.SRS), *srsLagrange.(*kzg_bw6761.SRS)) case *cs_bls24317.SparseR1CS: - return plonk_bls24317.Setup(tccs, *kzgSrs.(*kzg_bls24317.SRS)) + return plonk_bls24317.Setup(tccs, *srs.(*kzg_bls24317.SRS), *srsLagrange.(*kzg_bls24317.SRS)) case *cs_bls24315.SparseR1CS: - return plonk_bls24315.Setup(tccs, *kzgSrs.(*kzg_bls24315.SRS)) + return plonk_bls24315.Setup(tccs, *srs.(*kzg_bls24315.SRS), *srsLagrange.(*kzg_bls24315.SRS)) case *cs_bw6633.SparseR1CS: - return plonk_bw6633.Setup(tccs, *kzgSrs.(*kzg_bw6633.SRS)) + return plonk_bw6633.Setup(tccs, *srs.(*kzg_bw6633.SRS), *srsLagrange.(*kzg_bw6633.SRS)) default: panic("unrecognized SparseR1CS curve type") } diff --git a/backend/plonk/plonk_test.go b/backend/plonk/plonk_test.go index bce07d8a1a..5d1fa1207e 100644 --- a/backend/plonk/plonk_test.go +++ b/backend/plonk/plonk_test.go @@ -15,6 +15,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" "github.com/stretchr/testify/require" ) @@ -29,14 +30,14 @@ func TestProver(t *testing.T) { var b1, b2 bytes.Buffer assert := require.New(t) - ccs, _solution, srs := referenceCircuit(curve) + ccs, _solution, srs, srsLagrange := referenceCircuit(curve) fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) assert.NoError(err) publicWitness, err := fullWitness.Public() assert.NoError(err) - pk, vk, err := plonk.Setup(ccs, srs) + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) assert.NoError(err) // write the PK to ensure it is not mutated @@ -75,9 +76,10 @@ func TestCustomHashToField(t *testing.T) { assert.Run(func(assert *test.Assert) { ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &commitmentCircuit{}) assert.NoError(err) - srs, err := test.NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) assert.NoError(err) - pk, vk, err := plonk.Setup(ccs, srs) + + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) assert.NoError(err) witness, err := frontend.NewWitness(assignment, curve.ScalarField()) assert.NoError(err) @@ -114,9 +116,10 @@ func TestCustomChallengeHash(t *testing.T) { assert.Run(func(assert *test.Assert) { ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{}) assert.NoError(err) - srs, err := test.NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) assert.NoError(err) - pk, vk, err := plonk.Setup(ccs, srs) + + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) assert.NoError(err) witness, err := frontend.NewWitness(assignment, curve.ScalarField()) assert.NoError(err) @@ -156,9 +159,10 @@ func TestCustomKZGFoldingHash(t *testing.T) { assert.Run(func(assert *test.Assert) { ccs, err := frontend.Compile(curve.ScalarField(), scs.NewBuilder, &smallCircuit{}) assert.NoError(err) - srs, err := test.NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) assert.NoError(err) - pk, vk, err := plonk.Setup(ccs, srs) + + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) assert.NoError(err) witness, err := frontend.NewWitness(assignment, curve.ScalarField()) assert.NoError(err) @@ -193,10 +197,10 @@ func TestCustomKZGFoldingHash(t *testing.T) { func BenchmarkSetup(b *testing.B) { for _, curve := range getCurves() { b.Run(curve.String(), func(b *testing.B) { - ccs, _, srs := referenceCircuit(curve) + ccs, _, srs, srsLagrange := referenceCircuit(curve) b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, _ = plonk.Setup(ccs, srs) + _, _, _ = plonk.Setup(ccs, srs, srsLagrange) } }) } @@ -205,12 +209,12 @@ func BenchmarkSetup(b *testing.B) { func BenchmarkProver(b *testing.B) { for _, curve := range getCurves() { b.Run(curve.String(), func(b *testing.B) { - ccs, _solution, srs := referenceCircuit(curve) + ccs, _solution, srs, srsLagrange := referenceCircuit(curve) fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) if err != nil { b.Fatal(err) } - pk, _, err := plonk.Setup(ccs, srs) + pk, _, err := plonk.Setup(ccs, srs, srsLagrange) if err != nil { b.Fatal(err) } @@ -225,7 +229,7 @@ func BenchmarkProver(b *testing.B) { func BenchmarkVerifier(b *testing.B) { for _, curve := range getCurves() { b.Run(curve.String(), func(b *testing.B) { - ccs, _solution, srs := referenceCircuit(curve) + ccs, _solution, srs, srsLagrange := referenceCircuit(curve) fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField()) if err != nil { b.Fatal(err) @@ -235,7 +239,7 @@ func BenchmarkVerifier(b *testing.B) { b.Fatal(err) } - pk, vk, err := plonk.Setup(ccs, srs) + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) if err != nil { b.Fatal(err) } @@ -266,7 +270,7 @@ func (circuit *refCircuit) Define(api frontend.API) error { return nil } -func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit, kzg.SRS) { +func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit, kzg.SRS, kzg.SRS) { const nbConstraints = (1 << 12) - 3 circuit := refCircuit{ nbConstraints: nbConstraints, @@ -286,11 +290,11 @@ func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circu expectedY.Exp(expectedY, exp, curve.ScalarField()) good.Y = expectedY - srs, err := test.NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) if err != nil { panic(err) } - return ccs, &good, srs + return ccs, &good, srs, srsLagrange } type commitmentCircuit struct { diff --git a/debug_test.go b/debug_test.go index 620e9860d0..e171942779 100644 --- a/debug_test.go +++ b/debug_test.go @@ -14,7 +14,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) @@ -159,11 +159,12 @@ func getPlonkTrace(circuit, w frontend.Circuit) (string, error) { return "", err } - srs, err := test.NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) if err != nil { return "", err } - pk, _, err := plonk.Setup(ccs, srs) + + pk, _, err := plonk.Setup(ccs, srs, srsLagrange) if err != nil { return "", err } diff --git a/examples/plonk/main.go b/examples/plonk/main.go index b537c226a9..ac28303cf5 100644 --- a/examples/plonk/main.go +++ b/examples/plonk/main.go @@ -22,9 +22,9 @@ import ( "github.com/consensys/gnark/backend/plonk" cs "github.com/consensys/gnark/constraint/bn254" "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/test" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test/unsafekzg" ) // In this example we show how to use PLONK with KZG commitments. The circuit that is @@ -83,8 +83,8 @@ func main() { // has been run before. // The size of the data in KZG should be the closest power of 2 bounding // // above max(nbConstraints, nbVariables). - _r1cs := ccs.(*cs.SparseR1CS) - srs, err := test.NewKZGSRS(_r1cs) + scs := ccs.(*cs.SparseR1CS) + srs, srsLagrange, err := unsafekzg.NewSRS(scs) if err != nil { panic(err) } @@ -111,7 +111,7 @@ func main() { // public data consists of the polynomials describing the constants involved // in the constraints, the polynomial describing the permutation ("grand // product argument"), and the FFT domains. - pk, vk, err := plonk.Setup(ccs, srs) + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) //_, err := plonk.Setup(r1cs, kate, &publicWitness) if err != nil { log.Fatal(err) @@ -152,7 +152,7 @@ func main() { // public data consists of the polynomials describing the constants involved // in the constraints, the polynomial describing the permutation ("grand // product argument"), and the FFT domains. - pk, vk, err := plonk.Setup(ccs, srs) + pk, vk, err := plonk.Setup(ccs, srs, srsLagrange) //_, err := plonk.Setup(r1cs, kate, &publicWitness) if err != nil { log.Fatal(err) diff --git a/go.mod b/go.mod index 1e19b391d5..cb5c1b6a85 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.13 github.com/consensys/compress v0.1.0 - github.com/consensys/gnark-crypto v0.12.2-0.20231117165148-e77308824822 + github.com/consensys/gnark-crypto v0.12.2-0.20231208203441-d4eab6ddd2af github.com/fxamacker/cbor/v2 v2.5.0 github.com/google/go-cmp v0.5.9 github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b diff --git a/go.sum b/go.sum index 7da0c99404..c0d3ccb857 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,12 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= -github.com/consensys/compress v0.0.0-20231201231747-b7f0ad98d697 h1:Ar/NyBmxGYeKekc7a7sdpkKgZ6OO6P5Wc5aNH+DxlXE= -github.com/consensys/compress v0.0.0-20231201231747-b7f0ad98d697/go.mod h1:Ne8+cGKjqgjF1dlHapZx38pHzWpaBYhsKxQa+JPl0zM= github.com/consensys/compress v0.1.0 h1:fczDaganmx2198GudPo4+5VX3eBvKy/bEJfmNotbr70= github.com/consensys/compress v0.1.0/go.mod h1:Ne8+cGKjqgjF1dlHapZx38pHzWpaBYhsKxQa+JPl0zM= -github.com/consensys/gnark-crypto v0.12.2-0.20231117165148-e77308824822 h1:PvEjRgB/U4bv0jl9w65Wy9g0nIdkkW7vkNoR8Vq/als= -github.com/consensys/gnark-crypto v0.12.2-0.20231117165148-e77308824822/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/consensys/gnark-crypto v0.12.2-0.20231207224154-754e2ff331e8 h1:HOlb31u9qzF/3RmGN+mfwN008aWcVnrSaFEuGamFwmQ= +github.com/consensys/gnark-crypto v0.12.2-0.20231207224154-754e2ff331e8/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/consensys/gnark-crypto v0.12.2-0.20231208203441-d4eab6ddd2af h1:QbTpU3l/2wEFLtF4DQgApTXCDEtd9Jb8olP84VxvP4E= +github.com/consensys/gnark-crypto v0.12.2-0.20231208203441-d4eab6ddd2af/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl index b41b6c745a..bba64e9a40 100644 --- a/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl +++ b/internal/generator/backend/template/zkpschemes/plonk/plonk.setup.go.tmpl @@ -1,5 +1,4 @@ import ( - "errors" {{- template "import_kzg" . }} {{- template "import_fr" . }} {{- template "import_fft" . }} @@ -89,7 +88,7 @@ type ProvingKey struct { } // TODO modify the signature to receive the SRS in Lagrange form (optional argument ?) -func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, error) { +func Setup(spr *cs.SparseR1CS, srs, srsLagrange kzg.SRS) (*ProvingKey, *VerifyingKey, error) { var pk ProvingKey var vk VerifyingKey @@ -102,22 +101,26 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro return nil, nil, fmt.Errorf("circuit has only %d constraints; unsupported by the current implementation", spr.GetNbConstraints()) } + // check the size of the kzg srs. + if len(srs.Pk.G1) < (int(pk.Domain[0].Cardinality) + 3) { // + 3 for the kzg.Open of blinded poly + return nil, nil, fmt.Errorf("kzg srs is too small: got %d, need %d", len(srs.Pk.G1), pk.Domain[0].Cardinality+3) + } + + // same for the lagrange form + if len(srsLagrange.Pk.G1) != int(pk.Domain[0].Cardinality) { + return nil, nil, fmt.Errorf("kzg srs lagrange is too small: got %d, need %d", len(srsLagrange.Pk.G1), pk.Domain[0].Cardinality) + } + // step 1: set the verifying key pk.Vk.CosetShift.Set(&pk.Domain[0].FrMultiplicativeGen) vk.Size = pk.Domain[0].Cardinality vk.SizeInv.SetUint64(vk.Size).Inverse(&vk.SizeInv) vk.Generator.Set(&pk.Domain[0].Generator) vk.NbPublicVariables = uint64(len(spr.Public)) - if len(kzgSrs.Pk.G1) < int(vk.Size) +3 { // + 3 for the kzg.Open of blinded poly - return nil, nil, errors.New("kzg srs is too small") - } - pk.Kzg.G1 = kzgSrs.Pk.G1[:int(vk.Size)+3] - var err error - pk.KzgLagrange.G1, err = kzg.ToLagrangeG1(kzgSrs.Pk.G1[:int(vk.Size)]) - if err != nil { - return nil, nil, err - } - vk.Kzg = kzgSrs.Vk + + pk.Kzg.G1 = srs.Pk.G1[:int(vk.Size)+3] + pk.KzgLagrange.G1 = srsLagrange.Pk.G1 + vk.Kzg = srs.Vk // step 2: ql, qr, qm, qo, qk, qcp in Lagrange Basis BuildTrace(spr, &pk.trace) @@ -135,7 +138,7 @@ func Setup(spr *cs.SparseR1CS, kzgSrs kzg.SRS) (*ProvingKey, *VerifyingKey, erro // All the above polynomials are expressed in canonical basis afterwards. This is why // we save lqk before, because the prover needs to complete it in Lagrange form, and // then express it on the Lagrange coset basis. - if err = commitTrace(&pk.trace, &pk); err != nil { + if err := commitTrace(&pk.trace, &pk); err != nil { return nil, nil, err } diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index a0d1d4edbb..5caf1121dc 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -11,7 +11,6 @@ import ( "github.com/consensys/gnark-crypto/kzg" "github.com/consensys/gnark/backend/plonk" bn254r1cs "github.com/consensys/gnark/constraint/bn254" - "github.com/consensys/gnark/test" "github.com/stretchr/testify/require" "github.com/consensys/gnark-crypto/ecc" @@ -27,6 +26,7 @@ import ( stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" test_vector_utils "github.com/consensys/gnark/std/utils/test_vectors_utils" + "github.com/consensys/gnark/test/unsafekzg" ) // compressThreshold --> if linear expressions are larger than this, the frontend will introduce @@ -416,9 +416,9 @@ func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) { require.NoError(t, err) publicWitness, err = fullWitness.Public() require.NoError(t, err) - kzgSrs, err = test.NewKZGSRS(cs) + kzgSrs, srsLagrange, err := unsafekzg.NewSRS(cs) require.NoError(t, err) - pk, vk, err = plonk.Setup(cs, kzgSrs) + pk, vk, err = plonk.Setup(cs, kzgSrs, srsLagrange) require.NoError(t, err) proof, err = plonk.Prove(cs, pk, fullWitness) require.NoError(t, err) diff --git a/std/recursion/plonk/native_doc_test.go b/std/recursion/plonk/native_doc_test.go index 89e7f1aa32..2a840eb59d 100644 --- a/std/recursion/plonk/native_doc_test.go +++ b/std/recursion/plonk/native_doc_test.go @@ -7,7 +7,7 @@ import ( "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/std/recursion/plonk" - "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" ) // Example of verifying recursively BLS12-371 PLONK proof in BW6-761 PLONK circuit using field emulation @@ -46,12 +46,13 @@ func Example_native() { } // NB! UNSAFE! Use MPC. - srs, err := test.NewKZGSRS(innerCcs) + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) if err != nil { panic(err) } + // create PLONK setup. NB! UNSAFE - pk, vk, err := native_plonk.Setup(ccs, srs) // UNSAFE! Use MPC + pk, vk, err := native_plonk.Setup(ccs, srs, srsLagrange) // UNSAFE! Use MPC if err != nil { panic("setup failed: " + err.Error()) } diff --git a/std/recursion/plonk/nonnative_doc_test.go b/std/recursion/plonk/nonnative_doc_test.go index d9d34e326b..ccfbc369ce 100644 --- a/std/recursion/plonk/nonnative_doc_test.go +++ b/std/recursion/plonk/nonnative_doc_test.go @@ -14,7 +14,7 @@ import ( "github.com/consensys/gnark/std/algebra/emulated/sw_bw6761" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/recursion/plonk" - "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" ) // InnerCircuitNative is the definition of the inner circuit we want to @@ -44,11 +44,12 @@ func computeInnerProof(field, outer *big.Int) (constraint.ConstraintSystem, nati panic(err) } // NB! UNSAFE! Use MPC. - srs, err := test.NewKZGSRS(innerCcs) + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) if err != nil { panic(err) } - innerPK, innerVK, err := native_plonk.Setup(innerCcs, srs) + + innerPK, innerVK, err := native_plonk.Setup(innerCcs, srs, srsLagrange) if err != nil { panic(err) } @@ -131,12 +132,13 @@ func Example_emulated() { } // NB! UNSAFE! Use MPC. - srs, err := test.NewKZGSRS(innerCcs) + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) if err != nil { panic(err) } + // create PLONK setup. NB! UNSAFE - pk, vk, err := native_plonk.Setup(ccs, srs) // UNSAFE! Use MPC + pk, vk, err := native_plonk.Setup(ccs, srs, srsLagrange) // UNSAFE! Use MPC if err != nil { panic("setup failed: " + err.Error()) } diff --git a/std/recursion/plonk/verifier_test.go b/std/recursion/plonk/verifier_test.go index 924ac141f7..242c9f373b 100644 --- a/std/recursion/plonk/verifier_test.go +++ b/std/recursion/plonk/verifier_test.go @@ -16,6 +16,7 @@ import ( "github.com/consensys/gnark/std/algebra/native/sw_bls12377" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/test" + "github.com/consensys/gnark/test/unsafekzg" ) //----------------------------------------------------------------- @@ -35,9 +36,10 @@ func (c *InnerCircuitNativeWoCommit) Define(api frontend.API) error { func getInnerWoCommit(assert *test.Assert, field, outer *big.Int) (constraint.ConstraintSystem, plonk.VerifyingKey, witness.Witness, plonk.Proof) { innerCcs, err := frontend.Compile(field, scs.NewBuilder, &InnerCircuitNativeWoCommit{}) assert.NoError(err) - srs, err := test.NewKZGSRS(innerCcs) + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) assert.NoError(err) - innerPK, innerVK, err := plonk.Setup(innerCcs, srs) + + innerPK, innerVK, err := plonk.Setup(innerCcs, srs, srsLagrange) assert.NoError(err) // inner proof @@ -157,10 +159,10 @@ func getInnerCommit(assert *test.Assert, field, outer *big.Int) (constraint.Cons innerCcs, err := frontend.Compile(field, scs.NewBuilder, &InnerCircuitCommit{}) assert.NoError(err) - srs, err := test.NewKZGSRS(innerCcs) + srs, srsLagrange, err := unsafekzg.NewSRS(innerCcs) assert.NoError(err) - innerPK, innerVK, err := plonk.Setup(innerCcs, srs) + innerPK, innerVK, err := plonk.Setup(innerCcs, srs, srsLagrange) assert.NoError(err) // inner proof diff --git a/test/assert_checkcircuit.go b/test/assert_checkcircuit.go index 3ae327c886..a2e0a8c0e0 100644 --- a/test/assert_checkcircuit.go +++ b/test/assert_checkcircuit.go @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/test/unsafekzg" ) // CheckCircuit performs a series of check on the provided circuit. @@ -256,11 +257,11 @@ var ( pk, vk any, pkBuilder, vkBuilder, proofBuilder func() any, err error) { - srs, err := NewKZGSRS(ccs) + srs, srsLagrange, err := unsafekzg.NewSRS(ccs) if err != nil { return nil, nil, nil, nil, nil, err } - pk, vk, err = plonk.Setup(ccs, srs) + pk, vk, err = plonk.Setup(ccs, srs, srsLagrange) return pk, vk, func() any { return plonk.NewProvingKey(curve) }, func() any { return plonk.NewVerifyingKey(curve) }, func() any { return plonk.NewProof(curve) }, err }, prove: func(ccs constraint.ConstraintSystem, pk any, fullWitness witness.Witness, opts ...backend.ProverOption) (proof any, err error) { diff --git a/test/kzg_srs.go b/test/kzg_srs.go deleted file mode 100644 index 84ab658eab..0000000000 --- a/test/kzg_srs.go +++ /dev/null @@ -1,105 +0,0 @@ -/* -Copyright © 2021 ConsenSys Software Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package test - -import ( - "crypto/rand" - "sync" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/kzg" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/internal/utils" - - kzg_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/kzg" - kzg_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/kzg" - kzg_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/kzg" - kzg_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/kzg" - kzg_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/kzg" - kzg_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/kzg" - kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/kzg" -) - -const srsCachedSize = (1 << 14) + 3 - -// NewKZGSRS uses ccs nb variables and nb constraints to initialize a kzg srs -// for sizes < 2¹⁵, returns a pre-computed cached SRS -// -// /!\ warning /!\: this method is here for convenience only: in production, a SRS generated through MPC should be used. -func NewKZGSRS(ccs constraint.ConstraintSystem) (kzg.SRS, error) { - - nbConstraints := ccs.GetNbConstraints() - sizeSystem := nbConstraints + ccs.GetNbPublicVariables() - kzgSize := ecc.NextPowerOfTwo(uint64(sizeSystem)) + 3 - - if kzgSize <= srsCachedSize { - return getCachedSRS(ccs) - } - - return newKZGSRS(utils.FieldToCurve(ccs.Field()), kzgSize) -} - -var srsCache map[ecc.ID]kzg.SRS -var lock sync.Mutex - -func init() { - srsCache = make(map[ecc.ID]kzg.SRS) -} -func getCachedSRS(ccs constraint.ConstraintSystem) (kzg.SRS, error) { - lock.Lock() - defer lock.Unlock() - - curveID := utils.FieldToCurve(ccs.Field()) - - if srs, ok := srsCache[curveID]; ok { - return srs, nil - } - - srs, err := newKZGSRS(curveID, srsCachedSize) - if err != nil { - return nil, err - } - srsCache[curveID] = srs - return srs, nil -} - -func newKZGSRS(curve ecc.ID, kzgSize uint64) (kzg.SRS, error) { - - alpha, err := rand.Int(rand.Reader, curve.ScalarField()) - if err != nil { - return nil, err - } - - switch curve { - case ecc.BN254: - return kzg_bn254.NewSRS(kzgSize, alpha) - case ecc.BLS12_381: - return kzg_bls12381.NewSRS(kzgSize, alpha) - case ecc.BLS12_377: - return kzg_bls12377.NewSRS(kzgSize, alpha) - case ecc.BW6_761: - return kzg_bw6761.NewSRS(kzgSize, alpha) - case ecc.BLS24_317: - return kzg_bls24317.NewSRS(kzgSize, alpha) - case ecc.BLS24_315: - return kzg_bls24315.NewSRS(kzgSize, alpha) - case ecc.BW6_633: - return kzg_bw6633.NewSRS(kzgSize, alpha) - default: - panic("unrecognized R1CS curve type") - } -} diff --git a/test/unsafekzg/kzgsrs.go b/test/unsafekzg/kzgsrs.go new file mode 100644 index 0000000000..0fc2d3d6b8 --- /dev/null +++ b/test/unsafekzg/kzgsrs.go @@ -0,0 +1,418 @@ +// Package unsafekzg is a convenience package (to be use for test purposes only) +// to generate and cache SRS for the kzg scheme (and indirectly for PlonK setup). +// +// Functions in this package are thread safe. +package unsafekzg + +import ( + "bufio" + "crypto/rand" + "fmt" + "math/big" + "os" + "path/filepath" + "regexp" + "sync" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/logger" + + kzg_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/kzg" + kzg_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/kzg" + kzg_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/kzg" + kzg_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/kzg" + kzg_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/kzg" + kzg_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/kzg" + kzg_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/kzg" + + fft_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + fft_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + fft_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + fft_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + fft_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + fft_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + fft_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + + fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" + bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bn254" + bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761" +) + +var ( + cache = make(map[string]cacheEntry) + reCacheKey = regexp.MustCompile(`kzgsrs-(.*?)-\d+`) + memLock, fsLock sync.RWMutex +) + +// NewSRS returns a pair of kzg.SRS; one in canonical form, the other in lagrange form. +// Default options use a memory cache, see Option for more details & options. +func NewSRS(ccs constraint.ConstraintSystem, opts ...Option) (canonical kzg.SRS, lagrange kzg.SRS, err error) { + + nbConstraints := ccs.GetNbConstraints() + sizeSystem := nbConstraints + ccs.GetNbPublicVariables() + + sizeLagrange := ecc.NextPowerOfTwo(uint64(sizeSystem)) + sizeCanonical := sizeLagrange + 3 + + curveID := utils.FieldToCurve(ccs.Field()) + + log := logger.Logger().With().Str("package", "kzgsrs").Int("size", int(sizeCanonical)).Str("curve", curveID.String()).Logger() + + cfg, err := options(opts...) + if err != nil { + return nil, nil, err + } + + key := cacheKey(curveID, sizeCanonical) + log.Debug().Str("key", key).Msg("fetching SRS from mem cache") + memLock.RLock() + entry, ok := cache[key] + memLock.RUnlock() + if ok { + log.Debug().Msg("SRS found in mem cache") + return entry.canonical, entry.lagrange, nil + } + log.Debug().Msg("SRS not found in mem cache") + + if cfg.fsCache { + log.Debug().Str("key", key).Str("cacheDir", cfg.cacheDir).Msg("fetching SRS from fs cache") + fsLock.RLock() + entry, err = fsRead(key, cfg.cacheDir) + fsLock.RUnlock() + if err == nil { + log.Debug().Str("key", key).Msg("SRS found in fs cache") + canonical, lagrange = entry.canonical, entry.lagrange + memLock.Lock() + cache[key] = cacheEntry{canonical, lagrange} + memLock.Unlock() + return + } else { + log.Debug().Str("key", key).Err(err).Msg("SRS not found in fs cache") + panic(err) + } + } + + log.Debug().Msg("SRS not found in cache, generating") + + // not in cache, generate + canonical, lagrange, err = newSRS(curveID, sizeCanonical) + if err != nil { + return nil, nil, err + } + + // cache it + memLock.Lock() + cache[key] = cacheEntry{canonical, lagrange} + memLock.Unlock() + + if cfg.fsCache { + log.Debug().Str("key", key).Str("cacheDir", cfg.cacheDir).Msg("writing SRS to fs cache") + fsLock.Lock() + fsWrite(key, cfg.cacheDir, canonical, lagrange) + fsLock.Unlock() + } + + return canonical, lagrange, nil +} + +type cacheEntry struct { + canonical kzg.SRS + lagrange kzg.SRS +} + +func cacheKey(curveID ecc.ID, size uint64) string { + return fmt.Sprintf("kzgsrs-%s-%d", curveID.String(), size) +} + +func extractCurveID(key string) (ecc.ID, error) { + matches := reCacheKey.FindStringSubmatch(key) + + if len(matches) < 2 { + return ecc.UNKNOWN, fmt.Errorf("no curveID found in key") + } + return ecc.IDFromString(matches[1]) +} + +func newSRS(curveID ecc.ID, size uint64) (kzg.SRS, kzg.SRS, error) { + + tau, err := rand.Int(rand.Reader, curveID.ScalarField()) + if err != nil { + return nil, nil, err + } + + var srs kzg.SRS + + switch curveID { + case ecc.BN254: + srs, err = kzg_bn254.NewSRS(size, tau) + case ecc.BLS12_381: + srs, err = kzg_bls12381.NewSRS(size, tau) + case ecc.BLS12_377: + srs, err = kzg_bls12377.NewSRS(size, tau) + case ecc.BW6_761: + srs, err = kzg_bw6761.NewSRS(size, tau) + case ecc.BLS24_317: + srs, err = kzg_bls24317.NewSRS(size, tau) + case ecc.BLS24_315: + srs, err = kzg_bls24315.NewSRS(size, tau) + case ecc.BW6_633: + srs, err = kzg_bw6633.NewSRS(size, tau) + default: + panic("unrecognized R1CS curve type") + } + + if err != nil { + return nil, nil, err + } + + return srs, toLagrange(srs, tau), nil +} + +func toLagrange(canonicalSRS kzg.SRS, tau *big.Int) kzg.SRS { + + var lagrangeSRS kzg.SRS + + switch srs := canonicalSRS.(type) { + case *kzg_bn254.SRS: + newSRS := &kzg_bn254.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bn254.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + // do a fft on this. + d := fft_bn254.NewDomain(size) + d.FFTInverse(pAlpha, fft_bn254.DIF) + fft_bn254.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bn254.Generators() + newSRS.Pk.G1 = bn254.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bls12381.SRS: + newSRS := &kzg_bls12381.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bls12381.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + // do a fft on this. + d := fft_bls12381.NewDomain(size) + d.FFTInverse(pAlpha, fft_bls12381.DIF) + fft_bls12381.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bls12381.Generators() + newSRS.Pk.G1 = bls12381.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bls12377.SRS: + newSRS := &kzg_bls12377.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bls12377.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + // do a fft on this. + d := fft_bls12377.NewDomain(size) + d.FFTInverse(pAlpha, fft_bls12377.DIF) + fft_bls12377.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bls12377.Generators() + newSRS.Pk.G1 = bls12377.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bw6761.SRS: + newSRS := &kzg_bw6761.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bw6761.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + + // do a fft on this. + d := fft_bw6761.NewDomain(size) + d.FFTInverse(pAlpha, fft_bw6761.DIF) + fft_bw6761.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bw6761.Generators() + newSRS.Pk.G1 = bw6761.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bls24317.SRS: + newSRS := &kzg_bls24317.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bls24317.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + + // do a fft on this. + d := fft_bls24317.NewDomain(size) + d.FFTInverse(pAlpha, fft_bls24317.DIF) + fft_bls24317.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bls24317.Generators() + newSRS.Pk.G1 = bls24317.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bls24315.SRS: + newSRS := &kzg_bls24315.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bls24315.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + + // do a fft on this. + d := fft_bls24315.NewDomain(size) + d.FFTInverse(pAlpha, fft_bls24315.DIF) + fft_bls24315.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bls24315.Generators() + newSRS.Pk.G1 = bls24315.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + case *kzg_bw6633.SRS: + newSRS := &kzg_bw6633.SRS{Vk: srs.Vk} + size := uint64(len(srs.Pk.G1)) - 3 + + // instead of using ToLagrangeG1 we can directly do a fft on the powers of alpha + // since we know the randomness in test. + pAlpha := make([]fr_bw6633.Element, size) + pAlpha[0].SetUint64(1) + pAlpha[1].SetBigInt(tau) + for i := 2; i < len(pAlpha); i++ { + pAlpha[i].Mul(&pAlpha[i-1], &pAlpha[1]) + } + + // do a fft on this. + d := fft_bw6633.NewDomain(size) + d.FFTInverse(pAlpha, fft_bw6633.DIF) + fft_bw6633.BitReverse(pAlpha) + + // bath scalar mul + _, _, g1gen, _ := bw6633.Generators() + newSRS.Pk.G1 = bw6633.BatchScalarMultiplicationG1(&g1gen, pAlpha) + + lagrangeSRS = newSRS + default: + panic("unrecognized curve") + } + + return lagrangeSRS +} + +func fsRead(key string, cacheDir string) (cacheEntry, error) { + filePath := filepath.Join(cacheDir, key) + + // if file does not exist, return false + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return cacheEntry{}, fmt.Errorf("file %s does not exist", filePath) + } + + // else open file and read the srs. + f, err := os.Open(filePath) + if err != nil { + return cacheEntry{}, err + } + defer f.Close() + + r := bufio.NewReaderSize(f, 1<<20) + + curveID, err := extractCurveID(key) + if err != nil { + return cacheEntry{}, err + } + cacheEntry := cacheEntry{ + canonical: kzg.NewSRS(curveID), + lagrange: kzg.NewSRS(curveID), + } + _, err = cacheEntry.canonical.UnsafeReadFrom(r) + if err != nil { + return cacheEntry, err + } + _, err = cacheEntry.lagrange.UnsafeReadFrom(r) + if err != nil { + return cacheEntry, err + } + + return cacheEntry, nil +} + +func fsWrite(key string, cacheDir string, canonical kzg.SRS, lagrange kzg.SRS) { + // if file exist, return. + filePath := filepath.Join(cacheDir, key) + if _, err := os.Stat(filePath); err == nil { + return + } + + // else open file and write the srs. + f, err := os.Create(filePath) + if err != nil { + return + } + defer f.Close() + + w := bufio.NewWriterSize(f, 1<<20) + + if _, err = canonical.WriteRawTo(w); err != nil { + return + } + + if _, err = lagrange.WriteRawTo(w); err != nil { + return + } + + w.Flush() +} diff --git a/test/unsafekzg/options.go b/test/unsafekzg/options.go new file mode 100644 index 0000000000..3639505fde --- /dev/null +++ b/test/unsafekzg/options.go @@ -0,0 +1,66 @@ +package unsafekzg + +import ( + "os" + "path/filepath" + + "github.com/consensys/gnark/logger" +) + +type Option func(*config) error + +// WithCacheDir enables the filesystem cache and sets the cache directory +// to ~/.gnark/kzg by default. +func WithFSCache() Option { + return func(opt *config) error { + opt.fsCache = true + return nil + } +} + +type config struct { + fsCache bool + cacheDir string +} + +// default options +func options(opts ...Option) (config, error) { + var opt config + + // apply user provided options. + for _, option := range opts { + err := option(&opt) + if err != nil { + return opt, err + } + } + + // default value for cacheDir is ~/.gnark/kzg + if opt.fsCache { + if opt.cacheDir == "" { + homeDir, err := os.UserHomeDir() + if err != nil { + panic(err) + } + opt.cacheDir = filepath.Join(homeDir, ".gnark", "kzg") + } + initCache(opt.cacheDir) + } + + return opt, nil +} + +func initCache(cacheDir string) { + // get gnark logger + log := logger.Logger() + + // populate cache from disk + log.Warn().Str("cacheDir", cacheDir).Msg("using kzg srs cache") + + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + err := os.MkdirAll(cacheDir, 0700) + if err != nil { + panic(err) + } + } +}