Skip to content

Commit

Permalink
width 2 lookup for skyscraper
Browse files Browse the repository at this point in the history
  • Loading branch information
kustosz committed Dec 3, 2024
1 parent 2c95c3b commit fef3626
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 21 deletions.
29 changes: 25 additions & 4 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package main

import (
"fmt"
"github.com/reilabs/gnark-nimue/hash"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/uints"
gnark_nimue "github.com/reilabs/gnark-nimue"
"github.com/reilabs/gnark-nimue/hash"
"math/bits"
)

type TestCircuit struct {
Expand Down Expand Up @@ -193,7 +194,7 @@ type Manhattan struct {
func (c *Manhattan) Define(api frontend.API) error {
s := hash.NewSkyscraper(api)
a := c.I
for range 2000 {
for range 3000 {
a = s.Compress(a, a)
}
api.AssertIsEqual(c.O, a)
Expand All @@ -220,8 +221,28 @@ func ExampleManhattan() {
fmt.Printf("%v\n", vErr)
}

type TestLookup struct {
In frontend.Variable
}

func (c *TestLookup) Define(api frontend.API) error {
table := logderivlookup.New(api)
for i := range 256 {
table.Insert(bits.RotateLeft8(uint8(i), 3))
}
c0 := c.In
for range 256 {
c0 = table.Lookup(c0)[0]
}
api.AssertIsEqual(c0, c.In)
return nil
}

func main() {
ccs, _ := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &TestLookup{})
fmt.Printf("constraints: %d\n", ccs.GetNbConstraints())

//Example1()
//ExampleWhir()
ExampleManhattan()
//ExampleManhattan()
}
43 changes: 26 additions & 17 deletions hash/skyscraper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ func bytesBeHint(field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 1 {
return fmt.Errorf("bytesHint: expected 1 input, got %d", len(inputs))
}
if len(outputs) != 32 {
if len(outputs) != 16 {
return fmt.Errorf("bytesHint: expected 32 outputs, got %d", len(outputs))
}
bytes := make([]byte, 32)
inputs[0].FillBytes(bytes)
for i, o := range outputs {
o.SetUint64(uint64(bytes[i]))
o.SetUint64(uint64(bytes[2*i]))
o.Mul(o, big.NewInt(256))
o.Add(o, big.NewInt(int64(bytes[2*i+1])))
}
return nil
}
Expand Down Expand Up @@ -58,14 +60,21 @@ type Skyscraper struct {
api frontend.API
}

func sboxByte(b byte) byte {
x := bits.RotateLeft8(^b, 1)
y := bits.RotateLeft8(b, 2)
z := bits.RotateLeft8(b, 3)
return bits.RotateLeft8(b^(x&y&z), 1)
}

func initSbox(api frontend.API) *logderivlookup.Table {
t := logderivlookup.New(api)
for i := range 256 {
b := uint8(i)
x := bits.RotateLeft8(^b, 1)
y := bits.RotateLeft8(b, 2)
z := bits.RotateLeft8(b, 3)
t.Insert(bits.RotateLeft8(b^(x&y&z), 1))
for i := range 65536 {
w := uint16(i)
b1 := byte(w & 0xff)
b2 := byte(w >> 8)
r := uint16(sboxByte(b1)) | (uint16(sboxByte(b2)) << 8)
t.Insert(r)
}
return t
}
Expand Down Expand Up @@ -103,7 +112,7 @@ func (s *Skyscraper) square(v frontend.Variable) frontend.Variable {
func (s *Skyscraper) varFromBytesBe(bytes []frontend.Variable) frontend.Variable {
result := frontend.Variable(0)
for _, b := range bytes {
result = s.api.Mul(result, 256)
result = s.api.Mul(result, 65536)
result = s.api.Add(result, b)
}
return result
Expand All @@ -128,21 +137,21 @@ func (s *Skyscraper) assertLessThanModulus(hi, lo frontend.Variable) {
}

// the result is NOT rangechecked, but if it is in range, it is canonical
func (s *Skyscraper) canonicalDecompose(v frontend.Variable) [32]frontend.Variable {
o, _ := s.api.Compiler().NewHint(bytesBeHint, 32, v)
result := [32]frontend.Variable{}
func (s *Skyscraper) canonicalDecompose(v frontend.Variable) [16]frontend.Variable {
o, _ := s.api.Compiler().NewHint(bytesBeHint, 16, v)
result := [16]frontend.Variable{}
copy(result[:], o)
s.api.AssertIsEqual(s.varFromBytesBe(result[:]), v)
s.assertLessThanModulus(s.varFromBytesBe(result[:16]), s.varFromBytesBe(result[16:]))
s.assertLessThanModulus(s.varFromBytesBe(result[:8]), s.varFromBytesBe(result[8:]))
return result
}

func (s *Skyscraper) bar(v frontend.Variable) frontend.Variable {
bytes := s.canonicalDecompose(v)
tmp := [16]frontend.Variable{}
copy(tmp[:], bytes[:16])
copy(bytes[:], bytes[16:])
copy(bytes[16:], tmp[:])
tmp := [8]frontend.Variable{}
copy(tmp[:], bytes[:8])
copy(bytes[:], bytes[8:])
copy(bytes[8:], tmp[:])
for i := range bytes {
// sbox implicitly rangechecks the input
bytes[i] = s.sbox(bytes[i])
Expand Down

0 comments on commit fef3626

Please sign in to comment.