Skip to content

Commit

Permalink
Support new icicle with cpu compare icicle 2.0.1 (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxiaobleach authored Apr 28, 2024
1 parent 42f15e7 commit 515678f
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 313 deletions.
48 changes: 26 additions & 22 deletions backend/groth16/bls12-377/icicle/icicle.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ import (
fcs "github.com/consensys/gnark/frontend/cs"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
"github.com/ingonyama-zk/icicle/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/g2"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/msm"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/ntt"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/vecOps"
iciclegnark "github.com/ingonyama-zk/iciclegnark/curves/bls12377"
)

Expand Down Expand Up @@ -60,7 +64,7 @@ func (pk *ProvingKey) setupDevicePointers() error {
gen, _ := fft.Generator(2 * pk.Domain.Cardinality)
genBits := gen.Bits()
s.FromLimbs(core.ConvertUint64ArrToUint32Arr(genBits[:]))
bls12377.InitDomain(s, ctx, false)
ntt.InitDomain(s, ctx, false)

/************************* Start G1 Device Setup ***************************/
/************************* A ***************************/
Expand Down Expand Up @@ -265,7 +269,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

var bs1, ar curve.G1Jac

cfg := bls12377.GetDefaultMSMConfig()
cfg := msm.GetDefaultMSMConfig()
cfg.Ctx.Stream = &stream
cfg.IsAsync = true

Expand All @@ -277,7 +281,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
wireValuesBhost := iciclegnark.HostSliceFromScalars(wireValuesB)
var wireValuesBdevice core.DeviceSlice
wireValuesBhost.CopyToDeviceAsync(&wireValuesBdevice, stream, true)
gerr := bls12377.Msm(wireValuesBdevice, pk.G1Device.B, &cfg, out)
gerr := msm.Msm(wireValuesBdevice, pk.G1Device.B, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM b: %v", gerr)
}
Expand All @@ -290,10 +294,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
// Bs2 (1 multi exp G2 - size = len(wires))
var Bs, deltaS curve.G2Jac

outHostG2 := make(core.HostSlice[bls12377.G2Projective], 1)
outHostG2 := make(core.HostSlice[g2.G2Projective], 1)
var outG2 core.DeviceSlice
outG2.MallocAsync(outHostG2.SizeOfElement(), outHostG2.SizeOfElement(), stream)
gerr = bls12377.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
gerr = g2.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM g2 b: %v", gerr)
}
Expand All @@ -312,7 +316,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

<-chWireValuesA
wireValuesAhost := iciclegnark.HostSliceFromScalars(wireValuesA)
gerr = bls12377.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out)
gerr = msm.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM a: %v", gerr)
}
Expand All @@ -325,7 +329,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
proof.Ar.FromJacobian(&ar)

var krs, krs2, p1 curve.G1Jac
gerr = bls12377.Msm(h_device, pk.G1Device.Z, &cfg, out)
gerr = msm.Msm(h_device, pk.G1Device.Z, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM z: %v", gerr)
}
Expand All @@ -340,7 +344,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

<-chWireValues
_wireValuesHost := iciclegnark.HostSliceFromScalars(_wireValues)
gerr = bls12377.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out)
gerr = msm.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM k: %v", gerr)
}
Expand Down Expand Up @@ -446,7 +450,7 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
configCosetGenRaw := core.ConvertUint64ArrToUint32Arr(cosetBits[:])
copy(configCosetGen[:], configCosetGenRaw[:8])

cfg := bls12377.GetDefaultNttConfig()
cfg := ntt.GetDefaultNttConfig()
cfg.IsAsync = true
cfg.Ctx.Stream = &stream

Expand All @@ -470,16 +474,16 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt

cfg.Ordering = core.KNM

bls12377.Ntt(a_device, core.KInverse, &cfg, a_device)
bls12377.Ntt(b_device, core.KInverse, &cfg, b_device)
bls12377.Ntt(c_device, core.KInverse, &cfg, c_device)
ntt.Ntt(a_device, core.KInverse, &cfg, a_device)
ntt.Ntt(b_device, core.KInverse, &cfg, b_device)
ntt.Ntt(c_device, core.KInverse, &cfg, c_device)

cfg.CosetGen = configCosetGen
cfg.Ordering = core.KMN

bls12377.Ntt(a_device, core.KForward, &cfg, a_device)
bls12377.Ntt(b_device, core.KForward, &cfg, b_device)
bls12377.Ntt(c_device, core.KForward, &cfg, c_device)
ntt.Ntt(a_device, core.KForward, &cfg, a_device)
ntt.Ntt(b_device, core.KForward, &cfg, b_device)
ntt.Ntt(c_device, core.KForward, &cfg, c_device)

var den, one fr.Element
one.SetOne()
Expand All @@ -495,14 +499,14 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
vcfg.Ctx.Stream = &stream

// h = ifft_coset(ca o cb - cc)
bls12377.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
bls12377.VecOp(a_device, c_device, a_device, vcfg, core.Sub)
vecOps.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
vecOps.VecOp(a_device, c_device, a_device, vcfg, core.Sub)
den_host.CopyToDeviceAsync(&b_device, stream, false)
bls12377.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
vecOps.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
cfg.Ordering = core.KNR

// ifft_coset
bls12377.Ntt(a_device, core.KInverse, &cfg, a_device)
ntt.Ntt(a_device, core.KInverse, &cfg, a_device)
b_device.FreeAsync(stream)
c_device.FreeAsync(stream)
return a_device
Expand Down
2 changes: 1 addition & 1 deletion backend/groth16/bls12-377/icicle/provingkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package icicle_bls12377
import (
groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377"
cs "github.com/consensys/gnark/constraint/bls12-377"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
"io"
)

Expand Down
Loading

0 comments on commit 515678f

Please sign in to comment.