Skip to content

Commit

Permalink
properly parametrize skyscrapers lookup table size
Browse files Browse the repository at this point in the history
  • Loading branch information
kustosz committed Dec 3, 2024
1 parent be9fa81 commit bc6a8cd
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 98 deletions.
39 changes: 3 additions & 36 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ import (
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/uints"
gnark_nimue "github.com/reilabs/gnark-nimue"
"github.com/reilabs/gnark-nimue/hash"
"math/bits"
)

type TestCircuit struct {
Expand Down Expand Up @@ -192,7 +190,7 @@ type Manhattan struct {
}

func (c *Manhattan) Define(api frontend.API) error {
s := hash.NewSkyscraper(api)
s := hash.NewSkyscraper(api, 1)
a := c.I
for range 3000 {
a = s.Compress(a, a)
Expand All @@ -208,41 +206,10 @@ func ExampleManhattan() {
fmt.Println(err)
return
}
pk, vk, _ := groth16.Setup(ccs)
assignment := Manhattan{
I: 1,
O: 1000,
}
witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField())
publicWitness, _ := witness.Public()

proof, _ := groth16.Prove(ccs, pk, witness)
vErr := groth16.Verify(proof, vk, publicWitness)
fmt.Printf("%v\n", vErr)
}

type TestLookup struct {
In frontend.Variable
}
fmt.Printf("constraints: %d\n", ccs.GetNbConstraints())

func (c *TestLookup) Define(api frontend.API) error {
table := logderivlookup.New(api)
for i := range 256 {
table.Insert(bits.RotateLeft8(uint8(i), 3))
}
c0 := c.In
for range 256 {
c0 = table.Lookup(c0)[0]
}
api.AssertIsEqual(c0, c.In)
return nil
}

func main() {
ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &TestLookup{})
fmt.Printf("constraints: %d\n", ccs.GetNbConstraints())

//Example1()
//ExampleWhir()
//ExampleManhattan()
ExampleManhattan()
}
88 changes: 49 additions & 39 deletions hash/skyscraper.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,25 @@ import (
"math/bits"
)

func bytesBeHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
func wordsBeHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if field.Cmp(ecc.BN254.ScalarField()) != 0 {
return fmt.Errorf("bytesHint: expected BN254 Fr, got %s", field)
}
if len(inputs) != 1 {
return fmt.Errorf("bytesHint: expected 1 input, got %d", len(inputs))
if len(inputs) != 2 {
return fmt.Errorf("bytesHint: expected 2 inputs, got %d", len(inputs))
}
if len(outputs) != 16 {
return fmt.Errorf("bytesHint: expected 32 outputs, got %d", len(outputs))
wordLen := int(inputs[0].Int64())
if len(outputs) != 32/wordLen {
return fmt.Errorf("bytesHint: expected %d outputs, got %d", 32/wordLen, len(outputs))
}
bytes := make([]byte, 32)
inputs[0].FillBytes(bytes)
inputs[1].FillBytes(bytes)
for i, o := range outputs {
o.SetUint64(uint64(bytes[2*i]))
o.Mul(o, big.NewInt(256))
o.Add(o, big.NewInt(int64(bytes[2*i+1])))
o.SetUint64(0)
for j := range wordLen {
o.Mul(o, big.NewInt(256))
o.Add(o, big.NewInt(int64(bytes[wordLen*i+j])))
}
}
return nil
}
Expand All @@ -48,16 +51,17 @@ func gtHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error {
}

func init() {
solver.RegisterHint(bytesBeHint)
solver.RegisterHint(wordsBeHint)
solver.RegisterHint(gtHint)
}

type Skyscraper struct {
rc [8]big.Int
sigma big.Int
sboxT *logderivlookup.Table
rchk frontend.Rangechecker
api frontend.API
rc [8]big.Int
sigma big.Int
sboxT *logderivlookup.Table
rchk frontend.Rangechecker
wordSize int
api frontend.API
}

func sboxByte(b byte) byte {
Expand All @@ -67,19 +71,22 @@ func sboxByte(b byte) byte {
return bits.RotateLeft8(b^(x&y&z), 1)
}

func initSbox(api frontend.API) *logderivlookup.Table {
func initSbox(api frontend.API, wordSize int) *logderivlookup.Table {
t := logderivlookup.New(api)
for i := range 65536 {
w := uint16(i)
b1 := byte(w & 0xff)
b2 := byte(w >> 8)
r := uint16(sboxByte(b1)) | (uint16(sboxByte(b2)) << 8)
tableSize := 1 << (8 * wordSize)
for i := range tableSize {
r := uint64(0)
for j := range wordSize {
shiftSize := j * 8
inpByte := byte((i >> shiftSize) & 0xff)
r |= uint64(sboxByte(inpByte)) << shiftSize
}
t.Insert(r)
}
return t
}

func NewSkyscraper(api frontend.API) *Skyscraper {
func NewSkyscraper(api frontend.API, wordSize int) *Skyscraper {
rc := [8]big.Int{}
rc[0].SetString("17829420340877239108687448009732280677191990375576158938221412342251481978692", 10)
rc[1].SetString("5852100059362614845584985098022261541909346143980691326489891671321030921585", 10)
Expand All @@ -95,8 +102,9 @@ func NewSkyscraper(api frontend.API) *Skyscraper {
return &Skyscraper{
rc,
sigma,
initSbox(api),
initSbox(api, wordSize),
rangecheck.New(api),
wordSize,
api,
}
}
Expand All @@ -109,10 +117,10 @@ func (s *Skyscraper) square(v frontend.Variable) frontend.Variable {
return s.api.Mul(s.api.Mul(v, v), s.sigma)
}

func (s *Skyscraper) varFromBytesBe(bytes []frontend.Variable) frontend.Variable {
func (s *Skyscraper) varFromWordsBe(words []frontend.Variable) frontend.Variable {
result := frontend.Variable(0)
for _, b := range bytes {
result = s.api.Mul(result, 65536)
for _, b := range words {
result = s.api.Mul(result, 1<<(8*s.wordSize))
result = s.api.Add(result, b)
}
return result
Expand All @@ -137,26 +145,28 @@ func (s *Skyscraper) assertLessThanModulus(hi, lo frontend.Variable) {
}

// the result is NOT rangechecked, but if it is in range, it is canonical
func (s *Skyscraper) canonicalDecompose(v frontend.Variable) [16]frontend.Variable {
o, _ := s.api.Compiler().NewHint(bytesBeHint, 16, v)
result := [16]frontend.Variable{}
func (s *Skyscraper) canonicalDecompose(v frontend.Variable) []frontend.Variable {
wordsPerFelt := 32 / s.wordSize
o, _ := s.api.Compiler().NewHint(wordsBeHint, wordsPerFelt, s.wordSize, v)
result := make([]frontend.Variable, wordsPerFelt)
copy(result[:], o)
s.api.AssertIsEqual(s.varFromBytesBe(result[:]), v)
s.assertLessThanModulus(s.varFromBytesBe(result[:8]), s.varFromBytesBe(result[8:]))
s.api.AssertIsEqual(s.varFromWordsBe(result[:]), v)
s.assertLessThanModulus(s.varFromWordsBe(result[:wordsPerFelt/2]), s.varFromWordsBe(result[wordsPerFelt/2:]))
return result
}

func (s *Skyscraper) bar(v frontend.Variable) frontend.Variable {
bytes := s.canonicalDecompose(v)
tmp := [8]frontend.Variable{}
copy(tmp[:], bytes[:8])
copy(bytes[:], bytes[8:])
copy(bytes[8:], tmp[:])
for i := range bytes {
words := s.canonicalDecompose(v)
wordsPerFelt := 32 / s.wordSize
tmp := make([]frontend.Variable, wordsPerFelt/2)
copy(tmp[:], words[:wordsPerFelt/2])
copy(words[:], words[wordsPerFelt/2:])
copy(words[wordsPerFelt/2:], tmp[:])
for i := range words {
// sbox implicitly rangechecks the input
bytes[i] = s.sbox(bytes[i])
words[i] = s.sbox(words[i])
}
return s.varFromBytesBe(bytes[:])
return s.varFromWordsBe(words[:])
}

func (s *Skyscraper) Permute(state *[2]frontend.Variable) {
Expand Down
59 changes: 36 additions & 23 deletions hash/skyscraper_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hash

import (
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
Expand All @@ -16,31 +17,35 @@ func bigIntFromString(s string) frontend.Variable {
}

type TestSboxC struct {
In, Out frontend.Variable
WordSize int
In, Out frontend.Variable
}

func (c *TestSboxC) Define(api frontend.API) error {
s := NewSkyscraper(api)
s := NewSkyscraper(api, c.WordSize)
api.AssertIsEqual(s.sbox(c.In), c.Out)
return nil
}

func TestSbox(t *testing.T) {
assert := test.NewAssert(t)
assert.CheckCircuit(&TestSboxC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestSboxC{0xcd, 0xd3}),
test.WithValidAssignment(&TestSboxC{0x17, 0x0e}),
test.WithInvalidAssignment(&TestSboxC{0x17, 0x0f}),
test.WithInvalidAssignment(&TestSboxC{0x1234, 0x0f}))

for wordSize := 1; wordSize <= 2; wordSize++ {
assert.CheckCircuit(&TestSboxC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestSboxC{wordSize, 0xcd, 0xd3}),
test.WithValidAssignment(&TestSboxC{wordSize, 0x17, 0x0e}),
test.WithInvalidAssignment(&TestSboxC{wordSize, 0x17, 0x0f}),
test.WithInvalidAssignment(&TestSboxC{wordSize, 0x1234, 0x0f}))
}
assert.CheckCircuit(&TestSboxC{WordSize: 2}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestSboxC{2, 0xcd17, 0xd30e}))
}

type TestSquareC struct {
In, Out frontend.Variable
}

func (c *TestSquareC) Define(api frontend.API) error {
s := NewSkyscraper(api)
s := NewSkyscraper(api, 1)
s.sbox(123) // needed to silence an error about unused lookup tables
api.AssertIsEqual(s.square(c.In), c.Out)
return nil
Expand All @@ -57,40 +62,48 @@ func TestSquare(t *testing.T) {
}

type TestBarC struct {
In, Out frontend.Variable
WordSize int
In, Out frontend.Variable
}

func (c *TestBarC) Define(api frontend.API) error {
s := NewSkyscraper(api)
s := NewSkyscraper(api, c.WordSize)
api.AssertIsEqual(s.bar(c.In), c.Out)
return nil
}

func TestBar(t *testing.T) {
assert := test.NewAssert(t)
assert.CheckCircuit(&TestBarC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestBarC{0, 0}),
test.WithValidAssignment(&TestBarC{1, bigIntFromString("680564733841876926926749214863536422912")}),
test.WithValidAssignment(&TestBarC{2, bigIntFromString("1361129467683753853853498429727072845824")}),
test.WithValidAssignment(&TestBarC{bigIntFromString("4111585712030104139416666328230194227848755236259444667527487224433891325648"), bigIntFromString("18867677047139790809471719918880601980605904427073186248909139907505620573990")}))

for wordSize := 1; wordSize <= 2; wordSize++ {
fmt.Printf("wordSize: %d\n", wordSize)
assert.CheckCircuit(&TestBarC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestBarC{wordSize, 0, 0}),
test.WithValidAssignment(&TestBarC{wordSize, 1, bigIntFromString("680564733841876926926749214863536422912")}),
test.WithValidAssignment(&TestBarC{wordSize, 2, bigIntFromString("1361129467683753853853498429727072845824")}),
test.WithValidAssignment(&TestBarC{wordSize, bigIntFromString("4111585712030104139416666328230194227848755236259444667527487224433891325648"), bigIntFromString("18867677047139790809471719918880601980605904427073186248909139907505620573990")}))

}
}

type TestCompressC struct {
WordSize int
In1, In2, Out frontend.Variable
}

func (c *TestCompressC) Define(api frontend.API) error {
s := NewSkyscraper(api)
s := NewSkyscraper(api, c.WordSize)
api.AssertIsEqual(s.Compress(c.In1, c.In2), c.Out)
return nil
}

func TestCompress(t *testing.T) {
assert := test.NewAssert(t)
assert.CheckCircuit(&TestCompressC{}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestCompressC{
bigIntFromString("21614608883591910674239883101354062083890746690626773887530227216615498812963"),
bigIntFromString("9813154100006487150380270585621895148484502414032888228750638800367218873447"),
bigIntFromString("3583228880285179354728993622328037400470978495633822008876840172083178912457")}))
for wordSize := 1; wordSize <= 2; wordSize++ {
assert.CheckCircuit(&TestCompressC{WordSize: wordSize}, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16),
test.WithValidAssignment(&TestCompressC{wordSize,
bigIntFromString("21614608883591910674239883101354062083890746690626773887530227216615498812963"),
bigIntFromString("9813154100006487150380270585621895148484502414032888228750638800367218873447"),
bigIntFromString("3583228880285179354728993622328037400470978495633822008876840172083178912457")}))
}

}

0 comments on commit bc6a8cd

Please sign in to comment.