diff --git a/backend/groth16/bw6-761/prove.go b/backend/groth16/bw6-761/prove.go index 179eacc065..858380cfe9 100644 --- a/backend/groth16/bw6-761/prove.go +++ b/backend/groth16/bw6-761/prove.go @@ -106,9 +106,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // H (witness reduction / FFT part) var hOnDevice unsafe.Pointer + var h_err error chHDone := make(chan struct{}, 1) go func() { - hOnDevice = computeHOnDevice(solution.A, solution.B, solution.C, pk) + hOnDevice, h_err = computeHOnDevice(solution.A, solution.B, solution.C, pk) solution.A = nil solution.B = nil solution.C = nil @@ -118,47 +119,59 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // we need to copy and filter the wireValues for each multi exp // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity var wireValuesADevice, wireValuesBDevice OnDeviceData + var wireValuesADeviceErr, wireValuesBDeviceErr error chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) go func() { - wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) - for i, j := 0, 0; j < len(wireValuesA); i++ { - if pk.InfinityA[i] { - continue - } - wireValuesA[j] = wireValues[i] - j++ - } - - wireValuesASize := len(wireValuesA) - scalarBytes := wireValuesASize * fr.Bytes - wireValuesADevicePtr, _ := goicicle.CudaMalloc(scalarBytes) - goicicle.CudaMemCpyHtoD[fr.Element](wireValuesADevicePtr, wireValuesA, scalarBytes) - MontConvOnDevice(wireValuesADevicePtr, wireValuesASize, false) - wireValuesADevice = OnDeviceData{wireValuesADevicePtr, wireValuesASize} - + wireValuesADevice, wireValuesADeviceErr = PrepareWireValueOnDevice(wireValues, pk.NbInfinityA, pk.InfinityA) close(chWireValuesA) }() - go func() { - wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) - for i, j := 0, 0; j < len(wireValuesB); i++ { - if pk.InfinityB[i] { - continue - } - wireValuesB[j] = wireValues[i] - j++ - } - wireValuesBSize := len(wireValuesB) - scalarBytes := wireValuesBSize * fr.Bytes - wireValuesBDevicePtr, _ := goicicle.CudaMalloc(scalarBytes) - goicicle.CudaMemCpyHtoD[fr.Element](wireValuesBDevicePtr, wireValuesB, scalarBytes) - MontConvOnDevice(wireValuesBDevicePtr, wireValuesBSize, false) - wireValuesBDevice = OnDeviceData{wireValuesBDevicePtr, wireValuesBSize} + // go func() { + // wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) + // for i, j := 0, 0; j < len(wireValuesA); i++ { + // if pk.InfinityA[i] { + // continue + // } + // wireValuesA[j] = wireValues[i] + // j++ + // } + + // wireValuesASize := len(wireValuesA) + // scalarBytes := wireValuesASize * fr.Bytes + // wireValuesADevicePtr, _ := goicicle.CudaMalloc(scalarBytes) + // goicicle.CudaMemCpyHtoD[fr.Element](wireValuesADevicePtr, wireValuesA, scalarBytes) + // MontConvOnDevice(wireValuesADevicePtr, wireValuesASize, false) + // wireValuesADevice = OnDeviceData{wireValuesADevicePtr, wireValuesASize} + + // close(chWireValuesA) + // }() + go func() { + wireValuesBDevice, wireValuesBDeviceErr = PrepareWireValueOnDevice(wireValues, pk.NbInfinityB, pk.InfinityB) close(chWireValuesB) }() + // go func() { + // wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) + // for i, j := 0, 0; j < len(wireValuesB); i++ { + // if pk.InfinityB[i] { + // continue + // } + // wireValuesB[j] = wireValues[i] + // j++ + // } + + // wireValuesBSize := len(wireValuesB) + // scalarBytes := wireValuesBSize * fr.Bytes + // wireValuesBDevicePtr, _ := goicicle.CudaMalloc(scalarBytes) + // goicicle.CudaMemCpyHtoD[fr.Element](wireValuesBDevicePtr, wireValuesB, scalarBytes) + // MontConvOnDevice(wireValuesBDevicePtr, wireValuesBSize, false) + // wireValuesBDevice = OnDeviceData{wireValuesBDevicePtr, wireValuesBSize} + + // close(chWireValuesB) + // }() + // sample random r and s var r, s big.Int var _r, _s, _kr fr.Element @@ -266,13 +279,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // wait for FFT to end, as it uses all our CPUs <-chHDone - // schedule our proof part computations - // go computeKRS() - // go computeAR1() - // go computeBS1() - // if err = computeBS2(); err != nil { - // return nil, err - // } + if h_err != nil { + return nil, h_err + } + if wireValuesADeviceErr != nil { + return nil, wireValuesADeviceErr + } + if wireValuesBDeviceErr != nil { + return nil, wireValuesBDeviceErr + } startMSM := time.Now() computeBS1() @@ -361,7 +376,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { return a } -func computeHOnDevice(a, b, c []fr.Element, pk *ProvingKey) unsafe.Pointer { +func computeHOnDevice(a, b, c []fr.Element, pk *ProvingKey) (unsafe.Pointer, error) { // H part of Krs // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) @@ -442,5 +457,30 @@ func computeHOnDevice(a, b, c []fr.Element, pk *ProvingKey) unsafe.Pointer { } log.Debug().Dur("took", time.Since(computeHTime)).Msg("Icicle API: computeH") - return h + return h, nil +} + +func PrepareWireValueOnDevice(wireValues []fr.Element, nbInfinityA uint64, infinityA []bool) (data OnDeviceData, err error) { + wireValuesA := make([]fr.Element, len(wireValues)-int(nbInfinityA)) + for i, j := 0, 0; j < len(wireValuesA); i++ { + if infinityA[i] { + continue + } + wireValuesA[j] = wireValues[i] + j++ + } + + data.size = len(wireValuesA) + scalarBytes := data.size * fr.Bytes + if data.p, err = goicicle.CudaMalloc(scalarBytes); err != nil { + return + } + if ret := goicicle.CudaMemCpyHtoD[fr.Element](data.p, wireValuesA, scalarBytes); ret != 0 { + err = fmt.Errorf("CudaMemCpyHtoD in PrepareWireValueOnDevice %d", ret) + return + } + if err = MontConvOnDevice(data.p, data.size, false); err != nil { + return + } + return }