Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prover performance and memory use optimised #20

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 53 additions & 82 deletions backend/groth16/bn254/icicle/icicle.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ func (pk *ProvingKey) setupDevicePointers() error {
// Prove generates the proof of knowledge of a r1cs with full witness (secret + public part).
func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) {
lg := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "icicle").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger()
//n := runtime.NumCPU()
lg.Info().Msg("start prove")
opt, err := backend.NewProverConfig(opts...)
if err != nil {
Expand Down Expand Up @@ -185,33 +184,30 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
for i := range commitmentInfo {
copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal())
}

var errPedersen error
chPedersenDone := make(chan struct{}, 1)
go func() {
// TODO: after the bottleneck of HostSlice creation is solved, this function that's currently executed on CPU
// might become the bottleneck, especially for relatively weak CPUs
if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil {
errPedersen = err
}

if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil {
return nil, err
close(chPedersenDone)
}()
if errPedersen != nil {
return nil, errPedersen
}

stream, _ := cuda_runtime.CreateStream()
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
ctx.Stream = &stream

// H (witness reduction / FFT part)
h_device := computeHonDevice(solution.A, solution.B, solution.C, &pk.Domain, stream)
// cpu calculate h
/*var h []fr.Element
chHDone := make(chan struct{}, 1)
go func() {
h = computeH(solution.A, solution.B, solution.C, &pk.Domain)
solution.A = nil
solution.B = nil
solution.C = nil
lg.Debug().Msg(fmt.Sprintf("h len: %d", len(h)))
chHDone <- struct{}{}
}()*/

// we need to copy and filter the wireValues for each multi exp
// as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity
var wireValuesA, wireValuesB []fr.Element
chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1)
var wireValuesA, wireValuesB, _wireValues []fr.Element
chWireValuesA, chWireValuesB, chWireValues := make(chan struct{}, 1), make(chan struct{}, 1), make(chan struct{}, 1)

go func() {
wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA))
Expand All @@ -237,6 +233,18 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

close(chWireValuesB)
}()
go func() {
// filter the wire values if needed
// TODO Perf @Tabaie worst memory allocation offender
toRemove := commitmentInfo.GetPrivateCommitted()
toRemove = append(toRemove, commitmentInfo.CommitmentIndexes())
_wireValues = filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...))

close(chWireValues)
}()

// H (witness reduction / FFT part)
h_device := computeHonDevice(solution.A, solution.B, solution.C, &pk.Domain, stream)

// sample random r and s
var r, s big.Int
Expand Down Expand Up @@ -275,16 +283,33 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
}
outHost.CopyFromDeviceAsync(&out, stream)

/*var cpuBs1 curve.G1Jac
_, err = cpuBs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2})
if err != nil {
return nil, fmt.Errorf("error in cpu MultiExp bs1: %v", err)
}*/
bs1 = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0])
//lg.Debug().Msg(fmt.Sprintf("gpu bs1 equal cpu bs1: %v", cpuBs1.Equal(&bs1)))
bs1.AddMixed(&pk.G1.Beta)
bs1.AddMixed(&deltas[1])

// Bs2 (1 multi exp G2 - size = len(wires))
var Bs, deltaS curve.G2Jac

outHostG2 := make(core.HostSlice[bn254.G2Projective], 1)
var outG2 core.DeviceSlice
outG2.MallocAsync(outHostG2.SizeOfElement(), outHostG2.SizeOfElement(), stream)
gerr = bn254.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM g2 b: %v", gerr)
}
outHostG2.CopyFromDeviceAsync(&outG2, stream)
outG2.FreeAsync(stream)
wireValuesBdevice.FreeAsync(stream)

Bs = *iciclegnark.G2PointToGnarkJac(&outHostG2[0])

deltaS.FromAffine(&pk.G2.Delta)
deltaS.ScalarMultiplication(&deltaS, &s)
Bs.AddAssign(&deltaS)
Bs.AddMixed(&pk.G2.Beta)

proof.Bs.FromJacobian(&Bs)

