Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new icicle with cpu compare icicle 2.0.1 #22

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions backend/groth16/bls12-377/icicle/icicle.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ import (
fcs "github.com/consensys/gnark/frontend/cs"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/logger"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
"github.com/ingonyama-zk/icicle/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/g2"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/msm"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/ntt"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bls12377/vecOps"
iciclegnark "github.com/ingonyama-zk/iciclegnark/curves/bls12377"
)

Expand Down Expand Up @@ -60,7 +64,7 @@ func (pk *ProvingKey) setupDevicePointers() error {
gen, _ := fft.Generator(2 * pk.Domain.Cardinality)
genBits := gen.Bits()
s.FromLimbs(core.ConvertUint64ArrToUint32Arr(genBits[:]))
bls12377.InitDomain(s, ctx, false)
ntt.InitDomain(s, ctx, false)

/************************* Start G1 Device Setup ***************************/
/************************* A ***************************/
Expand Down Expand Up @@ -265,7 +269,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

var bs1, ar curve.G1Jac

cfg := bls12377.GetDefaultMSMConfig()
cfg := msm.GetDefaultMSMConfig()
cfg.Ctx.Stream = &stream
cfg.IsAsync = true

Expand All @@ -277,7 +281,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
wireValuesBhost := iciclegnark.HostSliceFromScalars(wireValuesB)
var wireValuesBdevice core.DeviceSlice
wireValuesBhost.CopyToDeviceAsync(&wireValuesBdevice, stream, true)
gerr := bls12377.Msm(wireValuesBdevice, pk.G1Device.B, &cfg, out)
gerr := msm.Msm(wireValuesBdevice, pk.G1Device.B, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM b: %v", gerr)
}
Expand All @@ -290,10 +294,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
// Bs2 (1 multi exp G2 - size = len(wires))
var Bs, deltaS curve.G2Jac

outHostG2 := make(core.HostSlice[bls12377.G2Projective], 1)
outHostG2 := make(core.HostSlice[g2.G2Projective], 1)
var outG2 core.DeviceSlice
outG2.MallocAsync(outHostG2.SizeOfElement(), outHostG2.SizeOfElement(), stream)
gerr = bls12377.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
gerr = g2.G2Msm(wireValuesBdevice, pk.G2Device.B, &cfg, outG2)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM g2 b: %v", gerr)
}
Expand All @@ -312,7 +316,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

<-chWireValuesA
wireValuesAhost := iciclegnark.HostSliceFromScalars(wireValuesA)
gerr = bls12377.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out)
gerr = msm.Msm(wireValuesAhost, pk.G1Device.A, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM a: %v", gerr)
}
Expand All @@ -325,7 +329,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
proof.Ar.FromJacobian(&ar)

var krs, krs2, p1 curve.G1Jac
gerr = bls12377.Msm(h_device, pk.G1Device.Z, &cfg, out)
gerr = msm.Msm(h_device, pk.G1Device.Z, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM z: %v", gerr)
}
Expand All @@ -340,7 +344,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

<-chWireValues
_wireValuesHost := iciclegnark.HostSliceFromScalars(_wireValues)
gerr = bls12377.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out)
gerr = msm.Msm(_wireValuesHost, pk.G1Device.K, &cfg, out)
if gerr != cuda_runtime.CudaSuccess {
return nil, fmt.Errorf("error in MSM k: %v", gerr)
}
Expand Down Expand Up @@ -446,7 +450,7 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
configCosetGenRaw := core.ConvertUint64ArrToUint32Arr(cosetBits[:])
copy(configCosetGen[:], configCosetGenRaw[:8])

cfg := bls12377.GetDefaultNttConfig()
cfg := ntt.GetDefaultNttConfig()
cfg.IsAsync = true
cfg.Ctx.Stream = &stream

Expand All @@ -470,16 +474,16 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt

cfg.Ordering = core.KNM

bls12377.Ntt(a_device, core.KInverse, &cfg, a_device)
bls12377.Ntt(b_device, core.KInverse, &cfg, b_device)
bls12377.Ntt(c_device, core.KInverse, &cfg, c_device)
ntt.Ntt(a_device, core.KInverse, &cfg, a_device)
ntt.Ntt(b_device, core.KInverse, &cfg, b_device)
ntt.Ntt(c_device, core.KInverse, &cfg, c_device)

cfg.CosetGen = configCosetGen
cfg.Ordering = core.KMN

bls12377.Ntt(a_device, core.KForward, &cfg, a_device)
bls12377.Ntt(b_device, core.KForward, &cfg, b_device)
bls12377.Ntt(c_device, core.KForward, &cfg, c_device)
ntt.Ntt(a_device, core.KForward, &cfg, a_device)
ntt.Ntt(b_device, core.KForward, &cfg, b_device)
ntt.Ntt(c_device, core.KForward, &cfg, c_device)

var den, one fr.Element
one.SetOne()
Expand All @@ -495,14 +499,14 @@ func computeHonDevice(a, b, c []fr.Element, domain *fft.Domain, stream cuda_runt
vcfg.Ctx.Stream = &stream

// h = ifft_coset(ca o cb - cc)
bls12377.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
bls12377.VecOp(a_device, c_device, a_device, vcfg, core.Sub)
vecOps.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
vecOps.VecOp(a_device, c_device, a_device, vcfg, core.Sub)
den_host.CopyToDeviceAsync(&b_device, stream, false)
bls12377.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
vecOps.VecOp(a_device, b_device, a_device, vcfg, core.Mul)
cfg.Ordering = core.KNR

// ifft_coset
bls12377.Ntt(a_device, core.KInverse, &cfg, a_device)
ntt.Ntt(a_device, core.KInverse, &cfg, a_device)
b_device.FreeAsync(stream)
c_device.FreeAsync(stream)
return a_device
Expand Down
2 changes: 1 addition & 1 deletion backend/groth16/bls12-377/icicle/provingkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package icicle_bls12377
import (
groth16_bls12377 "github.com/consensys/gnark/backend/groth16/bls12-377"
cs "github.com/consensys/gnark/constraint/bls12-377"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
"github.com/ingonyama-zk/icicle/v2/wrappers/golang/core"
"io"
)

Expand Down
Loading
Loading