diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index b75f200f7a..530a5fe412 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -28,6 +28,15 @@ import ( // AssertIsEqual adds an assertion in the constraint builder (i1 == i2) func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { + c1, i1Constant := builder.constantValue(i1) + c2, i2Constant := builder.constantValue(i2) + + if i1Constant && i2Constant { + if c1 != c2 { + panic("non-equal constant values") + } + return + } // encoded 1 * i1 == i2 r := builder.getLinearExpression(builder.toVariable(i1)) o := builder.getLinearExpression(builder.toVariable(i2)) diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index df1306b4c4..a070a6c7ad 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -308,6 +308,9 @@ func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, // and are always reduced to one element. may not always be true? return constraint.Element{}, false } + if _v[0].Coeff.IsZero() { + return constraint.Element{}, true + } if !(_v[0].WireID() == 0) { // public ONE WIRE return constraint.Element{}, false } diff --git a/frontend/cs/r1cs/r1cs_test.go b/frontend/cs/r1cs/r1cs_test.go index ceb965322c..02762db1e0 100644 --- a/frontend/cs/r1cs/r1cs_test.go +++ b/frontend/cs/r1cs/r1cs_test.go @@ -160,3 +160,23 @@ func TestPreCompileHook(t *testing.T) { t.Error("callback not called") } } + +type subSameNoConstraintCircuit struct { + A frontend.Variable +} + +func (c *subSameNoConstraintCircuit) Define(api frontend.API) error { + r := api.Sub(c.A, c.A) + api.AssertIsEqual(r, 0) + return nil +} + +func TestSubSameNoConstraint(t *testing.T) { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), NewBuilder, &subSameNoConstraintCircuit{}) + if err != nil { + t.Fatal(err) + } + if ccs.GetNbConstraints() != 0 { + t.Fatal("expected 0 constraints") + } +} diff --git a/frontend/cs/scs/api_test.go b/frontend/cs/scs/api_test.go index 40c967203d..1f81e31c15 100644 --- a/frontend/cs/scs/api_test.go +++ b/frontend/cs/scs/api_test.go @@ -238,3 +238,22 @@ func TestMulAccFastTrack(t *testing.T) { assert.NoError(err) _ = solution } + +type subSameNoConstraintCircuit struct { + A frontend.Variable +} + +func (c *subSameNoConstraintCircuit) Define(api frontend.API) error { + r := api.Sub(c.A, c.A) + api.AssertIsEqual(r, 0) + return nil +} + +func TestSubSameNoConstraint(t *testing.T) { + assert := test.NewAssert(t) + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &subSameNoConstraintCircuit{}) + assert.NoError(err) + if ccs.GetNbConstraints() != 0 { + t.Fatal("expected 0 constraints") + } +} diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index 16a9ddd289..0ca3666fac 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -17,6 +17,7 @@ limitations under the License. package scs import ( + "fmt" "math/big" "reflect" "sort" @@ -312,7 +313,10 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { } func (builder *builder) constantValue(v frontend.Variable) (constraint.Element, bool) { - if _, ok := v.(expr.Term); ok { + if vv, ok := v.(expr.Term); ok { + if vv.Coeff.IsZero() { + return constraint.Element{}, true + } return constraint.Element{}, false } return builder.cs.FromInterface(v), true @@ -686,3 +690,63 @@ func (builder *builder) ToCanonicalVariable(v frontend.Variable) frontend.Canoni return term } } + +// GetWireConstraints returns the pairs (constraintID, wireLocation) for the +// given wires in the compiled constraint system: +// - constraintID is the index of the constraint in the constraint system. +// - wireLocation is the location of the wire in the constraint (0=xA or 1=xB). +// +// If the argument addMissing is true, then the function will add a new +// constraint for each wire that is not found in the constraint system. This may +// happen when getting the constraint for a witness which is not used in +// constraints. Otherwise, when addMissing is false, the function returns an +// error if a wire is not found in the constraint system. +// +// The method only returns a single pair (constraintID, wireLocation) for every +// unique wire (removing duplicates). The order of the returned pairs is not the +// same as for the given arguments. +func (builder *builder) GetWireConstraints(wires []frontend.Variable, addMissing bool) ([][2]int, error) { + // construct a lookup table table for later quick access when iterating over instructions + lookup := make(map[int]struct{}) + for _, w := range wires { + ww, ok := w.(expr.Term) + if !ok { + panic("input wire is not a Term") + } + lookup[ww.WireID()] = struct{}{} + } + res := make([][2]int, 0, len(wires)) + iterator := builder.cs.GetSparseR1CIterator() + for c, constraintIdx := iterator.Next(), 0; c != nil; c, constraintIdx = iterator.Next(), constraintIdx+1 { + if _, ok := lookup[int(c.XA)]; ok { + res = append(res, [2]int{constraintIdx, 0}) + delete(lookup, int(c.XA)) + continue + } + if _, ok := lookup[int(c.XB)]; ok { + res = append(res, [2]int{constraintIdx, 1}) + delete(lookup, int(c.XB)) + continue + } + } + if addMissing { + nbWitnessWires := builder.cs.GetNbPublicVariables() + builder.cs.GetNbSecretVariables() + for k := range lookup { + if k >= nbWitnessWires { + return nil, fmt.Errorf("addMissing is true, but wire %d is not a witness", k) + } + constraintIdx := builder.cs.AddSparseR1C(constraint.SparseR1C{ + XA: uint32(k), + XC: uint32(k), + QL: constraint.CoeffIdOne, + QO: constraint.CoeffIdMinusOne, + }, builder.genericGate) + res = append(res, [2]int{constraintIdx, 0}) + delete(lookup, k) + } + } + if len(lookup) > 0 { + return nil, fmt.Errorf("constraint with wire not found in circuit") + } + return res, nil +}