<-chWireValuesA
wireValuesAhost := iciclegnark.HostSliceFromScalars(wireValuesA)
gerr = bn254.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out)
Expand All @@ -293,20 +318,12 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
}
outHost.CopyFromDeviceAsync(&out, stream)

/*var cpuAr curve.G1Jac
_, err = cpuAr.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2})
if err != nil {
return nil, fmt.Errorf("error in cpu MultiExp ar: %v", err)
}*/
ar = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0])
//lg.Debug().Msg(fmt.Sprintf("gpu ar equal cpu bs1: %v", cpuAr.Equal(&ar)))

ar.AddMixed(&pk.G1.Alpha)
ar.AddMixed(&deltas[0])
proof.Ar.FromJacobian(&ar)

//<-chHDone

var krs, krs2, p1 curve.G1Jac
gerr = bn254.Msm(h_device, pk.G1Device.Z, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
Expand All @@ -319,21 +336,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
solution.B = nil
solution.C = nil

/*var cpuKrs2 curve.G1Jac
sizeH := int(pk.Domain.Cardinality - 1)
_, err = cpuKrs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2})
if err != nil {
return nil, fmt.Errorf("error in cpu MultiExp cpuKrs2: %v", err)
}*/
krs2 = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0])
//lg.Debug().Msg(fmt.Sprintf("gpu ar equal cpu krs2: %v", cpuKrs2.Equal(&krs2)))

// filter the wire values if needed
// TODO Perf @Tabaie worst memory allocation offender
toRemove := commitmentInfo.GetPrivateCommitted()
toRemove = append(toRemove, commitmentInfo.CommitmentIndexes())
_wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...))

<-chWireValues
_wireValuesHost := iciclegnark.HostSliceFromScalars(_wireValues)
gerr = bn254.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
Expand All @@ -342,13 +347,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)

/*var cpuKrs curve.G1Jac
_, err = cpuKrs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2})
if err != nil {
return nil, fmt.Errorf("error in cpu MultiExp krs: %v", err)
}*/
krs = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0])
//lg.Debug().Msg(fmt.Sprintf("gpu ar equal cpu krs: %v", cpuKrs.Equal(&krs)))

krs.AddMixed(&deltas[2])
krs.AddAssign(&krs2)
Expand All @@ -359,34 +358,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

proof.Krs.FromJacobian(&krs)

// Bs2 (1 multi exp G2 - size = len(wires))
var Bs, deltaS curve.G2Jac

outHostG2 := make(core.HostSlice[bn254.G2Projective], 1)
var outG2 core.DeviceSlice
outG2.MallocAsync(outHostG2.SizeOfElement(), outHostG2.SizeOfElement(), stream)
gerr = bn254.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM g2 b: %v", gerr)
}
outHostG2.CopyFromDeviceAsync(&outG2, stream)
outG2.FreeAsync(stream)
wireValuesBdevice.FreeAsync(stream)

/*var cpuBs curve.G2Jac
_, err = cpuBs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n})
if err != nil {
return nil, fmt.Errorf("error in cpu G2 MultiExp Bs: %v", err)
}*/
Bs = *iciclegnark.G2PointToGnarkJac(&outHostG2[0])
//lg.Debug().Msg(fmt.Sprintf("gpu ar equal cpu Bs: %v", cpuBs.Equal(&Bs)))

deltaS.FromAffine(&pk.G2.Delta)
deltaS.ScalarMultiplication(&deltaS, &s)
Bs.AddAssign(&deltaS)
Bs.AddMixed(&pk.G2.Beta)

proof.Bs.FromJacobian(&Bs)
<-chPedersenDone

lg.Debug().Dur("took", time.Since(start)).Msg("prover done")

Expand Down Expand Up @@ -520,7 +492,6 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
den_host := iciclegnark.HostSliceFromScalars(den_repeated)

vcfg := core.DefaultVecOpsConfig()
cfg.IsAsync = true
vcfg.Ctx.Stream = &stream

