Skip to content

Commit

Permalink
C & Rust elliptic bindings: batchAffine, scalarMul, serial MSM and 32…
Browse files Browse the repository at this point in the history
…-bit Rust bindings (#332)

* bindings: batchAffine, scalarMul, serial MSM

* bindings: Rust supports 32-bit platforms

* workaround nim v2 gensym in templates: no {.noInit.} and no G2 scalarmul

* workaround nim v1.6 input shadowing issue
  • Loading branch information
mratsim committed Jan 5, 2024
1 parent 077cfb2 commit df9034b
Show file tree
Hide file tree
Showing 16 changed files with 10,056 additions and 25 deletions.
63 changes: 60 additions & 3 deletions bindings/c_curve_decls.nim
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

import
../constantine/math/config/curves,
../constantine/curves_primitives
../constantine/curves_primitives,

export curves, curves_primitives
../constantine/math/extension_fields # generic sandwich

export curves, curves_primitives, extension_fields

# Overview
# ------------------------------------------------------------
Expand Down Expand Up @@ -303,7 +305,9 @@ template genBindings_EC_ShortW_Affine*(ECP, Field: untyped) =

{.pop.}

template genBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff: untyped) =
template genBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff, ScalarBig, ScalarField: untyped) =
# TODO: remove the need of explicit ScalarBig and ScalarField

when appType == "lib":
{.push noconv, dynlib, exportc, raises: [].} # No exceptions allowed
else:
Expand Down Expand Up @@ -352,4 +356,57 @@ template genBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff: untyped) =
func `ctt _ ECP _ from_affine`(dst: var ECP, src: ECP_Aff) =
dst.fromAffine(src)

func `ctt _ ECP _ batch_affine`(dst: ptr UncheckedArray[ECP_Aff], src: ptr UncheckedArray[ECP], n: csize_t) =
dst.batchAffine(src, cast[int](n))

when ECP.G == G1:
# Workaround gensym issue in templates like mulCheckSparse
# for {.noInit.} temporaries and probably generic sandwich

func `ctt _ ECP _ scalar_mul_big_coef`(
P: var ECP, scalar: ScalarBig) =

P.scalarMul(scalar)

func `ctt _ ECP _ scalar_mul_fr_coef`(
P: var ECP, scalar: ScalarField) =

var big: ScalarBig # TODO: {.noInit.}
big.fromField(scalar)
P.scalarMul(big)

func `ctt _ ECP _ scalar_mul_big_coef_vartime`(
P: var ECP, scalar: ScalarBig) =

P.scalarMul_vartime(scalar)

func `ctt _ ECP _ scalar_mul_fr_coef_vartime`(
P: var ECP, scalar: ScalarField) =

var big: ScalarBig # TODO: {.noInit.}
big.fromField(scalar)
P.scalarMul_vartime(big)

func `ctt _ ECP _ multi_scalar_mul_big_coefs_vartime`(
r: var ECP,
coefs: ptr UncheckedArray[ScalarBig],
points: ptr UncheckedArray[ECP_Aff],
len: csize_t) =
r.multiScalarMul_vartime(coefs, points, cast[int](len))

func `ctt _ ECP _ multi_scalar_mul_fr_coefs_vartime`(
r: var ECP,
coefs: ptr UncheckedArray[ScalarField],
points: ptr UncheckedArray[ECP_Aff],
len: csize_t)=

let n = cast[int](len)
let coefs_fr = allocHeapArrayAligned(ScalarBig, n, alignment = 64)

for i in 0 ..< n:
coefs_fr[i].fromField(coefs[i])
r.multiScalarMul_vartime(coefs_fr, points, n)

freeHeapAligned(coefs_fr)

{.pop.}
2 changes: 2 additions & 0 deletions bindings/c_curve_decls_parallel.nim
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import
export curves_primitives_parallel

template genParallelBindings_EC_ShortW_NonAffine*(ECP, ECP_Aff, ScalarField: untyped) =
# TODO: remove the need of explicit ScalarField

# For some unknown reason {.push noconv.}
# would overwrite the threadpool {.nimcall.}
# in the parallel for-loop `generateClosure`
Expand Down
3 changes: 2 additions & 1 deletion bindings/c_typedefs.nim
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ proc toCparam*(name: string, typ: NimNode): string =
if typ.kind == nnkCall:
typ[0].expectKind(nnkOpenSymChoice)
doAssert typ[0][0].eqIdent"[]"
doAssert typ[1].eqIdent"openArray"
doAssert typ[1].eqIdent"openArray", block:
typ.treeRepr()
let sTyp = $typ[2]
if sTyp in TypeMap:
"const " & TypeMap[sTyp] & " " & name & "[], ptrdiff_t " & name & "_len"
Expand Down
28 changes: 16 additions & 12 deletions bindings/lib_curves.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import
./c_curve_decls_parallel
export c_curve_decls, c_curve_decls_parallel

type
big254 = BigInt[254]
big255 = BigInt[255]

# ----------------------------------------------------------

