diff --git a/backend/groth16/bn254/icicle/icicle.go b/backend/groth16/bn254/icicle/icicle.go index 2a069154e5..990399136f 100644 --- a/backend/groth16/bn254/icicle/icicle.go +++ b/backend/groth16/bn254/icicle/icicle.go @@ -3,14 +3,11 @@ package icicle_bn254 import ( - "errors" "fmt" "math/big" - "runtime" "sync" "time" - "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fp" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -28,9 +25,9 @@ import ( "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" "github.com/ingonyama-zk/icicle/wrappers/golang/core" - cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" - iciclewrapper_bn254 "github.com/ingonyama-zk/icicle/wrappers/golang/curves/bn254" - iciclegnark_bn254 "github.com/ingonyama-zk/iciclegnark/curves/bn254" + "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" + "github.com/ingonyama-zk/icicle/wrappers/golang/curves/bn254" + iciclegnark "github.com/ingonyama-zk/iciclegnark/curves/bn254" ) const HasIcicle = true @@ -53,64 +50,56 @@ func (pk *ProvingKey) setupDevicePointers() error { } pk.deviceInfo = &deviceInfo{} - // copy pk A to device - fmt.Printf("start copy pk A \n") + // ntt config + ctx, _ := cuda_runtime.GetDefaultDeviceContext() + var s bn254.ScalarField + + // domain.Generator + gen, _ := fft.Generator(4 * pk.Domain.Cardinality) + genBits := gen.Bits() + s.FromLimbs(core.ConvertUint64ArrToUint32Arr(genBits[:])) + bn254.InitDomain(s, ctx, true) + + /************************* Start G1 Device Setup ***************************/ + /************************* A ***************************/ copyADone := make(chan core.DeviceSlice, 1) - go iciclegnark_bn254.CopyPointsToDevice(pk.G1.A, copyADone) // Make a function for points - pk.G1Device.A = <-copyADone - fmt.Printf("end copy pk A \n") + go iciclegnark.CopyPointsToDevice(pk.G1.A, copyADone) // Make a function for points - // opcy pk B to device - fmt.Printf("start copy pk B \n") + /************************* B ***************************/ copyBDone := make(chan core.DeviceSlice, 1) - go iciclegnark_bn254.CopyPointsToDevice(pk.G1.B, copyBDone) // Make a function for points - pk.G1Device.B = <-copyBDone - fmt.Printf("end copy pk B \n") + go iciclegnark.CopyPointsToDevice(pk.G1.B, copyBDone) // Make a function for points + + /************************* K ***************************/ + var pointsNoInfinity []curve.G1Affine + for i, gnarkPoint := range pk.G1.K { + if gnarkPoint.IsInfinity() { + pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i) + } else { + pointsNoInfinity = append(pointsNoInfinity, gnarkPoint) + } + } - fmt.Printf("start copy pk K \n") copyKDone := make(chan core.DeviceSlice, 1) - go iciclegnark_bn254.CopyPointsToDevice(pk.G1.K, copyKDone) // Make a function for points - pk.G1Device.K = <-copyKDone - fmt.Printf("end copy pk K \n") + go iciclegnark.CopyPointsToDevice(pointsNoInfinity, copyKDone) // Make a function for points - fmt.Printf("start copy pk Z \n") + /************************* Z ***************************/ copyZDone := make(chan core.DeviceSlice, 1) - go iciclegnark_bn254.CopyPointsToDevice(pk.G1.Z, copyZDone) // Make a function for points + padding := make([]curve.G1Affine, 1) + // padding[0] = curve.G1Affine.generator() + Z_plus_point := append(pk.G1.Z, padding...) + go iciclegnark.CopyPointsToDevice(Z_plus_point, copyZDone) // Make a function for points + + /************************* End G1 Device Setup ***************************/ + pk.G1Device.A = <-copyADone + pk.G1Device.B = <-copyBDone + pk.G1Device.K = <-copyKDone pk.G1Device.Z = <-copyZDone - fmt.Printf("end copy pk Z \n") - fmt.Printf("start copy pk G2 B \n") - copyG2BDone := make(chan core.DeviceSlice, 1) + /************************* Start G2 Device Setup ***************************/ pointsBytesB2 := len(pk.G2.B) * fp.Bytes * 4 - go iciclegnark_bn254.CopyG2PointsToDevice(pk.G2.B, pointsBytesB2, copyG2BDone) // Make a function for points + copyG2BDone := make(chan core.DeviceSlice, 1) + go iciclegnark.CopyG2PointsToDevice(pk.G2.B, pointsBytesB2, copyG2BDone) // Make a function for points pk.G2Device.B = <-copyG2BDone - fmt.Printf("end copy pk G2 B \n") - - // ntt config - cfg := iciclewrapper_bn254.GetDefaultNttConfig() - var s iciclewrapper_bn254.ScalarField - - // set pk.Domain.CosetTable[1] - cosetTable, err := pk.Domain.CosetTable() - if err != nil { - return err - } - coset := cosetTable[1] - cosetBits := coset.Bits() - var configCosetGen [8]uint32 - configCosetGenRaw := core.ConvertUint64ArrToUint32Arr(cosetBits[:]) - if len(configCosetGenRaw) != 8 { - return fmt.Errorf("len mismatch: %d != 8", len(configCosetGenRaw)) - } - copy(configCosetGen[:], configCosetGenRaw[:8]) - cfg.CosetGen = configCosetGen - - // domain.Generator - genBits := pk.Domain.Generator.Bits() - s.FromLimbs(core.ConvertUint64ArrToUint32Arr(genBits[:])) - fmt.Printf("start init icicle domain \n") - //iciclewrapper_bn254.InitDomain(s, cfg.Ctx, true) - fmt.Printf("end init icicle domain \n") return nil } @@ -197,16 +186,12 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b return nil, err } + stream, _ := cuda_runtime.CreateStream() + ctx, _ := cuda_runtime.GetDefaultDeviceContext() + ctx.Stream = &stream + // H (witness reduction / FFT part) - var h []fr.Element - chHDone := make(chan struct{}, 1) - go func() { - h = computeH(solution.A, solution.B, solution.C, &pk.Domain) - solution.A = nil - solution.B = nil - solution.C = nil - chHDone <- struct{}{} - }() + h_device := computeHonDevice(solution.A, solution.B, solution.C, &pk.Domain, stream) // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity @@ -257,205 +242,101 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var bs1, ar curve.G1Jac - n := runtime.NumCPU() - - chBs1Done := make(chan error, 1) - computeBS1 := func() { - <-chWireValuesB - if _, merr := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); merr != nil { - chBs1Done <- merr - close(chBs1Done) - return - } + cfg := bn254.GetDefaultMSMConfig() + cfg.Ctx.Stream = &stream + cfg.IsAsync = true - bs1InGpu, gerr := MsmOnDevice(pk.G1Device.B, wireValuesB) - if gerr != nil { - chBs1Done <- gerr - close(chBs1Done) - return - } - var bs1JacInGpu curve.G1Jac - bs1JacInGpu.FromAffine(bs1InGpu) - if bs1JacInGpu.Equal(&bs1) { - fmt.Printf("bs1JacInGpu equal \n") - } else { - fmt.Printf("bs1JacInGpu not equal \n") - } - - bs1.AddMixed(&pk.G1.Beta) - bs1.AddMixed(&deltas[1]) - chBs1Done <- nil + outHost := make(core.HostSlice[bn254.Projective], 1) + var out core.DeviceSlice + out.MallocAsync(outHost.SizeOfElement(), outHost.SizeOfElement(), stream) + + wireValuesBhost := iciclegnark.HostSliceFromScalars(wireValuesB) + var wireValuesBdevice core.DeviceSlice + wireValuesBhost.CopyToDeviceAsync(&wireValuesBdevice, stream, true) + gerr := bn254.Msm(wireValuesBdevice, pk.G1Device.B, &cfg, out) + if gerr != cuda_runtime.CudaSuccess { + fmt.Errorf("Error in MSM: ", gerr) } + outHost.CopyFromDeviceAsync(&out, stream) - chArDone := make(chan error, 1) - computeAR1 := func() { - <-chWireValuesA - if _, merr := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); merr != nil { - chArDone <- merr - close(chArDone) - return - } - - arInGpu, gerr := MsmOnDevice(pk.G1Device.A, wireValuesA) - if gerr != nil { - chArDone <- gerr - close(chArDone) - return - } - var arJacInGpu curve.G1Jac - arJacInGpu.FromAffine(arInGpu) - if arJacInGpu.Equal(&ar) { - fmt.Printf("arJacInGpu equal \n") - } else { - fmt.Printf("arJacInGpu not equal \n") - } + bs1 = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0]) + bs1.AddMixed(&pk.G1.Beta) + bs1.AddMixed(&deltas[1]) - ar.AddMixed(&pk.G1.Alpha) - ar.AddMixed(&deltas[0]) - proof.Ar.FromJacobian(&ar) - chArDone <- nil + wireValuesAhost := iciclegnark.HostSliceFromScalars(wireValuesA) + gerr = bn254.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out) + if gerr != cuda_runtime.CudaSuccess { + fmt.Errorf("Error in MSM: ", gerr) } + outHost.CopyFromDeviceAsync(&out, stream) - chKrsDone := make(chan error, 1) - computeKRS := func() { - // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z - // however, having similar lengths for our tasks helps with parallelism - - var krs, krs2, p1 curve.G1Jac - chKrs2Done := make(chan error, 1) - sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 - go func() { - _, kerr := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) - - krs2InGpu, gerr := MsmOnDevice(pk.G1Device.Z, h[:sizeH]) - if gerr != nil { - chKrsDone <- gerr - return - } - - var krs2JacInGpu curve.G1Jac - krs2JacInGpu.FromAffine(krs2InGpu) - if krs2JacInGpu.Equal(&krs2) { - fmt.Printf("krs2JacInGpu equal \n") - } else { - fmt.Printf("krs2JacInGpu not equal \n") - } - - chKrs2Done <- kerr - }() + ar = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0]) - // filter the wire values if needed - // TODO Perf @Tabaie worst memory allocation offender - toRemove := commitmentInfo.GetPrivateCommitted() - toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) - _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) + ar.AddMixed(&pk.G1.Alpha) + ar.AddMixed(&deltas[0]) + proof.Ar.FromJacobian(&ar) - if _, merr := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); merr != nil { - chKrsDone <- merr - return - } + var krs, krs2, p1 curve.G1Jac + gerr = bn254.Msm(h_device, pk.G1Device.Z, &cfg, out) + if gerr != cuda_runtime.CudaSuccess { + fmt.Errorf("Error in MSM: ", gerr) + } + outHost.CopyFromDeviceAsync(&out, stream) + h_device.FreeAsync(stream) - // TODO - // filter zero/infinity points since icicle doesn't handle them - // See https://github.com/ingonyama-zk/icicle/issues/169 for more info - krsInGpu, gerr := MsmOnDevice(pk.G1Device.K, _wireValues) - if gerr != nil { - chKrsDone <- gerr - return - } + solution.A = nil + solution.B = nil + solution.C = nil - var krsJacInGpu curve.G1Jac - krsJacInGpu.FromAffine(krsInGpu) - if krsJacInGpu.Equal(&krs) { - fmt.Printf("krsJacInGpu equal \n") - } else { - fmt.Printf("krsJacInGpu not equal \n") - } + krs2 = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0]) - krs.AddMixed(&deltas[2]) - n := 3 - for n != 0 { - select { - case err := <-chKrs2Done: - if err != nil { - chKrsDone <- err - return - } - krs.AddAssign(&krs2) - case err := <-chArDone: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&ar, &s) - krs.AddAssign(&p1) - case err := <-chBs1Done: - if err != nil { - chKrsDone <- err - return - } - p1.ScalarMultiplication(&bs1, &r) - krs.AddAssign(&p1) - } - n-- - } + // filter the wire values if needed + // TODO Perf @Tabaie worst memory allocation offender + toRemove := commitmentInfo.GetPrivateCommitted() + toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) + _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - proof.Krs.FromJacobian(&krs) - chKrsDone <- nil + _wireValuesHost := iciclegnark.HostSliceFromScalars(_wireValues) + gerr = bn254.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out) + if gerr != cuda_runtime.CudaSuccess { + fmt.Errorf("Error in MSM: ", gerr) } + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) - computeBS2 := func() error { - // Bs2 (1 multi exp G2 - size = len(wires)) - var Bs, deltaS curve.G2Jac + krs = *iciclegnark.G1ProjectivePointToGnarkJac(&outHost[0]) - nbTasks := n - if nbTasks <= 16 { - // if we don't have a lot of CPUs, this may artificially split the MSM - nbTasks *= 2 - } - <-chWireValuesB - - if _, merr := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); merr != nil { - return merr - } + krs.AddMixed(&deltas[2]) + krs.AddAssign(&krs2) + p1.ScalarMultiplication(&ar, &s) + krs.AddAssign(&p1) + p1.ScalarMultiplication(&bs1, &r) + krs.AddAssign(&p1) - bsInGpu, gerr := G2MsmOnDevice(pk.G2Device.B, wireValuesB) - if gerr != nil { - return gerr - } + proof.Krs.FromJacobian(&krs) - var bsJacInGpu curve.G2Jac - bsJacInGpu.FromAffine(bsInGpu) - if bsJacInGpu.Equal(&Bs) { - fmt.Printf("bsJacInGpu equal \n") - } else { - fmt.Printf("bsJacInGpu not equal \n") - } + // Bs2 (1 multi exp G2 - size = len(wires)) + var Bs, deltaS curve.G2Jac - deltaS.FromAffine(&pk.G2.Delta) - deltaS.ScalarMultiplication(&deltaS, &s) - Bs.AddAssign(&deltaS) - Bs.AddMixed(&pk.G2.Beta) - - proof.Bs.FromJacobian(&Bs) - return nil + outHostG2 := make(core.HostSlice[bn254.G2Projective], 1) + var outG2 core.DeviceSlice + outG2.MallocAsync(outHostG2.SizeOfElement(), outHostG2.SizeOfElement(), stream) + gerr = bn254.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2) + if gerr != cuda_runtime.CudaSuccess { + fmt.Errorf("Error in MSM: ", gerr) } + outHostG2.CopyFromDeviceAsync(&outG2, stream) + outG2.FreeAsync(stream) + wireValuesBdevice.FreeAsync(stream) - // wait for FFT to end, as it uses all our CPUs - <-chHDone + Bs = *iciclegnark.G2PointToGnarkJac(&outHostG2[0]) - // schedule our proof part computations - go computeKRS() - go computeAR1() - go computeBS1() - if err := computeBS2(); err != nil { - return nil, err - } + deltaS.FromAffine(&pk.G2.Delta) + deltaS.ScalarMultiplication(&deltaS, &s) + Bs.AddAssign(&deltaS) + Bs.AddMixed(&pk.G2.Beta) - // wait for all parts of the proof to be computed. - if err := <-chKrsDone; err != nil { - return nil, err - } + proof.Bs.FromJacobian(&Bs) log.Debug().Dur("took", time.Since(start)).Msg("prover done") @@ -536,49 +417,72 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { +func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runtime.Stream) core.DeviceSlice { + cosetGen, _ := fft.Generator(2 * domain.Cardinality) + cosetBits := cosetGen.Bits() + var configCosetGen [8]uint32 + configCosetGenRaw := core.ConvertUint64ArrToUint32Arr(cosetBits[:]) + copy(configCosetGen[:], configCosetGenRaw[:8]) - return nil -} + cfg := bn254.GetDefaultNttConfig() + cfg.IsAsync = true + cfg.Ctx.Stream = &stream -func MsmOnDevice(gnarkPoints core.DeviceSlice, gnarkScalars []fr.Element) (*curve.G1Affine, error) { - fmt.Println("MsmOnDevice with g1 on device") - icicleScalars := iciclegnark_bn254.HostSliceFromScalars(gnarkScalars) + n := len(a) - cfg := core.GetDefaultMSMConfig() - var p iciclewrapper_bn254.Projective - var out core.DeviceSlice - _, e := out.Malloc(p.Size(), p.Size()) - if e != cr.CudaSuccess { - return nil, errors.New("cannot allocate") - } - e = iciclewrapper_bn254.Msm(icicleScalars, gnarkPoints, &cfg, out) - if e != cr.CudaSuccess { - return nil, errors.New("msm failed") - } - outHost := make(core.HostSlice[iciclewrapper_bn254.Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() - return iciclegnark_bn254.ProjectiveToGnarkAffine(&outHost[0]), nil -} + padding := make([]fr.Element, int(domain.Cardinality)-n) + a = append(a, padding...) + b = append(b, padding...) + c = append(c, padding...) + n = len(a) + a_host := iciclegnark.HostSliceFromScalars(a) + b_host := iciclegnark.HostSliceFromScalars(b) + c_host := iciclegnark.HostSliceFromScalars(c) -func G2MsmOnDevice(gnarkPoints core.DeviceSlice, gnarkScalars []fr.Element) (*curve.G2Affine, error) { - fmt.Println("MsmOnDevice with g2 on device") - icicleScalars := core.HostSliceFromElements(iciclegnark_bn254.BatchConvertFromFrGnark(gnarkScalars)) + var a_device core.DeviceSlice + var b_device core.DeviceSlice + var c_device core.DeviceSlice + a_host.CopyToDeviceAsync(&a_device, stream, true) + b_host.CopyToDeviceAsync(&b_device, stream, true) + c_host.CopyToDeviceAsync(&c_device, stream, true) - cfg := core.GetDefaultMSMConfig() - var p iciclewrapper_bn254.G2Projective - var out core.DeviceSlice - _, e := out.Malloc(p.Size(), p.Size()) - if e != cr.CudaSuccess { - return nil, errors.New("Cannot allocate g2") - } - e = iciclewrapper_bn254.G2Msm(icicleScalars, gnarkPoints, &cfg, out) - if e != cr.CudaSuccess { - return nil, errors.New("Msm g2 failed") + cfg.Ordering = core.KNM + + bn254.Ntt(a_device, core.KInverse, &cfg, a_device) + bn254.Ntt(b_device, core.KInverse, &cfg, b_device) + bn254.Ntt(c_device, core.KInverse, &cfg, c_device) + + cfg.CosetGen = configCosetGen + cfg.Ordering = core.KMN + + bn254.Ntt(a_device, core.KForward, &cfg, a_device) + bn254.Ntt(b_device, core.KForward, &cfg, b_device) + bn254.Ntt(c_device, core.KForward, &cfg, c_device) + + var den, one fr.Element + one.SetOne() + den.Exp(cosetGen, big.NewInt(int64(domain.Cardinality))) + den.Sub(&den, &one).Inverse(&den) + den_repeated := make([]fr.Element, n) + for i := 0; i < n; i++ { + den_repeated[i] = den } - outHost := make(core.HostSlice[iciclewrapper_bn254.G2Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() - return iciclegnark_bn254.G2PointToGnarkAffine(&outHost[0]), nil + den_host := iciclegnark.HostSliceFromScalars(den_repeated) + + vcfg := core.DefaultVecOpsConfig() + cfg.IsAsync = true + vcfg.Ctx.Stream = &stream + + // h = ifft_coset(ca o cb - cc) + bn254.VecOp(a_device, b_device, a_device, vcfg, core.Mul) + bn254.VecOp(a_device, c_device, a_device, vcfg, core.Sub) + den_host.CopyToDeviceAsync(&b_device, stream, false) + bn254.VecOp(a_device, b_device, a_device, vcfg, core.Mul) + cfg.Ordering = core.KNR + + // ifft_coset + bn254.Ntt(a_device, core.KInverse, &cfg, a_device) + b_device.FreeAsync(stream) + c_device.FreeAsync(stream) + return a_device } diff --git a/go.mod b/go.mod index 603c90b050..7b1830cc3d 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b github.com/icza/bitio v1.1.0 + github.com/ingonyama-zk/icicle v1.9.1 github.com/ingonyama-zk/iciclegnark v0.1.2-0.20240329204201-5a05ea507886 github.com/leanovate/gopter v0.2.9 github.com/rs/zerolog v1.30.0 @@ -23,7 +24,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/ingonyama-zk/icicle v1.9.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mmcloughlin/addchain v0.4.0 // indirect diff --git a/testgpu/gpu_test.go b/testgpu/gpu_test.go index a398e2b945..e821abc06c 100644 --- a/testgpu/gpu_test.go +++ b/testgpu/gpu_test.go @@ -1,6 +1,10 @@ package testgpu import ( + "math/big" + "os" + "testing" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/backend/groth16" @@ -12,9 +16,6 @@ import ( regroth16 "github.com/consensys/gnark/std/recursion/groth16" "github.com/consensys/gnark/test" "github.com/rs/zerolog" - "math/big" - "os" - "testing" ) func TestBn254Gpu(t *testing.T) {