// h = ifft_coset(ca o cb - cc)
Expand All @@ -535,4 +506,4 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
b_device.FreeAsync(stream)
c_device.FreeAsync(stream)
return a_device
}
}
125 changes: 121 additions & 4 deletions testgpu/gpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,128 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/logger"
cs_254 "github.com/consensys/gnark/constraint/bn254"
regroth16 "github.com/consensys/gnark/std/recursion/groth16"
"github.com/consensys/gnark/test"
"github.com/rs/zerolog"
)

func ReadProvingKey(filename string, pk groth16.ProvingKey) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()

_, err = pk.UnsafeReadFrom(f)
return err
}

func WriteProvingKey(pk groth16.ProvingKey, filename string) {
f, err := os.Create(filename)
if err != nil {
fmt.Errorf("pk writing open failed... ")
}
_, err = pk.WriteTo(f)
if err != nil {
fmt.Errorf("pk writing failed... ")
}
}

func ReadVerifyingKey(filename string, vk groth16.VerifyingKey) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()

_, err = vk.UnsafeReadFrom(f)
return err
}

func WriteVerifyingKey(vk groth16.VerifyingKey, filename string) {
f, err := os.Create(filename)
if err != nil {
fmt.Errorf("vk writing failed... ")
}

_, err = vk.WriteTo(f)
if err != nil {
fmt.Errorf("vk writing failed... ")
}
}

func WriteCcs(ccs constraint.ConstraintSystem, filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()

_, err = ccs.WriteTo(f)
if err != nil {
return err
}
return nil
}

func ReadCcs(filename string, ccs constraint.ConstraintSystem) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()

_, err = ccs.ReadFrom(f)
return err
}

func LoadOrGenPkVkForTest(ccs constraint.ConstraintSystem, curveID ecc.ID, name string) (groth16.ProvingKey, groth16.VerifyingKey) {
fmt.Printf("Start to setup pk \n")
var err error
pkFileName := fmt.Sprintf("%s.pk", name)
vkFileName := fmt.Sprintf("%s.vk", name)
var pk = groth16.NewProvingKey(curveID)
var vk = groth16.NewVerifyingKey(curveID)
err1 := ReadProvingKey(pkFileName, pk)
err2 := ReadVerifyingKey(vkFileName, vk)
if err1 != nil || err2 != nil {
fmt.Printf("Failed to read pk and vk, and try create, %v, %v \n", err1, err2)
pk, vk, err = groth16.Setup(ccs)
if err != nil {
fmt.Errorf("e: %v", err)
}
WriteProvingKey(pk, pkFileName)
WriteVerifyingKey(vk, vkFileName)
}
return pk, vk
}

func LoadOrGenCcsBN254ForTest(filename string, circuit frontend.Circuit) *cs_254.R1CS {
filename = fmt.Sprintf("%s.ccs", filename)
loadCcs := new(cs_254.R1CS)
err := ReadCcs(filename, loadCcs)
if err == nil {
fmt.Printf("load 254 ccs success: %s \n", filename)
return loadCcs
}
ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit)
if err != nil {
fmt.Errorf("e: %v", err)
}

err = WriteCcs(ccs, filename)
if err != nil {
fmt.Errorf("e: %v", err)
}

err = ReadCcs(filename, loadCcs)
if err != nil {
fmt.Errorf("e: %v", err)
}
return loadCcs
}

func TestBn254Gpu(t *testing.T) {
logger.Set(zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "15:04:05"}).With().Timestamp().Logger())
assert := test.NewAssert(t)
Expand Down Expand Up @@ -74,10 +191,10 @@ func TestBn254VerifyBw6761(t *testing.T) {
err = test.IsSolved(circuit, assigment, ecc.BN254.ScalarField())
assert.NoError(err)

outerCcs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, circuit)
assert.NoError(err)
outerPK, outerVK, err := groth16.Setup(outerCcs)
assert.NoError(err)
fileName := "outer_circuit"

outerCcs := LoadOrGenCcsBN254ForTest(fileName, circuit)
outerPK, outerVK := LoadOrGenPkVkForTest(outerCcs, ecc.BN254, fileName)
outerWitness, err := frontend.NewWitness(assigment, ecc.BN254.ScalarField())
assert.NoError(err)
for i := 0; i < 2; i++ {
Expand Down
Loading