type
Expand All @@ -38,11 +42,11 @@ collectBindings(cBindings_bls12_381):
genBindingsExtField(bls12_381_fp2)
genBindingsExtFieldSqrt(bls12_381_fp2)
genBindings_EC_ShortW_Affine(bls12_381_g1_aff, bls12_381_fp)
genBindings_EC_ShortW_NonAffine(bls12_381_g1_jac, bls12_381_g1_aff)
genBindings_EC_ShortW_NonAffine(bls12_381_g1_prj, bls12_381_g1_aff)
genBindings_EC_ShortW_NonAffine(bls12_381_g1_jac, bls12_381_g1_aff, big255, bls12_381_fr)
genBindings_EC_ShortW_NonAffine(bls12_381_g1_prj, bls12_381_g1_aff, big255, bls12_381_fr)
genBindings_EC_ShortW_Affine(bls12_381_g2_aff, bls12_381_fp2)
genBindings_EC_ShortW_NonAffine(bls12_381_g2_jac, bls12_381_g2_aff)
genBindings_EC_ShortW_NonAffine(bls12_381_g2_prj, bls12_381_g2_aff)
genBindings_EC_ShortW_NonAffine(bls12_381_g2_jac, bls12_381_g2_aff, big255, bls12_381_fr)
genBindings_EC_ShortW_NonAffine(bls12_381_g2_prj, bls12_381_g2_aff, big255, bls12_381_fr)

collectBindings(cBindings_bls12_381_parallel):
genParallelBindings_EC_ShortW_NonAffine(bls12_381_g1_jac, bls12_381_g1_aff, bls12_381_fr)
Expand All @@ -67,11 +71,11 @@ collectBindings(cBindings_bn254_snarks):
genBindingsExtField(bn254_snarks_fp2)
genBindingsExtFieldSqrt(bn254_snarks_fp2)
genBindings_EC_ShortW_Affine(bn254_snarks_g1_aff, bn254_snarks_fp)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g1_jac, bn254_snarks_g1_aff)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g1_prj, bn254_snarks_g1_aff)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g1_jac, bn254_snarks_g1_aff, big254, bn254_snarks_fr)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g1_prj, bn254_snarks_g1_aff, big254, bn254_snarks_fr)
genBindings_EC_ShortW_Affine(bn254_snarks_g2_aff, bn254_snarks_fp2)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g2_jac, bn254_snarks_g2_aff)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g2_prj, bn254_snarks_g2_aff)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g2_jac, bn254_snarks_g2_aff, big254, bn254_snarks_fr)
genBindings_EC_ShortW_NonAffine(bn254_snarks_g2_prj, bn254_snarks_g2_aff, big254, bn254_snarks_fr)

collectBindings(cBindings_bn254_snarks_parallel):
genParallelBindings_EC_ShortW_NonAffine(bn254_snarks_g1_jac, bn254_snarks_g1_aff, bn254_snarks_fr)
Expand All @@ -91,8 +95,8 @@ collectBindings(cBindings_pallas):
genBindingsField(pallas_fp)
genBindingsFieldSqrt(pallas_fp)
genBindings_EC_ShortW_Affine(pallas_ec_aff, pallas_fp)
genBindings_EC_ShortW_NonAffine(pallas_ec_jac, pallas_ec_aff)
genBindings_EC_ShortW_NonAffine(pallas_ec_prj, pallas_ec_aff)
genBindings_EC_ShortW_NonAffine(pallas_ec_jac, pallas_ec_aff, big255, pallas_fr)
genBindings_EC_ShortW_NonAffine(pallas_ec_prj, pallas_ec_aff, big255, pallas_fr)

collectBindings(cBindings_pallas_parallel):
genParallelBindings_EC_ShortW_NonAffine(pallas_ec_jac, pallas_ec_aff, pallas_fr)
Expand All @@ -110,8 +114,8 @@ collectBindings(cBindings_vesta):
genBindingsField(vesta_fp)
genBindingsFieldSqrt(vesta_fp)
genBindings_EC_ShortW_Affine(vesta_ec_aff, vesta_fp)
genBindings_EC_ShortW_NonAffine(vesta_ec_jac, vesta_ec_aff)
genBindings_EC_ShortW_NonAffine(vesta_ec_prj, vesta_ec_aff)
genBindings_EC_ShortW_NonAffine(vesta_ec_jac, vesta_ec_aff, big255, vesta_fr)
genBindings_EC_ShortW_NonAffine(vesta_ec_prj, vesta_ec_aff, big255, vesta_fr)

collectBindings(cBindings_vesta_parallel):
genParallelBindings_EC_ShortW_NonAffine(vesta_ec_jac, vesta_ec_aff, vesta_fr)
Expand Down
4 changes: 2 additions & 2 deletions bindings/lib_headers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ proc writeHeader_classicCurve(filepath: string, curve: string, modBits, orderBit
header &= curve_decls
header &= '\n'

header = "\n" & genCpp(header)
header = "#include \"constantine/curves/bigints.h\"\n\n" & genCpp(header)
header = genHeaderGuardAndInclude(curve.toUpperASCII(), header)
header = genHeaderLicense() & header

Expand Down Expand Up @@ -63,7 +63,7 @@ proc writeHeader_pairingFriendly(filepath: string, curve: string, modBits, order
header &= curve_decls
header &= '\n'

header = "\n" & genCpp(header)
header = "#include \"constantine/curves/bigints.h\"\n\n" & genCpp(header)
header = genHeaderGuardAndInclude(curve.toUpperASCII(), header)
header = genHeaderLicense() & header

Expand Down
Loading

0 comments on commit df9034b

Please sign in to comment.