diff --git a/std/hints.go b/std/hints.go index 8fc1222781..b4b16ae52f 100644 --- a/std/hints.go +++ b/std/hints.go @@ -8,6 +8,7 @@ import ( "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/math/bits" "github.com/consensys/gnark/std/math/emulated" + "github.com/consensys/gnark/std/selector" ) var registerOnce sync.Once @@ -30,5 +31,7 @@ func registerHints() { hint.Register(bits.NNAF) hint.Register(bits.IthBit) hint.Register(bits.NBits) + hint.Register(selector.MuxIndicators) + hint.Register(selector.MapIndicators) hint.Register(emulated.GetHints()...) } diff --git a/std/selector/doc_map_test.go b/std/selector/doc_map_test.go new file mode 100644 index 0000000000..f70c353f61 --- /dev/null +++ b/std/selector/doc_map_test.go @@ -0,0 +1,79 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MapCircuit is a minimal circuit using a selector map. +type MapCircuit struct { + QueryKey frontend.Variable + Keys [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Values [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + ExpectedValue frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MapCircuit) Define(api frontend.API) error { + result := selector.Map(api, c.QueryKey, c.Keys[:], c.Values[:]) + api.AssertIsEqual(result, c.ExpectedValue) + return nil +} + +// ExampleMap gives an example on how to use map selector. +func ExampleMap() { + circuit := MapCircuit{} + witness := MapCircuit{ + QueryKey: 55, + Keys: [10]frontend.Variable{0, 11, 22, 33, 44, 55, 66, 77, 88, 99}, + Values: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + ExpectedValue: 10, // element in values which corresponds to the position of 55 in keys + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/doc_mux_test.go b/std/selector/doc_mux_test.go new file mode 100644 index 0000000000..5a7b75ba36 --- /dev/null +++ b/std/selector/doc_mux_test.go @@ -0,0 +1,77 @@ +package selector_test + +import ( + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/selector" +) + +// MuxCircuit is a minimal circuit using a selector mux. +type MuxCircuit struct { + Selector frontend.Variable + In [10]frontend.Variable // we use array in witness to allocate sufficiently sized vector + Expected frontend.Variable +} + +// Define defines the arithmetic circuit. +func (c *MuxCircuit) Define(api frontend.API) error { + result := selector.Mux(api, c.Selector, c.In[:]...) // Note Mux takes var-arg input, need to expand the input vector + api.AssertIsEqual(result, c.Expected) + return nil +} + +// ExampleMux gives an example on how to use mux selector. +func ExampleMux() { + circuit := MuxCircuit{} + witness := MuxCircuit{ + Selector: 5, + In: [10]frontend.Variable{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}, + Expected: 10, // 5-th element in vector + } + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) + if err != nil { + panic(err) + } else { + fmt.Println("compiled") + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + panic(err) + } else { + fmt.Println("setup done") + } + secretWitness, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) + if err != nil { + panic(err) + } else { + fmt.Println("secret witness") + } + publicWitness, err := secretWitness.Public() + if err != nil { + panic(err) + } else { + fmt.Println("public witness") + } + proof, err := groth16.Prove(ccs, pk, secretWitness) + if err != nil { + panic(err) + } else { + fmt.Println("proof") + } + err = groth16.Verify(proof, vk, publicWitness) + if err != nil { + panic(err) + } else { + fmt.Println("verify") + } + // Output: compiled + // setup done + // secret witness + // public witness + // proof + // verify +} diff --git a/std/selector/multiplexer.go b/std/selector/multiplexer.go new file mode 100644 index 0000000000..8e9bec5003 --- /dev/null +++ b/std/selector/multiplexer.go @@ -0,0 +1,114 @@ +// Package selector provides a lookup table and map based on linear scan. +// +// The native [frontend.API] provides 1- and 2-bit lookups through the interface +// methods Select and Lookup2. This package extends the lookups to +// arbitrary-sized vectors. The lookups can be performed using the index of the +// elements (function [Mux]) or using a key, for which the user needs to provide +// the slice of keys (function [Map]). +// +// The implementation uses linear scan over all inputs, so the constraint count +// for every invocation of the function is C*len(values)+1, where: +// - for R1CS, C = 3 +// - for PLONK, C = 5 +package selector + +import ( + "math/big" + + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" +) + +func init() { + // register hints + hint.Register(MuxIndicators) + hint.Register(MapIndicators) +} + +// Map is a key value associative array: the output will be values[i] such that keys[i] == queryKey. If keys does not +// contain queryKey, no proofs can be generated. If keys has more than one key that equals to queryKey, the output will +// be undefined, and the output could be a linear combination of all the corresponding values with that queryKey. +// +// In case keys and values do not have the same length, this function will panic. +func Map(api frontend.API, queryKey frontend.Variable, + keys []frontend.Variable, values []frontend.Variable) frontend.Variable { + // we don't need this check, but we added it to produce more informative errors and disallow + // len(keys) < len(values) which is supported by generateSelector. + if len(keys) != len(values) { + panic("The number of keys and values must be equal") + } + return generateSelector(api, false, queryKey, keys, values) +} + +// Mux is an n to 1 multiplexer: out = inputs[sel]. In other words, it selects exactly one of its +// inputs based on sel. The index of inputs starts from zero. +// +// sel needs to be between 0 and n - 1 (inclusive), where n is the number of inputs, otherwise the proof will fail. +func Mux(api frontend.API, sel frontend.Variable, inputs ...frontend.Variable) frontend.Variable { + return generateSelector(api, true, sel, nil, inputs) +} + +// generateSelector generates a circuit for a multiplexer or an associative array (map). If wantMux is true, a +// multiplexer is generated and keys are ignored. If wantMux is false, a map is generated, and we must have +// len(keys) <= len(values), or it panics. +func generateSelector(api frontend.API, wantMux bool, sel frontend.Variable, + keys []frontend.Variable, values []frontend.Variable) (out frontend.Variable) { + + var indicators []frontend.Variable + if wantMux { + indicators, _ = api.Compiler().NewHint(MuxIndicators, len(values), sel) + } else { + indicators, _ = api.Compiler().NewHint(MapIndicators, len(keys), append(keys, sel)...) + } + + out = 0 + indicatorsSum := frontend.Variable(0) + for i := 0; i < len(indicators); i++ { + // Check that all indicators for inputs that are not selected, are zero. + if wantMux { + // indicators[i] * (sel - i) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, i)), 0) + } else { + // indicators[i] * (sel - keys[i]) == 0 + api.AssertIsEqual(api.Mul(indicators[i], api.Sub(sel, keys[i])), 0) + } + indicatorsSum = api.Add(indicatorsSum, indicators[i]) + // out += indicators[i] * values[i] + out = api.MulAcc(out, indicators[i], values[i]) + } + // We need to check that the indicator of the selected input is exactly 1. We used a sum constraint, because usually + // it is cheap. + api.AssertIsEqual(indicatorsSum, 1) + return out +} + +// MuxIndicators is a hint function used within [Mux] function. It must be +// provided to the prover when circuit uses it. +func MuxIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + sel := inputs[0] + for i := 0; i < len(results); i++ { + // i is an int which can be int32 or int64. We convert i to int64 then to bigInt, which is safe. We should + // not convert sel to int64. + if sel.Cmp(big.NewInt(int64(i))) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} + +// MapIndicators is a hint function used within [Map] function. It must be +// provided to the prover when circuit uses it. +func MapIndicators(_ *big.Int, inputs []*big.Int, results []*big.Int) error { + key := inputs[len(inputs)-1] + // We must make sure that we are initializing all elements of results + for i := 0; i < len(results); i++ { + if key.Cmp(inputs[i]) == 0 { + results[i].SetUint64(1) + } else { + results[i].SetUint64(0) + } + } + return nil +} diff --git a/std/selector/multiplexer_test.go b/std/selector/multiplexer_test.go new file mode 100644 index 0000000000..a0233205d1 --- /dev/null +++ b/std/selector/multiplexer_test.go @@ -0,0 +1,177 @@ +package selector_test + +import ( + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/selector" + "github.com/consensys/gnark/test" +) + +type muxCircuit struct { + SEL frontend.Variable + I0, I1, I2, I3, I4 frontend.Variable + OUT frontend.Variable +} + +func (c *muxCircuit) Define(api frontend.API) error { + + out := selector.Mux(api, c.SEL, c.I0, c.I1, c.I2, c.I3, c.I4) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +// The output of this circuit is ignored and the only way its proof can fail is by providing invalid inputs. +type ignoredOutputMuxCircuit struct { + SEL frontend.Variable + I0, I1, I2 frontend.Variable +} + +func (c *ignoredOutputMuxCircuit) Define(api frontend.API) error { + // We ignore the output + _ = selector.Mux(api, c.SEL, c.I0, c.I1, c.I2) + + return nil +} + +func TestMux(t *testing.T) { + assert := test.NewAssert(t) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 2, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 12}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 10, I1: 11, I2: 12, I3: 13, I4: 14, OUT: 10}) + + assert.ProverSucceeded(&muxCircuit{}, &muxCircuit{SEL: 4, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + // Failures + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 5, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 24}) + + assert.ProverFailed(&muxCircuit{}, &muxCircuit{SEL: 0, I0: 20, I1: 21, I2: 22, I3: 23, I4: 24, OUT: 21}) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 0, I0: 0, I1: 1, I2: 2}) + + assert.ProverSucceeded(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 2, I0: 0, I1: 1, I2: 2}) + + // Failures + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: 3, I0: 0, I1: 1, I2: 2}) + + assert.ProverFailed(&ignoredOutputMuxCircuit{}, &ignoredOutputMuxCircuit{SEL: -1, I0: 0, I1: 1, I2: 2}) +} + +// Map tests: +type mapCircuit struct { + SEL frontend.Variable + K0, K1, K2, K3 frontend.Variable + V0, V1, V2, V3 frontend.Variable + OUT frontend.Variable +} + +func (c *mapCircuit) Define(api frontend.API) error { + + out := selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1, c.K2, c.K3}, + []frontend.Variable{c.V0, c.V1, c.V2, c.V3}) + + api.AssertIsEqual(out, c.OUT) + + return nil +} + +type ignoredOutputMapCircuit struct { + SEL frontend.Variable + K0, K1 frontend.Variable + V0, V1 frontend.Variable +} + +func (c *ignoredOutputMapCircuit) Define(api frontend.API) error { + + _ = selector.Map(api, c.SEL, + []frontend.Variable{c.K0, c.K1}, + []frontend.Variable{c.V0, c.V1}) + + return nil +} + +func TestMap(t *testing.T) { + assert := test.NewAssert(t) + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 100, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 0, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 222, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, success: + assert.ProverSucceeded(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 222, K1: 222, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + // Duplicated key, UNDEFINED behavior: (with our hint implementation it fails) + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 333, + K0: 100, K1: 111, K2: 333, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 77, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 3, + }) + + assert.ProverFailed(&mapCircuit{}, + &mapCircuit{ + SEL: 111, + K0: 100, K1: 111, K2: 222, K3: 333, + V0: 0, V1: 1, V2: 2, V3: 3, + OUT: 2, + }) + + // Ignoring the circuit's output: + assert.ProverSucceeded(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 5, + K0: 5, K1: 5, + V0: 10, V1: 11, + }) + + assert.ProverFailed(&ignoredOutputMapCircuit{}, + &ignoredOutputMapCircuit{SEL: 6, + K0: 5, K1: 7, + V0: 10, V1: 11, + }) + +}