From 6d9bbf04f7ac3538fe751c22177778f701f5e741 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Fri, 29 Sep 2023 10:28:08 -0300 Subject: [PATCH 1/7] Implement Usort hints (#244) * Add hint codes * Implement usort_enter_scope Hint * Handle cast error * fmt * Implement USORT_BODY hint * Implement USORT_VERIFY hint * Implement USORT_VERIFY_MULTIPLICITY_ASSERT hint * Implement USORT_VERIFY_MULTIPLICITY_BODY hint * hint fixes * integration tests * refactor * unit test * unit test * unit test * add unit test * add unit test * add unit test * move file pkg/hints/usort_hint_codes.go -> pkg/hints/hint_codes/usort_hint_codes.go * Fix doc * Add symlink * CamelCase * typos * Handle ids.Insert errors * Handle Memory.Insert errors --- cairo_programs/proof_programs/usort.cairo | 1 + cairo_programs/usort.cairo | 22 ++ pkg/hints/hint_codes/usort_hint_codes.go | 32 +++ pkg/hints/hint_processor.go | 10 + pkg/hints/usort_hints.go | 266 ++++++++++++++++++++++ pkg/hints/usort_hints_test.go | 240 +++++++++++++++++++ pkg/vm/cairo_run/cairo_run_test.go | 8 + 7 files changed, 579 insertions(+) create mode 120000 cairo_programs/proof_programs/usort.cairo create mode 100644 cairo_programs/usort.cairo create mode 100644 pkg/hints/hint_codes/usort_hint_codes.go create mode 100644 pkg/hints/usort_hints.go create mode 100644 pkg/hints/usort_hints_test.go diff --git a/cairo_programs/proof_programs/usort.cairo b/cairo_programs/proof_programs/usort.cairo new file mode 120000 index 00000000..4a7e19e6 --- /dev/null +++ b/cairo_programs/proof_programs/usort.cairo @@ -0,0 +1 @@ +../usort.cairo \ No newline at end of file diff --git a/cairo_programs/usort.cairo b/cairo_programs/usort.cairo new file mode 100644 index 00000000..e5859b29 --- /dev/null +++ b/cairo_programs/usort.cairo @@ -0,0 +1,22 @@ +%builtins range_check +from starkware.cairo.common.usort import usort +from starkware.cairo.common.alloc import alloc + +func main{range_check_ptr}() -> () { + alloc_locals; + let (input_array: felt*) = alloc(); + assert input_array[0] = 2; + assert input_array[1] = 1; + assert input_array[2] = 0; + + let (output_len, output, multiplicities) = usort(input_len=3, input=input_array); + + assert output_len = 3; + assert output[0] = 0; + assert output[1] = 1; + assert output[2] = 2; + assert multiplicities[0] = 1; + assert multiplicities[1] = 1; + assert multiplicities[2] = 1; + return (); +} diff --git a/pkg/hints/hint_codes/usort_hint_codes.go b/pkg/hints/hint_codes/usort_hint_codes.go new file mode 100644 index 00000000..7dd3e944 --- /dev/null +++ b/pkg/hints/hint_codes/usort_hint_codes.go @@ -0,0 +1,32 @@ +package hint_codes + +const USORT_ENTER_SCOPE = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))" + +const USORT_BODY = `from collections import defaultdict + +input_ptr = ids.input +input_len = int(ids.input_len) +if __usort_max_size is not None: + assert input_len <= __usort_max_size, ( + f"usort() can only be used with input_len<={__usort_max_size}. " + f"Got: input_len={input_len}." + ) + +positions_dict = defaultdict(list) +for i in range(input_len): + val = memory[input_ptr + i] + positions_dict[val].append(i) + +output = sorted(positions_dict.keys()) +ids.output_len = len(output) +ids.output = segments.gen_arg(output) +ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])` + +const USORT_VERIFY = `last_pos = 0 +positions = positions_dict[ids.value][::-1]` + +const USORT_VERIFY_MULTIPLICITY_ASSERT = "assert len(positions) == 0" + +const USORT_VERIFY_MULTIPLICITY_BODY = `current_pos = positions.pop() +ids.next_item_index = current_pos - last_pos +last_pos = current_pos + 1` diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index 6ed7562f..a43c76ef 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -106,6 +106,16 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return memset_step_loop(data.Ids, vm, execScopes, "continue_loop") case VM_ENTER_SCOPE: return vm_enter_scope(execScopes) + case USORT_ENTER_SCOPE: + return usortEnterScope(execScopes) + case USORT_BODY: + return usortBody(data.Ids, execScopes, vm) + case USORT_VERIFY: + return usortVerify(data.Ids, execScopes, vm) + case USORT_VERIFY_MULTIPLICITY_ASSERT: + return usortVerifyMultiplicityAssert(execScopes) + case USORT_VERIFY_MULTIPLICITY_BODY: + return usortVerifyMultiplicityBody(data.Ids, execScopes, vm) case SET_ADD: return setAdd(data.Ids, vm) case FIND_ELEMENT: diff --git a/pkg/hints/usort_hints.go b/pkg/hints/usort_hints.go new file mode 100644 index 00000000..1dc0f364 --- /dev/null +++ b/pkg/hints/usort_hints.go @@ -0,0 +1,266 @@ +package hints + +import ( + "fmt" + "sort" + + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/pkg/errors" +) + +// SortFelt implements sort.Interface for []lambdaworks.Felt +type SortFelt []lambdaworks.Felt + +func (s SortFelt) Len() int { return len(s) } +func (s SortFelt) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s SortFelt) Less(i, j int) bool { + a, b := s[i], s[j] + + return a.Cmp(b) == -1 +} + +// Implements hint: +// %{ vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size'))) %} +func usortEnterScope(executionScopes *types.ExecutionScopes) error { + usort_max_size_interface, err := executionScopes.Get("usort_max_size") + + if err != nil { + executionScopes.EnterScope(make(map[string]interface{})) + return nil + } + + usort_max_size, cast_ok := usort_max_size_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + scope := make(map[string]interface{}) + scope["usort_max_size"] = usort_max_size + executionScopes.EnterScope(scope) + + return nil +} + +func usortBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + input_ptr, err := ids.GetRelocatable("input", vm) + if err != nil { + return err + } + + input_len, err := ids.GetFelt("input_len", vm) + + if err != nil { + return err + } + input_len_u64, err := input_len.ToU64() + + if err != nil { + return err + } + + usort_max_size, err := executionScopes.Get("usort_max_size") + + if err == nil { + usort_max_size_u64, cast_ok := usort_max_size.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + if input_len_u64 > usort_max_size_u64 { + return errors.New(fmt.Sprintf("usort() can only be used with input_len<= %v. Got: input_len=%v.", usort_max_size_u64, input_len_u64)) + } + } + + positions_dict := make(map[lambdaworks.Felt][]uint64) + + for i := uint64(0); i < input_len_u64; i++ { + + val, err := vm.Segments.Memory.GetFelt(input_ptr.AddUint(uint(i))) + + if err != nil { + return err + } + + positions_dict[val] = append(positions_dict[val], i) + } + executionScopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + output := make([]lambdaworks.Felt, 0, len(positions_dict)) + + for key := range positions_dict { + output = append(output, key) + } + + sort.Sort(SortFelt(output)) + + output_len := memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(uint64((len(output))))) + err = ids.Insert("output_len", output_len, vm) + + if err != nil { + return err + } + + output_base := vm.Segments.AddSegment() + + for i := range output { + err = vm.Segments.Memory.Insert(output_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(output[i])) + + if err != nil { + return err + } + } + + multiplicities_base := vm.Segments.AddSegment() + + multiplicities := make([]uint64, 0, len(output)) + + for key := range output { + multiplicities = append(multiplicities, uint64(len(positions_dict[output[key]]))) + } + + for i := range multiplicities { + err = vm.Segments.Memory.Insert(multiplicities_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(multiplicities[i]))) + + if err != nil { + return err + } + } + + err = ids.Insert("output", memory.NewMaybeRelocatableRelocatable(output_base), vm) + + if err != nil { + return err + } + + err = ids.Insert("multiplicities", memory.NewMaybeRelocatableRelocatable(multiplicities_base), vm) + + if err != nil { + return err + } + + return nil +} + +// Implements hint: +// +// %{ +// last_pos = 0 +// positions = positions_dict[ids.value][::-1] +// %} +func usortVerify(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + executionScopes.AssignOrUpdateVariable("last_pos", uint64(0)) + + positions_dict_interface, err := executionScopes.Get("positions_dict") + + if err != nil { + return err + } + + positions_dict, cast_ok := positions_dict_interface.(map[lambdaworks.Felt][]uint64) + + if !cast_ok { + return errors.New("Error casting positions_dict") + } + + value, err := ids.GetFelt("value", vm) + if err != nil { + return err + } + + if err != nil { + return err + } + + positions := positions_dict[value] + + for i, j := 0, len(positions)-1; i < j; i, j = i+1, j-1 { + positions[i], positions[j] = positions[j], positions[i] + } + + executionScopes.AssignOrUpdateVariable("positions", positions) + + return nil +} + +// Implements hint: +// %{ assert len(positions) == 0 %} +func usortVerifyMultiplicityAssert(executionScopes *types.ExecutionScopes) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + if len(positions) != 0 { + return errors.New("Assertion failed: len(positions) == 0") + } + + return nil + +} + +// Implements hint: +// +// %{ +// current_pos = positions.pop() +// ids.next_item_index = current_pos - last_pos +// last_pos = current_pos + 1 +// %} +func usortVerifyMultiplicityBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + last_pos_interface, err := executionScopes.Get("last_pos") + + if err != nil { + return err + } + + last_pos, cast_ok := last_pos_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting last_pos to uint64") + } + + current_pos := positions[len(positions)-1] + + executionScopes.AssignOrUpdateVariable("positions", positions[:len(positions)-1]) + + next_item_index := current_pos - last_pos + + err = ids.Insert("next_item_index", memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(next_item_index)), vm) + + if err != nil { + return err + } + + executionScopes.AssignOrUpdateVariable("last_pos", current_pos+1) + + return nil +} diff --git a/pkg/hints/usort_hints_test.go b/pkg/hints/usort_hints_test.go new file mode 100644 index 00000000..c390869e --- /dev/null +++ b/pkg/hints/usort_hints_test.go @@ -0,0 +1,240 @@ +package hints_test + +import ( + "reflect" + "sort" + "testing" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" +) + +func TestSortFeltArray(t *testing.T) { + array := []lambdaworks.Felt{lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(100), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(50)} + + sort.Sort(SortFelt(array)) + + sortedarray := []lambdaworks.Felt{lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(50), lambdaworks.FeltFromUint(100)} + + if !reflect.DeepEqual(array, sortedarray) { + t.Errorf("Error sorting felt array") + } + +} + +func TestUsortWithMaxSize(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_ENTER_SCOPE, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_ENTER_SCOPE hint execution failed") + } + + usort_max_size_interface, err := scopes.Get("usort_max_size") + + if err != nil { + t.Errorf("Error assigning usort_max_size") + } + + usort_max_size := usort_max_size_interface.(uint64) + + if usort_max_size != uint64(1) { + t.Errorf("Error assigning usort_max_size") + } + +} +func TestUsortOutOfRange(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "input": {NewMaybeRelocatableRelocatable(NewRelocatable(2, 1))}, + "input_len": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(5))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_BODY hint should have failed") + } + +} + +func TestUsortVerify(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + positions_dict := make(map[lambdaworks.Felt][]uint64) + positions_dict[lambdaworks.FeltFromUint64(0)] = []uint64{2} + positions_dict[lambdaworks.FeltFromUint64(1)] = []uint64{1} + positions_dict[lambdaworks.FeltFromUint64(2)] = []uint64{0} + + scopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(0))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY failed") + } + + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{2}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(0) { + t.Errorf("Error assigning last_pos") + } + +} + +func TestUsortVerifyMultiplicityAssert(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_ASSERT, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions := []uint64{0} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions = []uint64{} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT failed") + } + +} + +func TestUsortVerifyMultiplicityBody(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + + scopes.AssignOrUpdateVariable("positions", []uint64{1, 0, 4, 7, 10}) + scopes.AssignOrUpdateVariable("last_pos", uint64(3)) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "next_item_index": {nil}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_BODY failed") + } + + // Check scopes variables + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{1, 0, 4, 7}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(11) { + t.Errorf("Error assigning last_pos") + } + + // Check VM inserts + next_item_index, err := idsManager.GetFelt("next_item_index", vm) + + if err != nil { + t.Errorf("Error assigning next_item_index") + } + + if next_item_index != lambdaworks.FeltFromUint(7) { + t.Errorf("Error assigning next_item_index") + } + +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 3f54eacd..f2921821 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -309,6 +309,14 @@ func TestSplitFeltHint(t *testing.T) { testProgram("split_felt", t) } +func TestUsort(t *testing.T) { + testProgram("usort", t) +} + +func TestUsortProofMode(t *testing.T) { + testProgramProof("usort", t) +} + func TestSplitFeltHintProofMode(t *testing.T) { testProgramProof("split_felt", t) } From 51db1d4c9a4abbde348f1b40de18b0db20a7cea9 Mon Sep 17 00:00:00 2001 From: Juan-M-V <102986292+Juan-M-V@users.noreply.github.com> Date: Fri, 29 Sep 2023 10:51:11 -0300 Subject: [PATCH 2/7] Add double assign hint (#287) * Add ec hints * Implement hints * Add the hints to the processor * Test pack86 function * Test hint * Delete debug info, Test ec negative op * Second hint test * Test embedded hint * Change to Camel case * Implement slope hints * Fix format * Delete github conflict string * Tests hints * Tests hints slopes * Rename misleading name function * Fix function name * Fix error in function call * Delete debug info * Delete unused import * Secp hints * Secpr21 * Add it to the hint processor * Hints secp * bigint3 nondet * Zero verify * Merge main * Add hint to hint processor * Add double assign hint * Debug info * Remove integration test * Prints * Add unit tests * Test verify with unit test * Debug unit test * Test verify zero with debug * Non det big 3 test * Modify test to use ids manager * debug info * Fix broken test * Move file from hints_utils and rename * Delete debug * Move integration test to cairo_run_test.go * Return error of IdsData.Insert * Change to camel case * Add Integration test --------- Co-authored-by: Milton Co-authored-by: mmsc2 <88055861+mmsc2@users.noreply.github.com> Co-authored-by: Mariano A. Nicolini Co-authored-by: juan.mv Co-authored-by: Pedro Fontana --- cairo_programs/cairo_keccak.cairo | 1 + cairo_programs/ec_double_assign.cairo | 32 +++++++++++++ pkg/hints/ec_hint.go | 43 +++++++++++++++++ pkg/hints/ec_hint_test.go | 69 +++++++++++++++++++++++++++ pkg/hints/hint_codes/ec_op_hints.go | 33 ++++++++++++- pkg/hints/hint_processor.go | 2 + pkg/vm/cairo_run/cairo_run_test.go | 4 ++ 7 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 cairo_programs/ec_double_assign.cairo diff --git a/cairo_programs/cairo_keccak.cairo b/cairo_programs/cairo_keccak.cairo index 8adcd515..b5575e7c 100644 --- a/cairo_programs/cairo_keccak.cairo +++ b/cairo_programs/cairo_keccak.cairo @@ -27,3 +27,4 @@ func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { return (); } + diff --git a/cairo_programs/ec_double_assign.cairo b/cairo_programs/ec_double_assign.cairo new file mode 100644 index 00000000..16419e72 --- /dev/null +++ b/cairo_programs/ec_double_assign.cairo @@ -0,0 +1,32 @@ +%builtins range_check +from starkware.cairo.common.cairo_secp.bigint import BigInt3, nondet_bigint3 +struct EcPoint { + x: BigInt3, + y: BigInt3, +} + +func ec_double{range_check_ptr}(point: EcPoint, slope: BigInt3) -> (res: BigInt3) { + %{ + from starkware.cairo.common.cairo_secp.secp_utils import pack + SECP_P = 2**255-19 + + slope = pack(ids.slope, PRIME) + x = pack(ids.point.x, PRIME) + y = pack(ids.point.y, PRIME) + + value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P + %} + + let (new_x: BigInt3) = nondet_bigint3(); + return (res=new_x); +} + +func main{range_check_ptr}() { + let p = EcPoint(BigInt3(1,2,3), BigInt3(4,5,6)); + let s = BigInt3(7,8,9); + let (res) = ec_double(p, s); + assert res.d0 = 21935; + assert res.d1 = 12420; + assert res.d2 = 184; + return (); +} diff --git a/pkg/hints/ec_hint.go b/pkg/hints/ec_hint.go index 005523e0..04f0587c 100644 --- a/pkg/hints/ec_hint.go +++ b/pkg/hints/ec_hint.go @@ -7,6 +7,7 @@ import ( "github.com/lambdaclass/cairo-vm.go/pkg/builtins" "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" "github.com/lambdaclass/cairo-vm.go/pkg/types" . "github.com/lambdaclass/cairo-vm.go/pkg/types" "github.com/lambdaclass/cairo-vm.go/pkg/vm" @@ -184,6 +185,48 @@ func computeSlope(vm *VirtualMachine, execScopes ExecutionScopes, idsData IdsMan return nil } +// Implements hint: +// from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack +// +// slope = pack(ids.slope, PRIME) +// x = pack(ids.point.x, PRIME) +// y = pack(ids.point.y, PRIME) +// +// value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P +func ecDoubleAssignNewX(vm *VirtualMachine, execScopes ExecutionScopes, ids IdsManager, secpP big.Int) error { + execScopes.AssignOrUpdateVariable("SECP_P", secpP) + + slope3, err := BigInt3FromVarName("slope", ids, vm) + if err != nil { + return err + } + packedSlope := slope3.Pack86() + slope := new(big.Int).Mod(&packedSlope, Prime()) + point, err := EcPointFromVarName("point", vm, ids) + if err != nil { + return err + } + + xPacked := point.X.Pack86() + x := new(big.Int).Mod(&xPacked, Prime()) + yPacked := point.Y.Pack86() + y := new(big.Int).Mod(&yPacked, Prime()) + + value := new(big.Int).Mul(slope, slope) + value = value.Mod(value, &secpP) + + value = value.Sub(value, x) + value = value.Sub(value, x) + value = value.Mod(value, &secpP) + + execScopes.AssignOrUpdateVariable("slope", slope) + execScopes.AssignOrUpdateVariable("x", x) + execScopes.AssignOrUpdateVariable("y", y) + execScopes.AssignOrUpdateVariable("value", *value) + execScopes.AssignOrUpdateVariable("new_x", *value) + return nil +} + /* Implements hint: %{ from starkware.cairo.common.cairo_secp.secp256r1_utils import SECP256R1_ALPHA as ALPHA %} diff --git a/pkg/hints/ec_hint_test.go b/pkg/hints/ec_hint_test.go index 3cf08a82..974576b4 100644 --- a/pkg/hints/ec_hint_test.go +++ b/pkg/hints/ec_hint_test.go @@ -235,6 +235,75 @@ func TestRunComputeSlopeOk(t *testing.T) { } } +func TestEcDoubleAssignNewXOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "slope": { + NewMaybeRelocatableFelt(FeltFromUint64(3)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "point": { + // X + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + // Y + NewMaybeRelocatableFelt(FeltFromUint64(4)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: EC_DOUBLE_ASSIGN_NEW_X_V1, + }) + + execScopes := types.NewExecutionScopes() + err := hintProcessor.ExecuteHint(vm, &hintData, nil, execScopes) + + if err != nil { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed with error: %s", err) + } + + slopeUncast, _ := execScopes.Get("slope") + slope := slopeUncast.(*big.Int) + xUncast, _ := execScopes.Get("x") + x := xUncast.(*big.Int) + yUncast, _ := execScopes.Get("y") + y := yUncast.(*big.Int) + valueUncast, _ := execScopes.Get("value") + value := valueUncast.(big.Int) + new_xUncast, _ := execScopes.Get("new_x") + new_x := new_xUncast.(big.Int) + + if value.Cmp(&new_x) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: new_x != value. %v != %v", new_x, value) + } + expectedRes := big.NewInt(5) + if value.Cmp(expectedRes) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected value (%v) to be 6", value) + } + expectedSlope := big.NewInt(3) + if slope.Cmp(expectedSlope) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected slope (%v) to be 3", slope) + } + expectedX := big.NewInt(2) + if x.Cmp(expectedX) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected x (%v) to be 2", x) + } + expectedY := big.NewInt(4) + if y.Cmp(expectedY) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected y (%v) to be 4", y) + } +} + func TestRunComputeSlopeV2Ok(t *testing.T) { vm := NewVirtualMachine() diff --git a/pkg/hints/hint_codes/ec_op_hints.go b/pkg/hints/hint_codes/ec_op_hints.go index 471246db..e8ab665c 100644 --- a/pkg/hints/hint_codes/ec_op_hints.go +++ b/pkg/hints/hint_codes/ec_op_hints.go @@ -4,10 +4,39 @@ const EC_NEGATE = "from starkware.cairo.common.cairo_secp.secp_utils import SECP const EC_NEGATE_EMBEDDED_SECP = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nSECP_P = 2**255-19\n\ny = pack(ids.point.y, PRIME) % SECP_P\n# The modulo operation in python always returns a nonnegative number.\nvalue = (-y) % SECP_P" const EC_DOUBLE_SLOPE_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=0, p=SECP_P)" const COMPUTE_SLOPE_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import line_slope\n\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" -const EC_DOUBLE_SLOPE_EXTERNAL_CONSTS = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=ALPHA, p=SECP_P)" -const NONDET_BIGINT3_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import split\n\nsegments.write_arg(ids.res.address_, split(value))" +const EC_DOUBLE_ASSIGN_NEW_X_V1 = `from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V2 = `from starkware.cairo.common.cairo_secp.secp_utils import pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V3 = `from starkware.cairo.common.cairo_secp.secp_utils import pack +SECP_P = 2**255-19 + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V4 = `from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.pt.x, PRIME) +y = pack(ids.pt.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` const COMPUTE_SLOPE_V2 = "from starkware.python.math_utils import line_slope\nfrom starkware.cairo.common.cairo_secp.secp_utils import pack\nSECP_P = 2**255-19\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" const COMPUTE_SLOPE_WHITELIST = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import div_mod\n\n# Compute the slope.\nx0 = pack(ids.pt0.x, PRIME)\ny0 = pack(ids.pt0.y, PRIME)\nx1 = pack(ids.pt1.x, PRIME)\ny1 = pack(ids.pt1.y, PRIME)\nvalue = slope = div_mod(y0 - y1, x0 - x1, SECP_P)" +const EC_DOUBLE_SLOPE_EXTERNAL_CONSTS = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=ALPHA, p=SECP_P)" +const NONDET_BIGINT3_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import split\n\nsegments.write_arg(ids.res.address_, split(value))" const COMPUTE_SLOPE_SECP256R1 = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import line_slope\n\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" const FAST_EC_ADD_ASSIGN_NEW_X = `"from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index a43c76ef..ccac7bd2 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -92,6 +92,8 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return ecNegateImportSecpP(vm, *execScopes, data.Ids) case EC_NEGATE_EMBEDDED_SECP: return ecNegateEmbeddedSecpP(vm, *execScopes, data.Ids) + case EC_DOUBLE_ASSIGN_NEW_X_V1, EC_DOUBLE_ASSIGN_NEW_X_V2, EC_DOUBLE_ASSIGN_NEW_X_V3, EC_DOUBLE_ASSIGN_NEW_X_V4: + return ecDoubleAssignNewX(vm, *execScopes, data.Ids, SECP_P_V2()) case POW: return pow(data.Ids, vm) case SQRT: diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index f2921821..845a3076 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -329,6 +329,10 @@ func TestSplitIntHintProofMode(t *testing.T) { testProgramProof("split_int", t) } +func TestEcDoubleAssign(t *testing.T) { + testProgram("ec_double_assign", t) +} + func TestIntegrationEcDoubleSlope(t *testing.T) { testProgram("ec_double_slope", t) } From 4f3d9e713de0ce40ab661bceb608c3cee8036ce5 Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:51:16 +0300 Subject: [PATCH 3/7] Add generic function to fetch scope variables with a generic type (#293) * Add generic way to fetch scope variables * Use more specific error --------- Co-authored-by: Pedro Fontana --- pkg/types/exec_scope_test.go | 18 ++++++++++++++++++ pkg/types/exec_scopes.go | 21 +++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pkg/types/exec_scope_test.go b/pkg/types/exec_scope_test.go index 119f266d..6eb1d97e 100644 --- a/pkg/types/exec_scope_test.go +++ b/pkg/types/exec_scope_test.go @@ -258,3 +258,21 @@ func TestErrExitMainScope(t *testing.T) { t.Errorf("TestErrExitMainScope should fail with error: %s and fails with: %s", types.ErrCannotExitMainScop, err) } } + +func TestFetchScopeVar(t *testing.T) { + scope := make(map[string]interface{}) + scope["k"] = lambdaworks.FeltOne() + + scopes := types.NewExecutionScopes() + scopes.EnterScope(scope) + + result, err := types.FetchScopeVar[lambdaworks.Felt]("k", scopes) + if err != nil { + t.Errorf("TestGetLocalVariables failed with error: %s", err) + + } + expected := lambdaworks.FeltOne() + if expected != result { + t.Errorf("TestGetLocalVariables failed, expected: %s, got: %s", expected.ToSignedFeltString(), result.ToSignedFeltString()) + } +} diff --git a/pkg/types/exec_scopes.go b/pkg/types/exec_scopes.go index 357a1cec..bb372f40 100644 --- a/pkg/types/exec_scopes.go +++ b/pkg/types/exec_scopes.go @@ -18,6 +18,10 @@ func ErrVariableNotInScope(varName string) error { return ExecutionScopesError(errors.Errorf("Variable %s not in scope", varName)) } +func ErrVariableHasWrongType(varName string) error { + return ExecutionScopesError(errors.Errorf("Scope variable %s has wrong type", varName)) +} + func NewExecutionScopes() *ExecutionScopes { data := make([]map[string]interface{}, 1) data[0] = make(map[string]interface{}) @@ -82,3 +86,20 @@ func (es *ExecutionScopes) Get(varName string) (interface{}, error) { } return val, nil } + +// Generic version of ExecutionScopes.Get which also handles casting +func FetchScopeVar[T interface{}](varName string, scopes *ExecutionScopes) (T, error) { + locals, err := scopes.GetLocalVariables() + if err != nil { + return *new(T), err + } + valAny, prs := locals[varName] + if !prs { + return *new(T), ErrVariableNotInScope(varName) + } + val, ok := valAny.(T) + if !ok { + return *new(T), ErrVariableHasWrongType(varName) + } + return val, nil +} From 8c01a734eb39616ff94d060c542326ea04f2cf71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Rodr=C3=ADguez=20Chatruc?= <49622509+jrchatruc@users.noreply.github.com> Date: Fri, 29 Sep 2023 12:42:36 -0300 Subject: [PATCH 4/7] Hints and proof mode documentation (#286) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [WIP] Hints and proof mode documentation * Add temporary segments to the TODO list * Update README.md Co-authored-by: fmoletta <99273364+fmoletta@users.noreply.github.com> * Small improvement * Draft:Move to code walkthrough? * Fix title hierarchy * Structure section * Reorder sections following a more comfortable implementation order * Fixes * Start ids * Update README.md Co-authored-by: Antonio Calvín García * Add documentation on references * Add computing references section * Add IdsManager * Add CompileHint docu * Add constants dpcu * Add ExecScopes docu * Fix subtitles * Add example with ids * Add exec scopes example * Reorder todos * Improvements & Corrections * fix typo * fix fmt * Improve langage * language --------- Co-authored-by: fmoletta <99273364+fmoletta@users.noreply.github.com> Co-authored-by: Federica Co-authored-by: Antonio Calvín García --- README.md | 894 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 885 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 3b0b14c3..b62e1646 100644 --- a/README.md +++ b/README.md @@ -1548,7 +1548,7 @@ func (vm *VirtualMachine) UpdateAp(instruction *Instruction, operands *Operands) } ``` -### CairoRunner +#### CairoRunner Now that can can execute cairo steps, lets look at the VM's initialization step. We will begin by creating our `CairoRunner`: @@ -1589,7 +1589,7 @@ func (r *CairoRunner) Initialize() (memory.Relocatable, error) { } ``` -#### InitializeSegments +##### InitializeSegments This method will create our program and execution segments @@ -1602,7 +1602,7 @@ func (r *CairoRunner) initializeSegments() { } ``` -#### initializeMainEntrypoint +##### initializeMainEntrypoint This method will initialize the memory and initial register values to begin execution from the main entrypoint, and return the final pc @@ -1614,7 +1614,7 @@ func (r *CairoRunner) initializeMainEntrypoint() (memory.Relocatable, error) { } ``` -#### initializeFunctionEntrypoint +##### initializeFunctionEntrypoint This method will initialize the memory and initial register values to execute a cairo function given its offset within the program segment (aka entrypoint) and return the final pc. In our case, this function will be the main entrypoint, but later on we will be able to use this method to run starknet contract entrypoints. The stack will then be loaded into the execution segment in the next method. For now, the stack will be empty, but later on it will contain the builtin bases (which are the arguments for the main function), and the function arguments when running a function from a starknet contract. @@ -1631,7 +1631,7 @@ func (r *CairoRunner) initializeFunctionEntrypoint(entrypoint uint, stack *[]mem } ``` -#### InitializeState +##### InitializeState This method will be in charge of loading the program data into the program segment and the stack into the execution segment @@ -1648,7 +1648,7 @@ func (r *CairoRunner) initializeState(entrypoint uint, stack *[]memory.MaybeRelo } ``` -#### initializeVm +##### initializeVm This method will set the values of the VM's `RunContext` with our `CairoRunner`'s initial values @@ -1662,7 +1662,7 @@ func (r *CairoRunner) initializeVM() { With `CairoRunner.Initialize()` now complete we can move on to the execution step: -#### RunUntilPc +##### RunUntilPc This method will continuously execute cairo steps until the end pc, returned by 'CairoRunner.Initialize()' is reached @@ -2438,7 +2438,7 @@ func (vm *VirtualMachine) ComputeOperands(instruction Instruction) (Operands, er With all of our builtin logic integrated into the codebase, we can implement any builtin and use it in our cairo programs while worrying only about implementing the `BuiltinRunner` interface and creating the builtin in the `NewCairoRunner` function. -##### RangeCheck +#### RangeCheck The `RangeCheck` builtin does a very simple thing: it asserts that a given number is in the range $[0, 2^{128})$, i.e., that it's greater than zero and less than $2^{128}$. This might seem superficial but it is used for a lot of different things in Cairo, including comparing numbers. Whenever a program asserts that some number is less than other, the range check builtin is being called underneath. @@ -2517,7 +2517,7 @@ func (r *RangeCheckBuiltinRunner) AddValidationRule(mem *memory.Memory) { } `````` -##### Output +#### Output TODO @@ -2968,5 +2968,881 @@ TODO #### Hints +So far we have been thinking about the VM mostly abstracted from the prover and verifier it's meant to feed its results to. The last main feature we need to talk about, however, requires keeping this proving/verifying logic in mind. +As a reminder, the whole point of the Cairo VM is to output a trace/memory file so that a `prover` can then create a cryptographic proof that the execution of the program was done correctly. A `verifier` can then take that proof and verify it in much less time than it would have taken to re-execute the entire program. +In this model, the one actually using the VM to run a cairo program is *always the prover*. The verifier does not use the VM in any way, as that would defeat the entire purpose of validity proofs; they just get the program being run and the proof generated by the prover and run some cryptographic algorithm to check it. + +While the verifier does not execute the code, they do *check it*. As an example, if a cairo program computes a fibonacci number like this: + +``` +func main() { + // Call fib(1, 1, 10). + let result: felt = fib(1, 1, 10); +} +``` + +the verifier won't *run* this, but they will reject any incorrect execution of the call to `fib`. The correct value for `result` in this case is `144` (it's the 10th fibonacci number); any attempt by the prover to convince the verifier that `result` is not `144` will fail, because the call to the `fib` function is *being proven* and thus *seen* by the verifier. + +A `Hint` is a piece of code that is not proven, and therefore not seen by the verifier. If `fib` above were a hint, then the prover could convince the verifier that `result` is $144$, $0$, $1000$ or any other number. + +In cairo 0, hints are code written in `Python` and are surrounded by curly brackets. Here's an example from the `alloc` function, provided by the Cairo common library + +``` +func alloc() -> (ptr: felt*) { + %{ memory[ap] = segments.add() %} + ap += 1; + return (ptr=cast([ap - 1], felt*)); +} +``` + +The first line of the function, + +``` +%{ memory[ap] = segments.add() %} +``` + +is a hint called `ADD_SEGMENT`. All it does is create a new memory segment, then write its base to the current value of `ap`. This is python code that is being run in the context of the VM's execution; thus `memory` refers to the VM's current memory and `segments.add()` is just a function provided by the VM to allocate a new segment. + +At this point you might be wondering: why run code that's not being proven? Isn't the whole point of Cairo to prove correct execution? There are (at least) two reasons for hints to exist. + +##### Nothing to prove + +For some operations there's simply nothing to prove, as they are just convenient things one wants to do during execution. The `ADD_SEGMENT` hint shown above is a good example of that. When proving execution, the program's memory is presented as one relocated continuous segment, it does not matter at all which segment a cell was in, or when that segment was added. The verifier doesn't care. + +Because of this, there's no reason to make `ADD_SEGMENT` a part of the cairo language and have an instruction for it. + +##### Optimization + +Certain operations can be very expensive, in the sense that they might involve a huge amount of instructions or memory usage, and therefore contribute heavily to the proving time. For certain calculations, there are two ways to convince the verifier that it was done correctly: + +- Write the entire calculation in Cairo/Cairo Assembly. This makes it show up in the trace and therefore get proven. +- *Present the result of the calculation to the verifier through a hint*, then show said result indeed satisfies the relevant condition that makes it the actual result. + +To make this less abstract, let's show two examples. + +##### Square root + +Let's say the calculation in question is to compute the square root of a number `x`. The two ways to do it then become: + +- Write the usual square root algorithm in Cairo to compute `sqrt(x)`. +- Write a hint that computes `sqrt(x)`, then immediately after calling the hint show __in Cairo__ that `(sqrt(x))^2 = x`. + +The second approach is exactly what the `sqrt` function in the Cairo common library does: + +``` +// Returns the floor value of the square root of the given value. +// Assumptions: 0 <= value < 2**250. +func sqrt{range_check_ptr}(value) -> felt { + alloc_locals; + local root: felt; + + %{ + from starkware.python.math_utils import isqrt + value = ids.value % PRIME + assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)." + assert 2 ** 250 < PRIME + ids.root = isqrt(value) + %} + + assert_nn_le(root, 2 ** 125 - 1); + tempvar root_plus_one = root + 1; + assert_in_range(value, root * root, root_plus_one * root_plus_one); + + return root; +} +``` + +If you read it carefully, you'll see that the hint in this function computes the square root in python, then this line + +``` +assert_in_range(value, root * root, root_plus_one * root_plus_one); +``` + +asserts __in Cairo__ that `(sqrt(x))^2 = x`. + +This is done this way because it is much cheaper, in terms of the generated trace (and thus proving time), to square a number than compute its square root. + +Notice that the last assert is absolutely mandatory to make this safe. If you forget to write it, the square root calculation does not get proven, and anyone could convince the verifier that the result of `sqrt(x)` is any number they like. + +##### Linear search turned into an O(1) lookup + +This example is taken from the [Cairo documentation](https://docs.cairo-lang.org/0.12.0/hello_cairo/program_input.html). + +Given a list of `(key, value)` pairs, if we want to write a `get_value_by_key` function that returns the value associated to a given key, there are two ways to do it: + +- Write a linear search in Cairo, iterating over each key until you find the requested one. +- Do that exact same linear search *inside a hint*, find the result, then show that the result's key is the one requested. + +Again, the second approach makes the resulting trace and proving much faster, because it's just a lookup; there's no linear search. Notice this only applies to proving, the VM has to execute the hint, so there's still a linear search when executing to generate the trace. In fact, the second approach is more expensive for the VM than the first one. It has to do both a linear search and a lookup. This is a tradeoff in favor of proving time. + +Also note that, as in the square root example, when writing this logic you need to remember to show the hint's result is the correct one in Cairo. If you don't, your code is not being proven. + +##### Non-determinism + +The Cairo paper and documentation refers to this second approach to calculating things through hints as *non-determinism*. The reason for this is that sometimes there is more than one result that satisfies a certain condition. This means that cairo execution becomes non deterministic; a hint could output multiple values, and in principle there is no way to know which one it's going to be. Running the same code multiple times could give different results. + +The square root is an easy example of this. The condition `(sqrt(x))^2 = x` is not unique, there are two solutions to it. Without the hint, this is non-deterministic, `x` could have multiple values; the hint resolves that by choosing a specific value when being run. + +##### Common Library and Hints + +As explained above, using hints in your code is highly unsafe. Forgetting to add a check after calling them can make your code vulnerable to any sorts of attacks, as your program will not prove what you think it proves. + +Because of this, most hints in Cairo 0 are wrapped around or used by functions in the Cairo common library that do the checks for you, thus making them safe to use. Ideally, Cairo developers should not be using hints on their own; only transparently through Cairo library functions they call. + +##### Whitelisted Hints + +In Cairo, a hint could be any Python code you like. In the context of it as just another language someone might want to use, this is fine. In the context of Cairo as a programming language used to write smart contracts deployed on a blockchain, it's not. Users could deploy contracts with hints that simply do + +```python +while true: + pass +``` + +and grind the network down to a halt, as nodes get stuck executing an infinite loop when calling the contract. + +To address this, the starknet network maintains a list of *whitelisted* hints, which are the only ones that can be used in starknet contracts. These are the ones implemented in this VM. + +#### Implementing Hints + +Hints are essentially logic that is executed in each cairo step, before the next instruction, and which may interact with and modify the vm. We will first look into the broad execution loop and the dive into the different types of interaction hints can have with the vm. +While the original cairo-lang implementation executes these hints in python, we will instead be implementing their logic in go and matching each string of python code to a function in the vm's code. We will also be using an interface to abstract the hint processing part of the vm and allow greater flexibility when using the vm in other contexts. + +##### The HintProcessor interface + +This `HintProcessor` interface will consist of two methods: `CompileHint`, which receives hint data from the compiled program and transforms it into whatever format is more convenient for hint execution, and `ExecuteHint`, which will receive this data and use it to execute the hint. + +```go +type HintProcessor interface { + // Transforms hint data outputted by the VM into whichever format will be later used by ExecuteHint + CompileHint(hintParams *parser.HintParams, referenceManager *parser.ReferenceManager) (any, error) + // Executes the hint which's data is provided by a dynamic structure previously created by CompileHint + ExecuteHint(vm *VirtualMachine, hintData *any, constants *map[string]lambdaworks.Felt, execScopes *types.ExecutionScopes) error +} +``` + +We will first look at how hint processing ties into the core vm execution loop, and then look into how this vm's implementation of the `HintProcessor` interface works: + +##### VM execution loop + +Before we begin executing steps, we will feed the hint-related information from the compiled program to the `HintProcessor`, and obtain what we call `HintData`, which will be later on used to execute the hint. As we can see, the compiled json stores the hint information in a map which connects pc offsets (at which pc offset the hint should be executed) to a list of hints (yes, more than one hint can be executed as a given pc), and we will use a similar structure to hold the compiled `HintData`. +```go +func (r *CairoRunner) BuildHintDataMap(hintProcessor vm.HintProcessor) (map[uint][]any, error) { + hintDataMap := make(map[uint][]any) + for pc, hintsParams := range r.Program.Hints { + hintDatas := make([]any, 0, len(hintsParams)) + for _, hintParam := range hintsParams { + data, err := hintProcessor.CompileHint(&hintParam) + if err != nil { + return nil, err + } + hintDatas = append(hintDatas, data) + } + hintDataMap[pc] = hintDatas + } + + return hintDataMap, nil +} +``` + +Once we have our map of `HintData`s we can start executing cairo steps. Before fetching the next instruction, we will check if we have hints to run for the current pc, and if we do, the `HintProcessor` will execute each hint using the corresponding `HintData`. + +```go +func (v *VirtualMachine) Step(hintProcessor HintProcessor, hintDataMap *map[uint][]any) error { + // Run Hint + hintDatas, ok := (*hintDataMap)[v.RunContext.Pc.Offset] + if ok { + for i := 0; i < len(hintDatas); i++ { + err := hintProcessor.ExecuteHint(v, &hintDatas[i]) + if err != nil { + return err + } + } + } + + // Run Instruction + encoded_instruction, err := v.Segments.Memory.Get(v.RunContext.Pc) +``` + +##### Implementing a HintProcessor: ExecuteHint + +This method will receive a `HintData`, and match its `Code` field, which contains the python code as a string, to a go function that implements its logic: + +```go +func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any) error { + data, ok := (*hintData).(HintData) + if !ok { + return errors.New("Wrong Hint Data") + } + switch data.Code { + case ADD_SEGMENT: + return addSegment(vm) + default: + return errors.Errorf("Unknown Hint: %s", data.Code) + } +} +``` + +Where `ADD_SEGMENT` is a constant with the python code of the hint + +```go +const ADD_SEGMENT = "memory[ap] = segments.add()" +``` + +And the function `addSegment` implements its logic, which is to add a segment to the vm's memory: + +```go +func addSegment(vm *VirtualMachine) error { + newSegmentBase := vm.Segments.AddSegment() + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableRelocatable(newSegmentBase)) +} +``` + +Before we implement the `CompileHint` method, lets look at this crucial part of hint interaction with the vm: + +##### Hint Interaction: Ids + +Ids are hints' way to interact with variables in a cairo program. For example, if I declare a variable `n` in my cairo code, I can access that `n` variable inside a hint using `ids.n`. + +The following cairo snippet would print the number 17 + +```py + let n = 17 + %{ print(ids.n) %} +``` + +To access these variables when implementing our hints in go we will implement the `IdsManager`, which will allow us to read and write cairo variables. +But interacting with cairo variables is not as easy as it sounds. In order to access them, we must first compute their address from a `Reference` + +###### References + +As cairo variables are created during the vm's execution, we can't know their value beforehand. In order to solve this, the compiled program provides us with references for cairo variables available to hints. These references are instructions on where we can find a specific cairo variable in memory. For example, they might tell us to take the current value of the fp register, substract 1 from it, and access the memory value at that new address. + +As these references come in string format, we need to parse them into a struct that we can efficiently use to compute addresses: + +```go +type HintReference struct { + Offset1 OffsetValue + Offset2 OffsetValue + Dereference bool + ApTrackingData parser.ApTrackingData + ValueType string +} +``` + +This struct matches the canonical string format for references: `"cast(Offset1 + Offset2, ValueType)"` (or `"[cast(Offset1 + Offset2, ValueType)]"`, in the case of Dereference being true ). +The first two fields: Offset1 and Offset2 will lead us to a particular memory value, the Dereference field will tell us if the value of the ids is that memory value we found (in case of false), or if we should use that value as an address to fetch the ids value from memory (in case of true), and the ValueType tells us what type the variable has (be it a felt, felt*, struct, etc). As we already know the context of the hints, we can ignore the ValueType. + +Now lets look at what an `OffsetValue` is: + +```go +type OffsetValue struct { + ValueType offsetValueType + Immediate Felt + Value int + Register vm.Register + Dereference bool +} + +type offsetValueType uint + +const ( + Value offsetValueType = 0 + Immediate offsetValueType = 1 + Reference offsetValueType = 2 +) +``` + +There are three types of `OffsetValue`: + +* Inmediate: Contains the value of the ids as a literal, for example `"cast(17, felt)"` is a reference to a felt with literal value 17. Only Offset1 can be of Immediate type, and the reference can't have Dereference = true + +* Reference: It is made up of a Register (AP or FP) and a Value, it will tell us the location of an ids in memory by pointing to a memory cell relative to a register. For example `"cast(fp + (-1), felt*)"` is a reference with Offset1 of type Reference, with register FP and Value -1, and it leads us to an felt* value obtained from subtracting 1 from the current fp value. OffsetValues of type Reference can also have Dereference, for example: `"cast([fp + (-1)], felt)"` will lead us to a felt value located one cell before the one at the current register value. Both OffsetValues can be of type Reference in the same Reference + +* Value: Only Offset2 can be of type value, it consists of a single field value and acts as a modifier to the first OffsetValue (which will always be of type Reference for this case). For example, we can add second OffsetValue of Value type with Value = 1 to the first Reference type example: `"cast(fp + (-1) + 2), felt*)"`, this will tell us to subtract 1 from fp, and then add 2 to it, and that will be our ids value. + +When an offset doesn't exist in the reference, we use an OffsetValue of type Value with Value 0, which essentially does nothing, to represent it. This allows us to use go's zero value by default to make our code (and life) a bit simpler. + +This can be a bit hard to grasp at first so lets look at some examples: + +* Immediate Reference + String Reference: `cast(17, felt)` + Struct Reference: {Offset1: {ValueType: Immediate, Immedate: 17}, ValueType: "felt"} + Reference in words: The value of the ids is 17 + +* Dereference with one offset of Type Reference + String Reference: `[cast(ap + 1, felt)]` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1}, Dereference: true, ValueType: "felt"} + Reference in words: Take the current value of ap, add 1 to it and then fetch the memory value at that address + +* Two offsets of type Reference, Value + String Reference: `"cast(ap + 1 + (-2), felt*)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1}, Offset2: {ValueType: Value, Value: -2}, ValueType: "felt*"} + Reference in words: Take the current value of ap, add 1 to it and then subtract 2 from it + +* Two offsets of type Reference (with Dereference), Value + String Reference: `"cast([ap + (-1)] + (-2), felt*)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1, Dereference: true}, Offset2: {ValueType: Value, Value: -2}, ValueType: "felt*"} + Reference in words: Take the current value of ap, add 1 to it, fetch the memory value at that address and then subtract 2 from it + +* Two offsets of type Reference (with Dereference) + String Reference: `"cast([ap + (-1)] + [ap], felt)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1, Dereference: true}, Offset2: {ValueType: Reference, Register: AP, Value: 0, Dereference: true}, ValueType: "felt*"} + Reference in words: Take the current value of ap, subtract 1 to it, fetch the memory value at that address. Take the current value of ap, fetch the memory value at that address. Add the two values we obtained. + +Now all thats left to analyze is in the reference is the `ApTracking`: + +```go +type ApTrackingData struct { + Group int `json:"group"` + Offset int `json:"offset"` +} +``` + +As the value of AP is constantly changing with each instruction executed, its not that simple to track variables who's references are based on ap. ApTracking is used to calculate the difference between the value of ap at the moment the variable was created/ enterted the scope of the function (and hence, the hint) and the value of ap at the moment the hint is executed. Each hint and each reference has its own ApTracking. + +###### Computing addresses using References + +The function used to fetch the value from an ids variable using a reference works as follows: +1. Check if the refeference has type Immediate, if this is true, return the Immediate field +2. Calculate the address of the ids variable using the reference (we will see how this works soon) +3. Check the Dereference field of the reference, if false, return the address we obtained in 2, if true, fetch the memory value at that address and return it. + +```go +func getValueFromReference(reference *HintReference, apTracking parser.ApTrackingData, vm *VirtualMachine) (*MaybeRelocatable, bool) { + // Handle the case of immediate + if reference.Offset1.ValueType == Immediate { + return NewMaybeRelocatableFelt(reference.Offset1.Immediate), true + } + addr, ok := getAddressFromReference(reference, apTracking, vm) + if ok { + if reference.Dereference { + val, err := vm.Segments.Memory.Get(addr) + if err == nil { + return val, true + } + } else { + return NewMaybeRelocatableRelocatable(addr), true + } + } + return nil, false +} +``` + +In order to extract the value of an ids variable, we will first compute its address, this works as follows: +1. Check that the Offset1 is a Reference +2. Compute the value of Offset1 +3. Add the value of Offet2. By either calculating it in the case of a Reference type, or just using the Value field in the case of a Value type. +4. Return the result obtained in step 3. +```go +func getAddressFromReference(reference *HintReference, apTracking parser.ApTrackingData, vm *VirtualMachine) (Relocatable, bool) { + if reference.Offset1.ValueType != Reference { + return Relocatable{}, false + } + offset1 := getOffsetValueReference(reference.Offset1, reference.ApTrackingData, apTracking, vm) + if offset1 != nil { + offset1_rel, is_rel := offset1.GetRelocatable() + if is_rel { + switch reference.Offset2.ValueType { + case Reference: + offset2 := getOffsetValueReference(reference.Offset2, reference.ApTrackingData, apTracking, vm) + if offset2 != nil { + res, err := offset1_rel.AddMaybeRelocatable(*offset2) + if err == nil { + return res, true + } + } + case Value: + res, err := offset1_rel.AddInt(reference.Offset2.Value) + if err == nil { + return res, true + } + } + } + } + return Relocatable{}, false + +} +``` + +Now lets see how computing the value of an OffsetValue of type Reference works: +1. Determine a base address by checking the Register field of the OffsetValue. If the register is FP, use the current value of fp. If the register is AP, apply the necessary ap tracking corrections to ap and use it as base address. +2. Add the field Value of the OffsetValue to the base address +3. Check the Dereference field of the OffsetValue. If its false, return the address we obtained in 2. If is true, fetch the memory value at that address and return it + +```go +func getOffsetValueReference(offsetValue OffsetValue, refApTracking parser.ApTrackingData, hintApTracking parser.ApTrackingData, vm *VirtualMachine) *MaybeRelocatable { + var baseAddr Relocatable + ok := true + switch offsetValue.Register { + case FP: + baseAddr = vm.RunContext.Fp + case AP: + baseAddr, ok = applyApTrackingCorrection(vm.RunContext.Ap, refApTracking, hintApTracking) + } + if ok { + baseAddr, err := baseAddr.AddInt(offsetValue.Value) + if err == nil { + if offsetValue.Dereference { + // val will be nil if err is not nil, so we can ignore it + val, _ := vm.Segments.Memory.Get(baseAddr) + return val + } else { + return NewMaybeRelocatableRelocatable(baseAddr) + } + } + } + return nil +} +``` + +Finally, the last thing we need is to know how ap tracking corrections work. +This function will receive an address (the current value of ap), the ap tracking data of the reference (unique to each reference) and the hint's ap tracking data (unique to each hint) and perform the following steps: +1. Assert that both ap tracking datas belong to the same group (aka their Group fields match) +2. Subtract the difference between the hint's ap tracking data's Offset field and the reference's ap tracking data's Offset field from the address (ap) +3. Return the value obtained in 2 + +```go +func applyApTrackingCorrection(addr Relocatable, refApTracking parser.ApTrackingData, hintApTracking parser.ApTrackingData) (Relocatable, bool) { + // Reference & Hint ApTracking must belong to the same group + if refApTracking.Group == hintApTracking.Group { + addr, err := addr.SubUint(uint(hintApTracking.Offset - refApTracking.Offset)) + if err == nil { + return addr, true + } + } + return Relocatable{}, false +} +``` + +###### Implement the IdsManager + +Now that we have tackled reference management, we can implement the `IdsManager`, which will allow us to "forget" what references are when implementing hints. + +The IdsManager has the following structure: + +* References: A map of all the ids variables the hint has access to, it maps the name of the cairo varaible to a HintReference (the parsed version of the compiled program's Reference) +* HintAptracking: The ap tracking data unique to the hint + +```go +type IdsManager struct { + References map[string]HintReference + HintApTracking parser.ApTrackingData +} +``` + +And we can also implement friendlier versions of the functions we implemented in the previous section, that take the name of the ids variable, instead of the reference and hint ap tracking data: + +```go +// Returns the value of an identifier as a MaybeRelocatable +func (ids *IdsManager) Get(name string, vm *VirtualMachine) (*MaybeRelocatable, error) { + reference, ok := ids.References[name] + if ok { + val, ok := getValueFromReference(&reference, ids.HintApTracking, vm) + if ok { + return val, nil + } + } + return nil, ErrUnknownIdentifier(name) +} + +// Returns the address of an identifier given its name +func (ids *IdsManager) GetAddr(name string, vm *VirtualMachine) (Relocatable, error) { + reference, ok := ids.References[name] + if ok { + addr, ok := getAddressFromReference(&reference, ids.HintApTracking, vm) + if ok { + return addr, nil + } + } + return Relocatable{}, ErrUnknownIdentifier(name) +} +``` + +We can also make more specialized versions of the Get method, that will also handle conversions to Felt or Relocatable, as we will almost always know which type of value we are expecting when implementing hints: + +```go +// Returns the value of an identifier as a Felt +func (ids *IdsManager) GetFelt(name string, vm *VirtualMachine) (lambdaworks.Felt, error) { + val, err := ids.Get(name, vm) + if err != nil { + return lambdaworks.Felt{}, err + } + felt, is_felt := val.GetFelt() + if !is_felt { + return lambdaworks.Felt{}, ErrIdentifierNotFelt(name) + } + return felt, nil +} + +// Returns the value of an identifier as a Relocatable +func (ids *IdsManager) GetRelocatable(name string, vm *VirtualMachine) (Relocatable, error) { + val, err := ids.Get(name, vm) + if err != nil { + return Relocatable{}, err + } + relocatable, is_relocatable := val.GetRelocatable() + if !is_relocatable { + return Relocatable{}, errors.Errorf("Identifier %s is not a Relocatable", name) + } + return relocatable, nil +} +``` + +Lastly, we can also implement a method to insert a value into an ids variable (as we already know how to calculate their address) + +```go +// Inserts value into memory given its identifier name +func (ids *IdsManager) Insert(name string, value *MaybeRelocatable, vm *VirtualMachine) error { + + addr, err := ids.GetAddr(name, vm) + if err != nil { + return err + } + return vm.Segments.Memory.Insert(addr, value) +} +``` + +##### Implementing a HintProcessor: CompileHint + + +The `CompileHint` method will be in charge of converting the hint-related data from the compiled json into a format that our processor can use to execute each hint. For our `CairoVmHintProcessor` we will use the following struct: + +```go +type HintData struct { + Ids IdsManager + Code string +} +``` +Where IdsManager is the struct we just saw in the previous section, a struct which manages all kinds of interaction between the hint implemented in go and the cairo variables available to it, and Code is the python code of the hint. + +And we will implement a `CompileHint` method which receives the hint's data from the compiled program in the form of `HintParams`, and a reference to the compiled json's `ReferenceManager`, a list of references to all ids variables in the program. And performs the following steps: + +1. Create a map from variable name to HintReference +2. Iterate over the hintParams's `ReferenceIds` field (a map from an ids name to an index in the ReferenceManager). For each iteration: + + 1. Remove the path from the reference's name (shortening full paths such a "__main__.a" to just the variable name "a"), + 2. Fetch the reference from the ReferenceManager (using the index from the ReferenceIds) + 3. Parse the Reference into a `HintReference` + 4. Insert the parsed reference into the map we created in 1, using the shortened name (from 2.1) as a key + +3. Create an IdsManager using the map from 1, and the hintParam's ap tracking data +4. Create a `HintData` struct with the IdsManager and the hintParam's Code + +```go +func (p *CairoVmHintProcessor) CompileHint(hintParams *parser.HintParams, referenceManager *parser.ReferenceManager) (any, error) { + references := make(map[string]HintReference, 0) + for name, n := range hintParams.FlowTrackingData.ReferenceIds { + if int(n) >= len(referenceManager.References) { + return nil, errors.New("Reference not found in ReferenceManager") + } + split := strings.Split(name, ".") + name = split[len(split)-1] + references[name] = ParseHintReference(referenceManager.References[n]) + } + ids := NewIdsManager(references, hintParams.FlowTrackingData.APTracking) + return HintData{Ids: ids, Code: hintParams.Code}, nil +} +``` + +##### Hint Interaction: Constants + +###### How are Constants handled by hints and the cairo compiler + +Hints can also access constant variables using the ids syntax, for example, a hint can access the `MAX_SIZE` constant from a cairo program using `ids.MAX_SIZE`. While the behaviour from the hint's standpoint is identical to regular ids variables, they are handled differently by both the compiler and the vm. + +They are part of the compiled program's `Idenfifiers` field, and can be identified by the `const` type. We may also find aliases for them in the `Identifiers` section, aliases happen when a cairo file imports constants from another cairo file, in such cases we will have an identifier of type `const` under the file where the constant was declared's path, and an identifier of type `alias` under the file where the constant was imported's path, pointing to the original constant's identifier. For example: + +```json +"starkware.cairo.common.cairo_keccak.keccak.BLOCK_SIZE": { + "destination": "starkware.cairo.common.cairo_keccak.packed_keccak.BLOCK_SIZE", + "type": "alias" + }, +"starkware.cairo.common.cairo_keccak.packed_keccak.BLOCK_SIZE": { + "type": "const", + "value": 3 + }, +``` + +This is an extract from a compiled cairo program, where we can see that there is a constant `BLOCK_SIZE`, with value 3, declared in packed_keccak.cairo file, that was then imported by the keccak.cairo file. + +###### How does the vm extract the constants for hint execution + +As constants are not unique to any specific hint, they are not provided to the HintProcessor's `CompileHint` method, but are instead provided directly to the `ExecuteHint` method. Before providing these constants, we need to first extract them from the Identifiers field of the compiled program. This works as follows: + +1. Create a map to store the constants, maping full path constant names to their Felt value +2. Iterate over the program's `Identifiers` field, and check the type of each identifier. If the identifier is of type `const`, add its value to the map created in 1. If the identifier is of type `alias`, search for the identifier at its destination (we will see how to do this next), and if its of type `const`, add it to the map created in 1 under the alias' name. +3. Return the map created in 1 + +```go +func (p *Program) ExtractConstants() map[string]lambdaworks.Felt { + constants := make(map[string]lambdaworks.Felt) + for name, identifier := range p.Identifiers { + switch identifier.Type { + case "const": + constants[name] = identifier.Value + case "alias": + val, ok := searchConstFromAlias(identifier.Destination, &p.Identifiers) + if ok { + constants[name] = val + } + } + } + return constants +} +``` + +In order to search for the aliased identifier we need to do so recursively, as constants can be imported form file A into file B, then from file B into file C and so on. +To do so we use a recursive function which receives the destination field of an alias type identifier and a reference to the identifiers map. It will then look for the identifier using the received destination. If the new identifier is a constant, it wil return its value, if it is an alias it will call itself again with the new alias' destintation, and if its none, it will return false, indicating that the alias was not pointing to a constant. + +```go +func searchConstFromAlias(destination string, identifiers *map[string]Identifier) (lambdaworks.Felt, bool) { + identifier, ok := (*identifiers)[destination] + if ok { + switch identifier.Type { + case "const": + return identifier.Value, true + case "alias": + return searchConstFromAlias(identifier.Destination, identifiers) + } + } + return lambdaworks.Felt{}, false +} +``` + +###### How does the IdsManager handle constants + +Before looking into how the IdsManager handles constants, we'll have to add a new field to it: + +```go +type IdsManager struct { + References map[string]HintReference + HintApTracking parser.ApTrackingData + AccessibleScopes []string +} +``` +AccessibleScopes is a list of paths that a hint has access to, for example, if we were to write a hint in a function `foo` of a cairo program called `program`, that hint's accessible scopes will look something along the likes of `["program", "program.foo"]`. This list is taken directly from the `HintParams`' `AccessibleScopes` field in the compiled json. + +We can use this accessible scopes to determine the correct path for a cairo constant when implementing a hint. To do so, we will be searching for a constant in the map of constants provided by the vm, using the name of the constant in the hint and the possible paths in the accessible scopes, going from innermost (in the example, "program.foo"), to outermost (in the example, "program"). +We will be adding this behaviour to the `IdsManager`, by adding a function that will return the value of a constant given its name (without its full path) and the map of constants, following these steps: + +1. Iterate over the list of accessible scopes in reverse order +2. For each path in accessible scopes, append the name of the constant to get the full-path constant's name +3. Using the full-path constant names, try to fetch from the constants map +4. Once a match is found, return the value from the constant map + +```go +func (ids *IdsManager) GetConst(name string, constants *map[string]lambdaworks.Felt) (lambdaworks.Felt, error) { + // Hints should always have accessible scopes + if len(ids.AccessibleScopes) != 0 { + // Accessible scopes are listed from outer to inner + for i := len(ids.AccessibleScopes) - 1; i >= 0; i-- { + constant, ok := (*constants)[ids.AccessibleScopes[i]+"."+name] + if ok { + return constant, nil + } + } + } + return lambdaworks.FeltZero(), errors.Errorf("Missing constant %s", name) +} +``` + +##### Hint Interaction: ExecutionScopes + +Up until now we saw how hints can interact with the vm and the cairo variables, but what about the interaction between hints themselves? +To answer this question, we will introduce the concept of `Execution Scopes`, they consist of a stack of dictionaries that can hold any kind of variables. These scopes are accessible to all hints, allowing data to be shared between hints without the cairo program being aware of them. As it consists of a stack of dictionaries (from now on referred to as scopes), hints will only be able to interact with the last (or top level) scope. Hints can also remove and create new scopes, we will call these operations `ExitScope` and `EnterScope`. To better illustrate this behaviour, lets make a generic example: + +* HINT A: Adds variable n = 3 (Scopes = [{n: 3}]) +* HINT B: Fetches variable n and updates its value to 5 (Scopes = [{n: 5}]) +* HINT C: Uses the EnterScope operation (Scopes = [{n: 5}, {}]) +* HINT D: Adds variable n = 3 (Scopes = [{n: 5}, {n: 3}]) +* HINT E: Prints the value of n (3), then used the ExitScope operation (Scopes = [{n: 5}]) +* HINT F: Prints the value of n (5) + +Now that we know how execution scopes work, implementing them is quite simple: + +```go +type ExecutionScopes struct { + data []map[string]interface{} +} +``` +We have a stack (represented as a slice), of maps that connect a varaible's name, to its value, accepting any kind of variables as value + +We should also note that when creating an `ExecutionScopes`, it comes with one initial scope (called main scope), which can't be exited + +```go +func NewExecutionScopes() *ExecutionScopes { + data := make([]map[string]interface{}, 1) + data[0] = make(map[string]interface{}) + return &ExecutionScopes{data} +} +``` + +With this struct we can implement the basic operations: + +*EnterScope* + +Adds a new scope to the stack, which is received by the method +```go +func (es *ExecutionScopes) EnterScope(newScopeLocals map[string]interface{}) { + es.data = append(es.data, newScopeLocals) + +} +``` + +*ExitScope* + +Removes the last scope from the stack, guards that the main scope is not removed by the operation. + +```go +func (es *ExecutionScopes) ExitScope() error { + if len(es.data) < 2 { + return ErrCannotExitMainScop + } + es.data = es.data[len(es.data) - 1] + + return nil +} +``` + +*AssignOrUpdateVariable* + +Inserts a variable to the current scope (aka the top one in the stack), overwitting the previous value if it exists + +```go +func (es *ExecutionScopes) AssignOrUpdateVariable(varName string, varValue interface{}) { + locals, err := es.getLocalVariablesMut() + if err != nil { + return + } + (*locals)[varName] = varValue +} +``` + +*Get* + +Fetches a variable from the current scope +```go +func (es *ExecutionScopes) Get(varName string) (interface{}, error) { + locals, err := es.GetLocalVariables() + if err != nil { + return nil, err + } + val, prs := locals[varName] + if !prs { + return nil, ErrVariableNotInScope(varName) + } + return val, nil +} +``` + +*DeleteVariable* + +Removes a variable from the current scope + +```go +func (es *ExecutionScopes) DeleteVariable(varName string) { + locals, err := es.getLocalVariablesMut() + if err != nil { + return + } + delete(*locals, varName) +} +``` + +And the helper methods for these methods: + +```go +func (es *ExecutionScopes) getLocalVariablesMut() (*map[string]interface{}, error) { + locals, err := es.GetLocalVariables() + if err != nil { + return nil, err + } + return &locals, nil +} + +func (es *ExecutionScopes) GetLocalVariables() (map[string]interface{}, error) { + if len(es.data) > 0 { + return es.data[len(es.data)-1], nil + } + return nil, ExecutionScopesError(errors.Errorf("Every enter_scope() requires a corresponding exit_scope().")) +} +``` + +##### Hint Implementation Examples + +Now that we have all the necessary tools to begin implementing hints, lets look at some examples: + +###### IS_LE_FELT + +The python code we have to implement is the following: + +"memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1" + +The first thing we notice is that its uses the ids variables "a" and "b" so this gives as an opportunity to use our `IdsManager`. We can also look at the context of this hint, in this case the common library function is_le_felt (in the math_cmp module) to see that ids.a and ids.b are both felt values. + +We can divide the hint into the following steps: + +1. Fetch ids.a as a Felt +2. Fetch ids.b as a Felt +3. Compare the values of a and b (we don't need to perform % PRIME, as our Felt type already takes care of it) +4. Insert either 0 or 1 at the current value of ap depending on the comparison in 3 + +And implement the hint: + +```go +func isLeFelt(ids IdsManager, vm *VirtualMachine) error { + a, err := ids.GetFelt("a", vm) + if err != nil { + return err + } + b, err := ids.GetFelt("b", vm) + if err != nil { + return err + } + if a.Cmp(b) != 1 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} +``` + +###### ASSERT_LE_FELT_EXCLUDED_0 + +The python code we have to implement is the following: + +"memory[ap] = 1 if excluded != 0 else 0" + +This hint is quite similar to the previous example, except that instead of comparing ids variables it uses this "excluded" variable. As this variable is neither an ids, nor is it created during the hint, we can tell that it is a variable created by a previous hint, shared through the current execution scope. With this knowledge, we can divide the hint into the following set of steps: + +1. Fetch excluded from the execution scopes +2. Cast the excluded variable to a concrete type. In this case, as we have previously implemented the hint that creates this variable, we know its type is 'int' +3. Compare the values of excluded vs 0 +4. Insert either 0 or 1 at the current value of ap depending on the comparison in 3 + +```go +func assertLeFeltExcluded0(vm *VirtualMachine, scopes *ExecutionScopes) error { + // Fetch scope var + excludedAny, err := scopes.Get("excluded") + if err != nil { + return err + } + excluded, ok := excludedAny.(int) + if !ok { + return errors.New("excluded not in scope") + } + if excluded == 0 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} +``` + +#### Proof mode + +TODO + +#### Temporary Segments + +TODO From 78c0a0897967ad99aeda5c3354f6496dbf3dde58 Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Fri, 29 Sep 2023 21:37:52 +0300 Subject: [PATCH 5/7] Add `ExecutionResources` + fix `CairoRunner` methods that received vm as an argument (#300) * Add ExecutionResources struct * Fix: Remove vm argument from CairoRunner methods * Finish func + remove todos * Add unit test * Add `RunUntilPc` & `GetReturnValues` (#302) * Implement GenArg * Remove recursive processing * Add unit tests * Fix test values * Start fn * Add RunFromEntryPoint * Add test for RunFromEntryPoint * Add unit tests * Add comments --------- Co-authored-by: Pedro Fontana --- pkg/runners/cairo_runner.go | 171 +++++++++++++++++++---------- pkg/runners/cairo_runner_test.go | 104 ++++++++++++------ pkg/runners/execution_resources.go | 7 ++ pkg/vm/cairo_run/cairo_run.go | 6 +- pkg/vm/memory/segments.go | 38 +++++++ pkg/vm/memory/segments_test.go | 49 +++++++++ pkg/vm/vm_core.go | 9 ++ pkg/vm/vm_test.go | 28 +++++ 8 files changed, 321 insertions(+), 91 deletions(-) create mode 100644 pkg/runners/execution_resources.go diff --git a/pkg/runners/cairo_runner.go b/pkg/runners/cairo_runner.go index e9baf5dc..32cb37d4 100644 --- a/pkg/runners/cairo_runner.go +++ b/pkg/runners/cairo_runner.go @@ -70,11 +70,11 @@ func NewCairoRunner(program vm.Program, layoutName string, proofMode bool) (*Cai // Performs the initialization step, returns the end pointer (pc upon which execution should stop) func (r *CairoRunner) Initialize() (memory.Relocatable, error) { - err := r.initializeBuiltins() + err := r.InitializeBuiltins() if err != nil { return memory.Relocatable{}, errors.New(err.Error()) } - r.initializeSegments() + r.InitializeSegments() end, err := r.initializeMainEntrypoint() if err == nil { err = r.initializeVM() @@ -84,7 +84,7 @@ func (r *CairoRunner) Initialize() (memory.Relocatable, error) { // Initializes builtin runners in accordance to the specified layout and // the builtins present in the running program. -func (r *CairoRunner) initializeBuiltins() error { +func (r *CairoRunner) InitializeBuiltins() error { var builtinRunners []builtins.BuiltinRunner programBuiltins := map[string]struct{}{} for _, builtin := range r.Program.Builtins { @@ -113,7 +113,7 @@ func (r *CairoRunner) initializeBuiltins() error { } // Creates program, execution and builtin segments -func (r *CairoRunner) initializeSegments() { +func (r *CairoRunner) InitializeSegments() { // Program Segment r.ProgramBase = r.Vm.Segments.AddSegment() // Execution Segment @@ -144,9 +144,9 @@ func (r *CairoRunner) initializeState(entrypoint uint, stack *[]memory.MaybeRelo // Initializes memory, initial register values & returns the end pointer (final pc) to run from a given pc offset // (entrypoint) -func (r *CairoRunner) initializeFunctionEntrypoint(entrypoint uint, stack *[]memory.MaybeRelocatable, return_fp memory.Relocatable) (memory.Relocatable, error) { +func (r *CairoRunner) initializeFunctionEntrypoint(entrypoint uint, stack *[]memory.MaybeRelocatable, return_fp memory.MaybeRelocatable) (memory.Relocatable, error) { end := r.Vm.Segments.AddSegment() - *stack = append(*stack, *memory.NewMaybeRelocatableRelocatable(return_fp), *memory.NewMaybeRelocatableRelocatable(end)) + *stack = append(*stack, return_fp, *memory.NewMaybeRelocatableRelocatable(end)) r.initialFp = r.executionBase r.initialFp.Offset += uint(len(*stack)) r.initialAp = r.initialFp @@ -188,7 +188,7 @@ func (r *CairoRunner) initializeMainEntrypoint() (memory.Relocatable, error) { return memory.NewRelocatable(r.ProgramBase.SegmentIndex, r.ProgramBase.Offset+r.Program.End), nil } - return_fp := r.Vm.Segments.AddSegment() + return_fp := *memory.NewMaybeRelocatableRelocatable(r.Vm.Segments.AddSegment()) return r.initializeFunctionEntrypoint(r.mainOffset, &stack, return_fp) } @@ -237,7 +237,7 @@ func (r *CairoRunner) RunUntilPC(end memory.Relocatable, hintProcessor vm.HintPr return nil } -func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll bool, vm *vm.VirtualMachine, hintProcessor vm.HintProcessor) error { +func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll bool, hintProcessor vm.HintProcessor) error { if runner.RunEnded { return ErrRunnerCalledTwice } @@ -245,7 +245,7 @@ func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll b // TODO: This seems to have to do with temporary segments // vm.Segments.Memory.RelocateMemory() - err := vm.EndRun() + err := runner.Vm.EndRun() if err != nil { return err } @@ -254,15 +254,15 @@ func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll b return nil } - vm.Segments.ComputeEffectiveSizes() + runner.Vm.Segments.ComputeEffectiveSizes() if runner.ProofMode && !disableTracePadding { - err := runner.RunUntilNextPowerOfTwo(vm, hintProcessor) + err := runner.RunUntilNextPowerOfTwo(hintProcessor) if err != nil { return err } for true { - err := runner.CheckUsedCells(vm) + err := runner.CheckUsedCells() if errors.Unwrap(err) == memory.ErrInsufficientAllocatedCells { } else if err != nil { return err @@ -270,12 +270,12 @@ func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll b break } - err = runner.RunForSteps(1, vm, hintProcessor) + err = runner.RunForSteps(1, hintProcessor) if err != nil { return err } - err = runner.RunUntilNextPowerOfTwo(vm, hintProcessor) + err = runner.RunUntilNextPowerOfTwo(hintProcessor) if err != nil { return err } @@ -286,7 +286,7 @@ func (runner *CairoRunner) EndRun(disableTracePadding bool, disableFinalizeAll b return nil } -func (r *CairoRunner) FinalizeSegments(virtualMachine vm.VirtualMachine) error { +func (r *CairoRunner) FinalizeSegments() error { if r.SegmentsFinalized { return nil } @@ -304,7 +304,7 @@ func (r *CairoRunner) FinalizeSegments(virtualMachine vm.VirtualMachine) error { publicMemory = append(publicMemory, i) } - virtualMachine.Segments.Finalize(size, uint(r.ProgramBase.SegmentIndex), &publicMemory) + r.Vm.Segments.Finalize(size, uint(r.ProgramBase.SegmentIndex), &publicMemory) publicMemory = make([]uint, 0) execBase := r.executionBase @@ -316,9 +316,9 @@ func (r *CairoRunner) FinalizeSegments(virtualMachine vm.VirtualMachine) error { publicMemory = append(publicMemory, elem+execBase.Offset) } - virtualMachine.Segments.Finalize(nil, uint(execBase.SegmentIndex), &publicMemory) - for _, builtin := range virtualMachine.BuiltinRunners { - _, size, err := builtin.GetUsedCellsAndAllocatedSizes(&virtualMachine.Segments, virtualMachine.CurrentStep) + r.Vm.Segments.Finalize(nil, uint(execBase.SegmentIndex), &publicMemory) + for _, builtin := range r.Vm.BuiltinRunners { + _, size, err := builtin.GetUsedCellsAndAllocatedSizes(&r.Vm.Segments, r.Vm.CurrentStep) if err != nil { return err } @@ -329,9 +329,9 @@ func (r *CairoRunner) FinalizeSegments(virtualMachine vm.VirtualMachine) error { for i = 0; i < size; i++ { publicMemory = append(publicMemory, i) } - virtualMachine.Segments.Finalize(&size, uint(builtin.Base().SegmentIndex), &publicMemory) + r.Vm.Segments.Finalize(&size, uint(builtin.Base().SegmentIndex), &publicMemory) } else { - virtualMachine.Segments.Finalize(&size, uint(builtin.Base().SegmentIndex), nil) + r.Vm.Segments.Finalize(&size, uint(builtin.Base().SegmentIndex), nil) } } @@ -339,15 +339,15 @@ func (r *CairoRunner) FinalizeSegments(virtualMachine vm.VirtualMachine) error { return nil } -func (r *CairoRunner) ReadReturnValues(virtualMachine *vm.VirtualMachine) error { +func (r *CairoRunner) ReadReturnValues() error { if !r.RunEnded { return errors.New("Tried to read return values before run ended") } - pointer := virtualMachine.RunContext.Ap + pointer := r.Vm.RunContext.Ap - for i := len(virtualMachine.BuiltinRunners) - 1; i >= 0; i-- { - newPointer, err := virtualMachine.BuiltinRunners[i].FinalStack(&virtualMachine.Segments, pointer) + for i := len(r.Vm.BuiltinRunners) - 1; i >= 0; i-- { + newPointer, err := r.Vm.BuiltinRunners[i].FinalStack(&r.Vm.Segments, pointer) if err != nil { return err } @@ -363,7 +363,7 @@ func (r *CairoRunner) ReadReturnValues(virtualMachine *vm.VirtualMachine) error execBase := r.executionBase begin := pointer.Offset - execBase.Offset - ap := virtualMachine.RunContext.Ap + ap := r.Vm.RunContext.Ap end := ap.Offset - execBase.Offset var publicMemoryExtension []uint @@ -379,26 +379,26 @@ func (r *CairoRunner) ReadReturnValues(virtualMachine *vm.VirtualMachine) error } -func (runner *CairoRunner) CheckUsedCells(virtualMachine *vm.VirtualMachine) error { - for _, builtin := range virtualMachine.BuiltinRunners { +func (runner *CairoRunner) CheckUsedCells() error { + for _, builtin := range runner.Vm.BuiltinRunners { // I guess we call this just in case it errors out, even though later on we also call it? - _, _, err := builtin.GetUsedCellsAndAllocatedSizes(&virtualMachine.Segments, virtualMachine.CurrentStep) + _, _, err := builtin.GetUsedCellsAndAllocatedSizes(&runner.Vm.Segments, runner.Vm.CurrentStep) if err != nil { return err } } - err := runner.CheckRangeCheckUsage(virtualMachine) + err := runner.CheckRangeCheckUsage() if err != nil { return err } - err = runner.CheckMemoryUsage(virtualMachine) + err = runner.CheckMemoryUsage() if err != nil { return err } - err = runner.CheckDilutedCheckUsage(virtualMachine) + err = runner.CheckDilutedCheckUsage() if err != nil { return err } @@ -406,13 +406,13 @@ func (runner *CairoRunner) CheckUsedCells(virtualMachine *vm.VirtualMachine) err return nil } -func (runner *CairoRunner) CheckMemoryUsage(virtualMachine *vm.VirtualMachine) error { +func (runner *CairoRunner) CheckMemoryUsage() error { instance := runner.Layout var builtinsMemoryUnits uint = 0 - for _, builtin := range virtualMachine.BuiltinRunners { - result, err := builtin.GetAllocatedMemoryUnits(&virtualMachine.Segments, virtualMachine.CurrentStep) + for _, builtin := range runner.Vm.BuiltinRunners { + result, err := builtin.GetAllocatedMemoryUnits(&runner.Vm.Segments, runner.Vm.CurrentStep) if err != nil { return err } @@ -420,7 +420,7 @@ func (runner *CairoRunner) CheckMemoryUsage(virtualMachine *vm.VirtualMachine) e builtinsMemoryUnits += result } - totalMemoryUnits := instance.MemoryUnitsPerStep * virtualMachine.CurrentStep + totalMemoryUnits := instance.MemoryUnitsPerStep * runner.Vm.CurrentStep publicMemoryUnits := totalMemoryUnits / instance.PublicMemoryFraction remainder := totalMemoryUnits % instance.PublicMemoryFraction @@ -428,10 +428,10 @@ func (runner *CairoRunner) CheckMemoryUsage(virtualMachine *vm.VirtualMachine) e return errors.Errorf("Total Memory units was not divisible by the Public Memory Fraction. TotalMemoryUnits: %d PublicMemoryFraction: %d", totalMemoryUnits, instance.PublicMemoryFraction) } - instructionMemoryUnits := 4 * virtualMachine.CurrentStep + instructionMemoryUnits := 4 * runner.Vm.CurrentStep unusedMemoryUnits := totalMemoryUnits - (publicMemoryUnits + instructionMemoryUnits + builtinsMemoryUnits) - memoryAddressHoles, err := runner.GetMemoryHoles(virtualMachine) + memoryAddressHoles, err := runner.GetMemoryHoles() if err != nil { return err } @@ -443,11 +443,11 @@ func (runner *CairoRunner) CheckMemoryUsage(virtualMachine *vm.VirtualMachine) e return nil } -func (runner *CairoRunner) GetMemoryHoles(virtualMachine *vm.VirtualMachine) (uint, error) { - return virtualMachine.Segments.GetMemoryHoles(uint(len(virtualMachine.BuiltinRunners))) +func (runner *CairoRunner) GetMemoryHoles() (uint, error) { + return runner.Vm.Segments.GetMemoryHoles(uint(len(runner.Vm.BuiltinRunners))) } -func (runner *CairoRunner) CheckDilutedCheckUsage(virtualMachine *vm.VirtualMachine) error { +func (runner *CairoRunner) CheckDilutedCheckUsage() error { dilutedPoolInstance := runner.Layout.DilutedPoolInstance if dilutedPoolInstance == nil { return nil @@ -455,14 +455,14 @@ func (runner *CairoRunner) CheckDilutedCheckUsage(virtualMachine *vm.VirtualMach var usedUnitsByBuiltins uint = 0 - for _, builtin := range virtualMachine.BuiltinRunners { + for _, builtin := range runner.Vm.BuiltinRunners { usedUnits := builtin.GetUsedDilutedCheckUnits(dilutedPoolInstance.Spacing, dilutedPoolInstance.NBits) ratio := builtin.Ratio() if ratio == 0 { ratio = 1 } - multiplier, err := utils.SafeDiv(virtualMachine.CurrentStep, ratio) + multiplier, err := utils.SafeDiv(runner.Vm.CurrentStep, ratio) if err != nil { return err @@ -471,7 +471,7 @@ func (runner *CairoRunner) CheckDilutedCheckUsage(virtualMachine *vm.VirtualMach usedUnitsByBuiltins += usedUnits * multiplier } - var dilutedUnits uint = dilutedPoolInstance.UnitsPerStep * virtualMachine.CurrentStep + var dilutedUnits uint = dilutedPoolInstance.UnitsPerStep * runner.Vm.CurrentStep var unusedDilutedUnits uint = dilutedUnits - usedUnitsByBuiltins var dilutedUsageUpperBound uint = 1 << dilutedPoolInstance.NBits @@ -483,7 +483,7 @@ func (runner *CairoRunner) CheckDilutedCheckUsage(virtualMachine *vm.VirtualMach return nil } -func (runner *CairoRunner) CheckRangeCheckUsage(virtualMachine *vm.VirtualMachine) error { +func (runner *CairoRunner) CheckRangeCheckUsage() error { var rcMin, rcMax *uint for _, builtin := range runner.Vm.BuiltinRunners { @@ -513,7 +513,7 @@ func (runner *CairoRunner) CheckRangeCheckUsage(virtualMachine *vm.VirtualMachin var rcUnitsUsedByBuiltins uint = 0 for _, builtin := range runner.Vm.BuiltinRunners { - usedUnits, err := builtin.GetUsedPermRangeCheckLimits(&virtualMachine.Segments, virtualMachine.CurrentStep) + usedUnits, err := builtin.GetUsedPermRangeCheckLimits(&runner.Vm.Segments, runner.Vm.CurrentStep) if err != nil { return err } @@ -521,7 +521,7 @@ func (runner *CairoRunner) CheckRangeCheckUsage(virtualMachine *vm.VirtualMachin rcUnitsUsedByBuiltins += usedUnits } - unusedRcUnits := (runner.Layout.RcUnits-3)*virtualMachine.CurrentStep - uint(rcUnitsUsedByBuiltins) + unusedRcUnits := (runner.Layout.RcUnits-3)*runner.Vm.CurrentStep - uint(rcUnitsUsedByBuiltins) if unusedRcUnits < (*rcMax - *rcMin) { return memory.InsufficientAllocatedCellsError(unusedRcUnits, *rcMax-*rcMin) @@ -530,8 +530,7 @@ func (runner *CairoRunner) CheckRangeCheckUsage(virtualMachine *vm.VirtualMachin return nil } -// TODO: Add hint processor when it's done -func (runner *CairoRunner) RunForSteps(steps uint, virtualMachine *vm.VirtualMachine, hintProcessor vm.HintProcessor) error { +func (runner *CairoRunner) RunForSteps(steps uint, hintProcessor vm.HintProcessor) error { hintDataMap, err := runner.BuildHintDataMap(hintProcessor) if err != nil { return err @@ -539,11 +538,11 @@ func (runner *CairoRunner) RunForSteps(steps uint, virtualMachine *vm.VirtualMac constants := runner.Program.ExtractConstants() var remainingSteps int for remainingSteps = int(steps); remainingSteps > 0; remainingSteps-- { - if runner.finalPc != nil && *runner.finalPc == virtualMachine.RunContext.Pc { + if runner.finalPc != nil && *runner.finalPc == runner.Vm.RunContext.Pc { return &vm.VirtualMachineError{Msg: fmt.Sprintf("EndOfProgram: %d", remainingSteps)} } - err := virtualMachine.Step(hintProcessor, &hintDataMap, &constants, &runner.execScopes) + err := runner.Vm.Step(hintProcessor, &hintDataMap, &constants, &runner.execScopes) if err != nil { return err } @@ -552,12 +551,70 @@ func (runner *CairoRunner) RunForSteps(steps uint, virtualMachine *vm.VirtualMac return nil } -// TODO: Add hint processor when it's done -func (runner *CairoRunner) RunUntilSteps(steps uint, virtualMachine *vm.VirtualMachine, hintProcessor vm.HintProcessor) error { - return runner.RunForSteps(steps-virtualMachine.CurrentStep, virtualMachine, hintProcessor) +func (runner *CairoRunner) RunUntilSteps(steps uint, hintProcessor vm.HintProcessor) error { + return runner.RunForSteps(steps-runner.Vm.CurrentStep, hintProcessor) +} + +func (runner *CairoRunner) RunUntilNextPowerOfTwo(hintProcessor vm.HintProcessor) error { + return runner.RunUntilSteps(utils.NextPowOf2(runner.Vm.CurrentStep), hintProcessor) } -// TODO: Add hint processor when it's done -func (runner *CairoRunner) RunUntilNextPowerOfTwo(virtualMachine *vm.VirtualMachine, hintProcessor vm.HintProcessor) error { - return runner.RunUntilSteps(utils.NextPowOf2(virtualMachine.CurrentStep), virtualMachine, hintProcessor) +func (runner *CairoRunner) GetExecutionResources() (ExecutionResources, error) { + nSteps := uint(len(runner.Vm.Trace)) + if nSteps == 0 { + nSteps = runner.Vm.CurrentStep + } + nMemoryHoles, err := runner.GetMemoryHoles() + if err != nil { + return ExecutionResources{}, err + } + builtinInstaceCounter := make(map[string]uint) + for i := 0; i < len(runner.Vm.BuiltinRunners); i++ { + builtinInstaceCounter[runner.Vm.BuiltinRunners[i].Name()], err = runner.Vm.BuiltinRunners[i].GetUsedInstances(&runner.Vm.Segments) + if err != nil { + return ExecutionResources{}, err + } + } + return ExecutionResources{ + NSteps: nSteps, + NMemoryHoles: nMemoryHoles, + BuiltinsInstanceCounter: builtinInstaceCounter, + }, nil +} + +// TODO: Add verifySecure once its implemented +/* +Runs a cairo program from a give entrypoint, indicated by its pc offset, with the given arguments. +If `verifySecure` is set to true, [verifySecureRunner] will be called to run extra verifications. +`programSegmentSize` is only used by the [verifySecureRunner] function and will be ignored if `verifySecure` is set to false. +Each arg can be either MaybeRelocatable, []MaybeRelocatable or [][]MaybeRelocatable +*/ +func (runner *CairoRunner) RunFromEntrypoint(entrypoint uint, args []any, hintProcessor vm.HintProcessor) error { + stack := make([]memory.MaybeRelocatable, 0) + for _, arg := range args { + val, err := runner.Vm.Segments.GenArg(arg) + if err != nil { + return err + } + stack = append(stack, val) + } + returnFp := *memory.NewMaybeRelocatableFelt(lambdaworks.FeltZero()) + end, err := runner.initializeFunctionEntrypoint(entrypoint, &stack, returnFp) + if err != nil { + return err + } + err = runner.initializeVM() + if err != nil { + return err + } + err = runner.RunUntilPC(end, hintProcessor) + if err != nil { + return err + } + err = runner.EndRun(false, false, hintProcessor) + if err != nil { + return err + } + // TODO: verifySecureRunner + return nil } diff --git a/pkg/runners/cairo_runner_test.go b/pkg/runners/cairo_runner_test.go index e19367d1..1c5b7130 100644 --- a/pkg/runners/cairo_runner_test.go +++ b/pkg/runners/cairo_runner_test.go @@ -567,11 +567,10 @@ func TestCheckRangeCheckUsagePermRangeLimitsNone(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.Trace = make([]vm.TraceEntry, 0) + runner.Vm.Trace = make([]vm.TraceEntry, 0) - err = runner.CheckRangeCheckUsage(virtualMachine) + err = runner.CheckRangeCheckUsage() if err != nil { t.Errorf("Check Range Usage Failed With Error %s", err) } @@ -584,18 +583,17 @@ func TestCheckRangeCheckUsageWithoutBuiltins(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.Trace = make([]vm.TraceEntry, 0) - virtualMachine.CurrentStep = 1000 - virtualMachine.Segments.Memory.Insert( + runner.Vm.Trace = make([]vm.TraceEntry, 0) + runner.Vm.CurrentStep = 1000 + runner.Vm.Segments.Memory.Insert( memory.NewRelocatable(0, 0), memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromHex("0x80FF80000530")), ) - virtualMachine.Trace = make([]vm.TraceEntry, 1) - virtualMachine.Trace[0] = vm.TraceEntry{Pc: memory.NewRelocatable(0, 0), Ap: memory.NewRelocatable(0, 0), Fp: memory.NewRelocatable(0, 0)} - err = runner.CheckRangeCheckUsage(virtualMachine) + runner.Vm.Trace = make([]vm.TraceEntry, 1) + runner.Vm.Trace[0] = vm.TraceEntry{Pc: memory.NewRelocatable(0, 0), Ap: memory.NewRelocatable(0, 0), Fp: memory.NewRelocatable(0, 0)} + err = runner.CheckRangeCheckUsage() if err != nil { t.Errorf("Check Range Usage Failed With Error %s", err) } @@ -620,7 +618,7 @@ func TestCheckRangeUsageInsufficientAllocatedCells(t *testing.T) { runner.Vm.Trace = make([]vm.TraceEntry, 1) runner.Vm.Trace[0] = vm.TraceEntry{Pc: memory.NewRelocatable(0, 0), Ap: memory.NewRelocatable(0, 0), Fp: memory.NewRelocatable(0, 0)} runner.Vm.Segments.ComputeEffectiveSizes() - err = runner.CheckRangeCheckUsage(&runner.Vm) + err = runner.CheckRangeCheckUsage() if err == nil { t.Error("Check Range Usage Should Have Failed With Insufficient Allocated Cells Error") } @@ -633,11 +631,10 @@ func TestCheckDilutedCheckUsageWithoutPoolInstance(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() runner.Layout.DilutedPoolInstance = nil - err = runner.CheckDilutedCheckUsage(virtualMachine) + err = runner.CheckDilutedCheckUsage() if err != nil { t.Errorf("Check Diluted Check Usage Failed With Error %s", err) } @@ -650,12 +647,11 @@ func TestCheckDilutedCheckUsageWithoutBuiltinRunners(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.CurrentStep = 10000 - virtualMachine.BuiltinRunners = make([]builtins.BuiltinRunner, 0) + runner.Vm.CurrentStep = 10000 + runner.Vm.BuiltinRunners = make([]builtins.BuiltinRunner, 0) - err = runner.CheckDilutedCheckUsage(virtualMachine) + err = runner.CheckDilutedCheckUsage() if err != nil { t.Errorf("Check Diluted Check Usage Failed With Error %s", err) } @@ -668,12 +664,11 @@ func TestCheckDilutedCheckUsageInsufficientAllocatedCells(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.CurrentStep = 100 - virtualMachine.BuiltinRunners = make([]builtins.BuiltinRunner, 0) + runner.Vm.CurrentStep = 100 + runner.Vm.BuiltinRunners = make([]builtins.BuiltinRunner, 0) - err = runner.CheckDilutedCheckUsage(virtualMachine) + err = runner.CheckDilutedCheckUsage() if err == nil { t.Errorf("Check Diluted Check Usage Should Have failed With Insufficient Allocated Cells Error") } @@ -686,13 +681,12 @@ func TestCheckDilutedCheckUsage(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.CurrentStep = 8192 - virtualMachine.BuiltinRunners = make([]builtins.BuiltinRunner, 0) - virtualMachine.BuiltinRunners = append(virtualMachine.BuiltinRunners, builtins.NewBitwiseBuiltinRunner(256)) + runner.Vm.CurrentStep = 8192 + runner.Vm.BuiltinRunners = make([]builtins.BuiltinRunner, 0) + runner.Vm.BuiltinRunners = append(runner.Vm.BuiltinRunners, builtins.NewBitwiseBuiltinRunner(256)) - err = runner.CheckDilutedCheckUsage(virtualMachine) + err = runner.CheckDilutedCheckUsage() if err != nil { t.Errorf("Check Diluted Check Usage Failed With Error %s", err) } @@ -706,14 +700,62 @@ func TestCheckUsedCellsDilutedCheckUsageError(t *testing.T) { if err != nil { t.Error("Could not initialize Cairo Runner") } - virtualMachine := vm.NewVirtualMachine() - virtualMachine.Segments.SegmentUsedSizes = make(map[uint]uint) - virtualMachine.Segments.SegmentUsedSizes[0] = 4 - virtualMachine.Trace = []vm.TraceEntry{} + runner.Vm.Segments.SegmentUsedSizes = make(map[uint]uint) + runner.Vm.Segments.SegmentUsedSizes[0] = 4 + runner.Vm.Trace = []vm.TraceEntry{} - err = runner.CheckUsedCells(virtualMachine) + err = runner.CheckUsedCells() if err == nil { t.Errorf("Check Used Cells Should Have failed With Insufficient Allocated Cells Error") } } + +func TestRunFibonacciGetExecutionResources(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{Layout: "all_cairo", ProofMode: false} + runner, err := cairo_run.CairoRun("../../cairo_programs/fibonacci.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } + expectedExecutionResources := runners.ExecutionResources{ + NSteps: 80, + BuiltinsInstanceCounter: make(map[string]uint), + } + executionResources, _ := runner.GetExecutionResources() + if !reflect.DeepEqual(executionResources, expectedExecutionResources) { + t.Errorf("Wong ExecutionResources.\n Expected : %+v, got: %+v", expectedExecutionResources, executionResources) + } +} + +// This test will run the `fib` function in the fibonacci.json program +func TestRunFromEntryPointFibonacci(t *testing.T) { + compiledProgram, _ := parser.Parse("../../cairo_programs/fibonacci.json") + programJson := vm.DeserializeProgramJson(compiledProgram) + + entrypoint := programJson.Identifiers["__main__.fib"].PC + args := []any{ + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltOne()), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltOne()), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(10)), + } + runner, _ := runners.NewCairoRunner(programJson, "all_cairo", false) + hintProcessor := hints.CairoVmHintProcessor{} + + runner.InitializeBuiltins() + runner.InitializeSegments() + err := runner.RunFromEntrypoint(uint(entrypoint), args, &hintProcessor) + + if err != nil { + t.Errorf("Running fib entrypoint failed with error %s", err.Error()) + } + + // Check result + res, err := runner.Vm.GetReturnValues(1) + if err != nil { + t.Errorf("Failed to fetch return values from fib with error %s", err.Error()) + } + if len(res) != 1 || !reflect.DeepEqual(res[0], *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(144))) { + t.Errorf("Wrong value returned by fib entrypoint.\n Expected [144], got: %+v", res) + } + +} diff --git a/pkg/runners/execution_resources.go b/pkg/runners/execution_resources.go new file mode 100644 index 00000000..e3872021 --- /dev/null +++ b/pkg/runners/execution_resources.go @@ -0,0 +1,7 @@ +package runners + +type ExecutionResources struct { + NSteps uint + NMemoryHoles uint + BuiltinsInstanceCounter map[string]uint +} diff --git a/pkg/vm/cairo_run/cairo_run.go b/pkg/vm/cairo_run/cairo_run.go index 678a658d..d54a1f99 100644 --- a/pkg/vm/cairo_run/cairo_run.go +++ b/pkg/vm/cairo_run/cairo_run.go @@ -53,18 +53,18 @@ func CairoRun(programPath string, cairoRunConfig CairoRunConfig) (*runners.Cairo if err != nil { return nil, err } - err = cairoRunner.EndRun(cairoRunConfig.DisableTracePadding, false, &cairoRunner.Vm, &hintProcessor) + err = cairoRunner.EndRun(cairoRunConfig.DisableTracePadding, false, &hintProcessor) if err != nil { return nil, err } - err = cairoRunner.ReadReturnValues(&cairoRunner.Vm) + err = cairoRunner.ReadReturnValues() if err != nil { return nil, err } if cairoRunConfig.ProofMode { - cairoRunner.FinalizeSegments(cairoRunner.Vm) + cairoRunner.FinalizeSegments() } err = cairoRunner.Vm.Relocate() diff --git a/pkg/vm/memory/segments.go b/pkg/vm/memory/segments.go index 732ee0b7..32c5acda 100644 --- a/pkg/vm/memory/segments.go +++ b/pkg/vm/memory/segments.go @@ -1,6 +1,8 @@ package memory import ( + "errors" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" ) @@ -186,3 +188,39 @@ func (m *MemorySegmentManager) GetFeltRange(start Relocatable, size uint) ([]lam } return feltRange, nil } + +/* +Converts a generic argument into a MaybeRelocatable +If the argument is a slice, it loads it into memory in a new segment and returns its base +Accepts MaybeRelocatable, []MaybeRelocatable, [][]MaybeRelocatable +*/ +func (m *MemorySegmentManager) GenArg(arg any) (MaybeRelocatable, error) { + // Attempt to cast to MaybeRelocatable + a, ok := arg.(MaybeRelocatable) + if ok { + return a, nil + } + // Attempt to cast to []MaybeRelocatable + data, ok := arg.([]MaybeRelocatable) + if ok { + base := m.AddSegment() + _, err := m.LoadData(base, &data) + return *NewMaybeRelocatableRelocatable(base), err + } + // Attempt to cast to [][]MaybeRelocatable + datas, ok := arg.([][]MaybeRelocatable) + if ok { + args := make([]MaybeRelocatable, 0) + for _, data = range datas { + dataBase, err := m.GenArg(data) + if err != nil { + return *NewMaybeRelocatableFelt(lambdaworks.FeltZero()), err + } + args = append(args, dataBase) + } + base := m.AddSegment() + _, err := m.LoadData(base, &args) + return *NewMaybeRelocatableRelocatable(base), err + } + return *NewMaybeRelocatableFelt(lambdaworks.FeltZero()), errors.New("GenArg: found argument of invalid type.") +} diff --git a/pkg/vm/memory/segments_test.go b/pkg/vm/memory/segments_test.go index cdca3585..e4604188 100644 --- a/pkg/vm/memory/segments_test.go +++ b/pkg/vm/memory/segments_test.go @@ -310,3 +310,52 @@ func TestGetFeltRangeRelocatable(t *testing.T) { t.Errorf("GetFeltRange should have failed") } } + +func TestGenArgMaybeRelocatable(t *testing.T) { + segments := memory.NewMemorySegmentManager() + arg := any(*memory.NewMaybeRelocatableFelt(lambdaworks.FeltZero())) + expectedArg := *memory.NewMaybeRelocatableFelt(lambdaworks.FeltZero()) + genedArg, err := segments.GenArg(arg) + if err != nil || !reflect.DeepEqual(expectedArg, genedArg) { + t.Error("GenArg failed or returned wrong value") + } +} + +func TestGenArgSliceMaybeRelocatable(t *testing.T) { + segments := memory.NewMemorySegmentManager() + arg := any([]memory.MaybeRelocatable{*memory.NewMaybeRelocatableFelt(lambdaworks.FeltZero())}) + + expectedBase := memory.NewRelocatable(0, 0) + expectedArg := *memory.NewMaybeRelocatableRelocatable(expectedBase) + genedArg, err := segments.GenArg(arg) + if err != nil || !reflect.DeepEqual(expectedArg, genedArg) { + t.Error("GenArg failed or returned wrong value") + } + val, err := segments.Memory.GetFelt(expectedBase) + if err != nil || !val.IsZero() { + t.Error("GenArg inserted wrong value into memory") + } +} + +func TestGenArgSliceSliceMaybeRelocatable(t *testing.T) { + segments := memory.NewMemorySegmentManager() + arg := any([][]memory.MaybeRelocatable{{*memory.NewMaybeRelocatableFelt(lambdaworks.FeltZero())}}) + + expectedBaseA := memory.NewRelocatable(1, 0) + expectedBaseB := memory.NewRelocatable(0, 0) + expectedArg := *memory.NewMaybeRelocatableRelocatable(expectedBaseA) + genedArg, err := segments.GenArg(arg) + + if err != nil || !reflect.DeepEqual(expectedArg, genedArg) { + t.Error("GenArg failed or returned wrong value") + } + valA, err := segments.Memory.GetRelocatable(expectedBaseA) + if err != nil || valA != expectedBaseB { + t.Error("GenArg inserted wrong value into memory") + } + + valB, err := segments.Memory.GetFelt(expectedBaseB) + if err != nil || !valB.IsZero() { + t.Error("GenArg inserted wrong value into memory") + } +} diff --git a/pkg/vm/vm_core.go b/pkg/vm/vm_core.go index b8aa1e2f..c5cb69b1 100644 --- a/pkg/vm/vm_core.go +++ b/pkg/vm/vm_core.go @@ -642,3 +642,12 @@ func (vm *VirtualMachine) GetRangeCheckBound() (lambdaworks.Felt, error) { return rcBuiltin.Bound(), nil } + +// Gets `nRet` return values from memory +func (vm *VirtualMachine) GetReturnValues(nRet uint) ([]memory.MaybeRelocatable, error) { + ptr, err := vm.RunContext.Ap.SubUint(nRet) + if err != nil { + return nil, err + } + return vm.Segments.Memory.GetRange(ptr, nRet) +} diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 8f3ab08c..1d537321 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -1088,3 +1088,31 @@ func TestGetFooBuiltinReturnsNilAndError(t *testing.T) { t.Error("Obtained a non existant builtin, or didn't raise an error") } } + +func TestReadReturnValuesOk(t *testing.T) { + vm := vm.NewVirtualMachine() + vm.Segments.AddSegment() + // Load data at ap and advance ap + data := []memory.MaybeRelocatable{ + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(1)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(2)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(3)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(4)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(5)), + } + vm.RunContext.Ap, _ = vm.Segments.LoadData(vm.RunContext.Ap, &data) + // Fetch 3 return values + expectedReturnValues := []memory.MaybeRelocatable{ + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(3)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(4)), + *memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint(5)), + } + returnValues, err := vm.GetReturnValues(3) + if err != nil { + t.Errorf("GetReturnValues failed with error: %s", err.Error()) + } + + if !reflect.DeepEqual(expectedReturnValues, returnValues) { + t.Errorf("Wrong return values.\n Expected: %+v, got: %+v", expectedReturnValues, returnValues) + } +} From 9cde79b36064d67d802e7f8d4df08d3fd9ff3a70 Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Tue, 3 Oct 2023 01:44:55 +0300 Subject: [PATCH 6/7] Add testing util CheckScopeVar` (#308) --- pkg/hints/hint_utils/testing_utils.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pkg/hints/hint_utils/testing_utils.go b/pkg/hints/hint_utils/testing_utils.go index 46403f07..a52ddb0d 100644 --- a/pkg/hints/hint_utils/testing_utils.go +++ b/pkg/hints/hint_utils/testing_utils.go @@ -1,8 +1,12 @@ package hint_utils import ( + "reflect" + "testing" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" "github.com/lambdaclass/cairo-vm.go/pkg/parser" + "github.com/lambdaclass/cairo-vm.go/pkg/types" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" ) @@ -56,3 +60,13 @@ func SetupConstantsForTest(new_constants map[string]lambdaworks.Felt, ids *IdsMa } return constants } + +func CheckScopeVar[T any](name string, expectedVal T, scopes *types.ExecutionScopes, t *testing.T) { + val, err := types.FetchScopeVar[T](name, scopes) + if err != nil { + t.Error(err.Error()) + } + if !reflect.DeepEqual(val, expectedVal) { + t.Errorf("Wrong scope var %s.\n Expected: %v, got: %v", name, expectedVal, val) + } +} From 76d65f699b7dc9579a40f364b5fce136e9a49bda Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:48:16 +0300 Subject: [PATCH 7/7] Implement most signature hints (#291) * Add ec hints * Implement hints * Add the hints to the processor * Test pack86 function * Test hint * Delete debug info, Test ec negative op * Second hint test * Test embedded hint * Change to Camel case * Implement slope hints * Fix format * Delete github conflict string * Tests hints * Tests hints slopes * Rename misleading name function * Fix function name * Fix error in function call * Delete debug info * Delete unused import * Secp hints * Secpr21 * Add it to the hint processor * Hints secp * bigint3 nondet * Zero verify * Merge main * Add hint to hint processor * Debug info * Prints * Test verify with unit test * Debug unit test * Test verify zero with debug * Non det big 3 test * Modify test to use ids manager * Add hint codes * Implement base hint * Add hints * Add hints to ExecuteHint * debug info * Fix broken test * Move integration test to cairo_run_test.go * Move file from hints_utils and rename * Delete debug * Return error of IdsData.Insert * Change to camel case * Add unit test * Add unit test * Add hint codes * Implement hint * Add SafeDivBig * Add generic way to fetch scope variables * Add generic fetch * Add generic way to fetch scope variables * Use more specific error * Add hints to ExecuteHint * Add extra hint * Fix logic, add unit test * Add unit test * use boolean flag instead or arg * Fix scope var name * Fix scope var name in tests * Make FetchScopeVar work despite references * Revert "Make FetchScopeVar work despite references" This reverts commit 69993be48a9a8fea8d241450463e4dd240091056. * Handle scope variables as big.Int instead of *big.Int * Fix merge cnflicts * Fix tests * Implement Igcdex + add tests * Implement DivMod * Use DivMod instead of Div + Mod * Dont modify the original value in bigint3_split function * Push test file * Remove redundant check * Merge math_utils/utils & utils/math_utils --------- Co-authored-by: Milton Co-authored-by: mmsc2 <88055861+mmsc2@users.noreply.github.com> Co-authored-by: Mariano A. Nicolini Co-authored-by: Pedro Fontana --- cairo_programs/div_mod_n.cairo | 129 +++++++++++++++ pkg/builtins/ec_op.go | 5 +- pkg/hints/hint_codes/signature_hint_codes.go | 21 +++ pkg/hints/hint_processor.go | 10 ++ pkg/hints/hint_utils/bigint_utils.go | 17 +- pkg/hints/hint_utils/secp_utils.go | 2 +- pkg/hints/math_hints.go | 2 +- pkg/hints/signature_hints.go | 84 ++++++++++ pkg/hints/signature_hints_test.go | 159 ++++++++++++++++++ pkg/math_utils/utils.go | 27 ---- pkg/math_utils/utils_test.go | 129 --------------- pkg/utils/math_utils.go | 46 ++++++ pkg/utils/math_utils_test.go | 161 +++++++++++++++++++ pkg/vm/cairo_run/cairo_run_test.go | 4 + 14 files changed, 626 insertions(+), 170 deletions(-) create mode 100644 cairo_programs/div_mod_n.cairo create mode 100644 pkg/hints/hint_codes/signature_hint_codes.go create mode 100644 pkg/hints/signature_hints.go create mode 100644 pkg/hints/signature_hints_test.go delete mode 100644 pkg/math_utils/utils.go delete mode 100644 pkg/math_utils/utils_test.go diff --git a/cairo_programs/div_mod_n.cairo b/cairo_programs/div_mod_n.cairo new file mode 100644 index 00000000..4dbe3c82 --- /dev/null +++ b/cairo_programs/div_mod_n.cairo @@ -0,0 +1,129 @@ +%builtins range_check + +from starkware.cairo.common.cairo_secp.bigint import BigInt3, nondet_bigint3, BASE, bigint_mul +from starkware.cairo.common.cairo_secp.constants import BETA, N0, N1, N2 + +// Source: https://github.com/myBraavos/efficient-secp256r1/blob/73cca4d53730cb8b2dcf34e36c7b8f34b96b3230/src/secp256r1/signature.cairo + +// Computes a * b^(-1) modulo the size of the elliptic curve (N). +// +// Prover assumptions: +// * All the limbs of a are in the range (-2 ** 210.99, 2 ** 210.99). +// * All the limbs of b are in the range (-2 ** 124.99, 2 ** 124.99). +// * b is in the range [0, 2 ** 256). +// +// Soundness assumptions: +// * The limbs of a are in the range (-2 ** 249, 2 ** 249). +// * The limbs of b are in the range (-2 ** 159.83, 2 ** 159.83). +func div_mod_n{range_check_ptr}(a: BigInt3, b: BigInt3) -> (res: BigInt3) { + %{ + from starkware.cairo.common.cairo_secp.secp_utils import N, pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + let (res) = nondet_bigint3(); + + %{ value = k_plus_one = safe_div(res * b - a, N) + 1 %} + let (k_plus_one) = nondet_bigint3(); + let k = BigInt3(d0=k_plus_one.d0 - 1, d1=k_plus_one.d1, d2=k_plus_one.d2); + + let (res_b) = bigint_mul(res, b); + let n = BigInt3(N0, N1, N2); + let (k_n) = bigint_mul(k, n); + + // We should now have res_b = k_n + a. Since the numbers are in unreduced form, + // we should handle the carry. + + tempvar carry1 = (res_b.d0 - k_n.d0 - a.d0) / BASE; + assert [range_check_ptr + 0] = carry1 + 2 ** 127; + + tempvar carry2 = (res_b.d1 - k_n.d1 - a.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (res_b.d2 - k_n.d2 - a.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + tempvar carry4 = (res_b.d3 - k_n.d3 + carry3) / BASE; + assert [range_check_ptr + 3] = carry4 + 2 ** 127; + + assert res_b.d4 - k_n.d4 + carry4 = 0; + + let range_check_ptr = range_check_ptr + 4; + + return (res=res); +} + +func div_mod_n_alt{range_check_ptr}(a: BigInt3, b: BigInt3) -> (res: BigInt3) { + // just used to import N + %{ + from starkware.cairo.common.cairo_secp.secp_utils import N, pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + let (res) = nondet_bigint3(); + + %{ value = k_plus_one = safe_div(res * b - a, N) + 1 %} + let (k_plus_one) = nondet_bigint3(); + let k = BigInt3(d0=k_plus_one.d0 - 1, d1=k_plus_one.d1, d2=k_plus_one.d2); + + let (res_b) = bigint_mul(res, b); + let n = BigInt3(N0, N1, N2); + let (k_n) = bigint_mul(k, n); + + tempvar carry1 = (res_b.d0 - k_n.d0 - a.d0) / BASE; + assert [range_check_ptr + 0] = carry1 + 2 ** 127; + + tempvar carry2 = (res_b.d1 - k_n.d1 - a.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (res_b.d2 - k_n.d2 - a.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + tempvar carry4 = (res_b.d3 - k_n.d3 + carry3) / BASE; + assert [range_check_ptr + 3] = carry4 + 2 ** 127; + + assert res_b.d4 - k_n.d4 + carry4 = 0; + + let range_check_ptr = range_check_ptr + 4; + + return (res=res); +} + +func test_div_mod_n{range_check_ptr: felt}() { + let a: BigInt3 = BigInt3(100, 99, 98); + let b: BigInt3 = BigInt3(10, 9, 8); + + let (res) = div_mod_n(a, b); + + assert res = BigInt3( + 3413472211745629263979533, 17305268010345238170172332, 11991751872105858217578135 + ); + + // test alternative hint + let (res_alt) = div_mod_n_alt(a, b); + + assert res_alt = res; + + return (); +} + +func main{range_check_ptr: felt}() { + test_div_mod_n(); + + return (); +} diff --git a/pkg/builtins/ec_op.go b/pkg/builtins/ec_op.go index 0e973a0a..9053c25c 100644 --- a/pkg/builtins/ec_op.go +++ b/pkg/builtins/ec_op.go @@ -5,7 +5,6 @@ import ( "math/big" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" - "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" "github.com/lambdaclass/cairo-vm.go/pkg/utils" "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" "github.com/pkg/errors" @@ -263,7 +262,7 @@ func LineSlope(point_a PartialSumB, point_b DoublePointB, prime big.Int) (big.In n := new(big.Int).Sub(&point_a.Y, &point_b.Y) m := new(big.Int).Sub(&point_a.X, &point_b.X) - z, err := math_utils.DivMod(n, m, &prime) + z, err := utils.DivMod(n, m, &prime) if err != nil { return big.Int{}, err } @@ -299,7 +298,7 @@ func EcDoubleSlope(point DoublePointB, alpha big.Int, prime big.Int) (big.Int, e n.Add(n, &alpha) m := new(big.Int).Mul(&point.Y, big.NewInt(2)) - z, err := math_utils.DivMod(n, m, &prime) + z, err := utils.DivMod(n, m, &prime) if err != nil { return big.Int{}, err diff --git a/pkg/hints/hint_codes/signature_hint_codes.go b/pkg/hints/hint_codes/signature_hint_codes.go new file mode 100644 index 00000000..37b37051 --- /dev/null +++ b/pkg/hints/hint_codes/signature_hint_codes.go @@ -0,0 +1,21 @@ +package hint_codes + +const DIV_MOD_N_PACKED_DIVMOD_V1 = `from starkware.cairo.common.cairo_secp.secp_utils import N, pack +from starkware.python.math_utils import div_mod, safe_div + +a = pack(ids.a, PRIME) +b = pack(ids.b, PRIME) +value = res = div_mod(a, b, N)` + +const DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N = `from starkware.cairo.common.cairo_secp.secp_utils import pack +from starkware.python.math_utils import div_mod, safe_div + +a = pack(ids.a, PRIME) +b = pack(ids.b, PRIME) +value = res = div_mod(a, b, N)` + +const DIV_MOD_N_SAFE_DIV = "value = k = safe_div(res * b - a, N)" + +const DIV_MOD_N_SAFE_DIV_PLUS_ONE = "value = k_plus_one = safe_div(res * b - a, N) + 1" + +const XS_SAFE_DIV = "value = k = safe_div(res * s - x, N)" diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index ccac7bd2..9685de6e 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -188,6 +188,16 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return splitInt(data.Ids, vm) case SPLIT_INT_ASSERT_RANGE: return splitIntAssertRange(data.Ids, vm) + case DIV_MOD_N_PACKED_DIVMOD_V1: + return divModNPackedDivMod(data.Ids, vm, execScopes) + case DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N: + return divModNPackedDivModExternalN(data.Ids, vm, execScopes) + case XS_SAFE_DIV: + return divModNSafeDiv(data.Ids, execScopes, "x", "s", false) + case DIV_MOD_N_SAFE_DIV: + return divModNSafeDiv(data.Ids, execScopes, "a", "b", false) + case DIV_MOD_N_SAFE_DIV_PLUS_ONE: + return divModNSafeDiv(data.Ids, execScopes, "a", "b", true) case VERIFY_ZERO_EXTERNAL_SECP: return verifyZeroWithExternalConst(*vm, *execScopes, data.Ids) case FAST_EC_ADD_ASSIGN_NEW_X: diff --git a/pkg/hints/hint_utils/bigint_utils.go b/pkg/hints/hint_utils/bigint_utils.go index 9a900443..2fac8054 100644 --- a/pkg/hints/hint_utils/bigint_utils.go +++ b/pkg/hints/hint_utils/bigint_utils.go @@ -96,15 +96,14 @@ func BigInt3FromBaseAddr(addr Relocatable, name string, vm *VirtualMachine) (Big } func BigInt3FromVarName(name string, ids IdsManager, vm *VirtualMachine) (BigInt3, error) { - bigIntAddr, err := ids.GetAddr(name, vm) - if err != nil { - return BigInt3{}, err - } + limbs, err := limbsFromVarName(3, name, ids, vm) + return BigInt3{Limbs: limbs}, err +} - bigInt, err := BigInt3FromBaseAddr(bigIntAddr, name, vm) - if err != nil { - return BigInt3{}, err - } +// Uint384 + +type Uint384 = BigInt3 - return bigInt, err +func Uint384FromVarName(name string, ids IdsManager, vm *VirtualMachine) (Uint384, error) { + return BigInt3FromVarName(name, ids, vm) } diff --git a/pkg/hints/hint_utils/secp_utils.go b/pkg/hints/hint_utils/secp_utils.go index a125823a..bec3b650 100644 --- a/pkg/hints/hint_utils/secp_utils.go +++ b/pkg/hints/hint_utils/secp_utils.go @@ -46,7 +46,7 @@ func Bigint3Split(integer big.Int) ([]big.Int, error) { for i := 0; i < 3; i++ { canonicalRepr[i] = *new(big.Int).And(&num, BASE_MINUS_ONE()) - num.Rsh(&num, 86) + num = *new(big.Int).Rsh(&num, 86) } if num.Cmp(big.NewInt(0)) != 0 { return nil, errors.New("HintError SecpSplitOutOfRange") diff --git a/pkg/hints/math_hints.go b/pkg/hints/math_hints.go index 03b239c3..a577131d 100644 --- a/pkg/hints/math_hints.go +++ b/pkg/hints/math_hints.go @@ -7,8 +7,8 @@ import ( . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" - . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/utils" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" "github.com/pkg/errors" diff --git a/pkg/hints/signature_hints.go b/pkg/hints/signature_hints.go new file mode 100644 index 00000000..37b57f7f --- /dev/null +++ b/pkg/hints/signature_hints.go @@ -0,0 +1,84 @@ +package hints + +import ( + "math/big" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" + "github.com/lambdaclass/cairo-vm.go/pkg/utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" +) + +func divModNPacked(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes, n *big.Int) error { + a, err := Uint384FromVarName("a", ids, vm) + if err != nil { + return err + } + b, err := Uint384FromVarName("b", ids, vm) + if err != nil { + return err + } + packedA := a.Pack86() + packedB := b.Pack86() + + val, err := utils.DivMod(&packedA, &packedB, n) + if err != nil { + return err + } + + scopes.AssignOrUpdateVariable("a", packedA) + scopes.AssignOrUpdateVariable("b", packedB) + scopes.AssignOrUpdateVariable("value", *val) + scopes.AssignOrUpdateVariable("res", *val) + + return nil +} + +func divModNPackedDivMod(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes) error { + n, _ := new(big.Int).SetString("115792089237316195423570985008687907852837564279074904382605163141518161494337", 10) + scopes.AssignOrUpdateVariable("N", *n) + return divModNPacked(ids, vm, scopes, n) +} + +func divModNPackedDivModExternalN(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes) error { + n, err := FetchScopeVar[big.Int]("N", scopes) + if err != nil { + return err + } + return divModNPacked(ids, vm, scopes, &n) +} + +func divModNSafeDiv(ids IdsManager, scopes *ExecutionScopes, aAlias string, bAlias string, addOne bool) error { + // Fetch scope variables + a, err := FetchScopeVar[big.Int](aAlias, scopes) + if err != nil { + return err + } + + b, err := FetchScopeVar[big.Int](bAlias, scopes) + if err != nil { + return err + } + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil { + return err + } + + n, err := FetchScopeVar[big.Int]("N", scopes) + if err != nil { + return err + } + + // Hint logic + value, err := utils.SafeDivBig(new(big.Int).Sub(new(big.Int).Mul(&res, &b), &a), &n) + if err != nil { + return err + } + if addOne { + value = new(big.Int).Add(value, big.NewInt(1)) + } + // Update scope + scopes.AssignOrUpdateVariable("value", *value) + return nil +} diff --git a/pkg/hints/signature_hints_test.go b/pkg/hints/signature_hints_test.go new file mode 100644 index 00000000..4c3de522 --- /dev/null +++ b/pkg/hints/signature_hints_test.go @@ -0,0 +1,159 @@ +package hints_test + +import ( + "math/big" + "testing" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" +) + +func TestDivModNPackedDivMod(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "a": { + NewMaybeRelocatableFelt(FeltFromUint64(10)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "b": { + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_PACKED_DIVMOD_V1, + }) + scopes := NewExecutionScopes() + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_PACKED_DIVMOD_V1 hint test failed with error %s", err) + } + // Check result in scope + expectedRes := big.NewInt(5) + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil || res.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope value res") + } + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope var value") + } +} + +func TestDivModNPackedDivModExternalN(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "a": { + NewMaybeRelocatableFelt(FeltFromUint64(20)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "b": { + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(7)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N hint test failed with error %s", err) + } + // Check result in scope + expectedRes := big.NewInt(3) + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil || res.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope value res") + } + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope var value") + } +} + +func TestDivModSafeDivOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_SAFE_DIV, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(5)) + scopes.AssignOrUpdateVariable("a", *big.NewInt(10)) + scopes.AssignOrUpdateVariable("b", *big.NewInt(30)) + scopes.AssignOrUpdateVariable("res", *big.NewInt(2)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_SAFE_DIV hint test failed with error %s", err) + } + // Check result in scope + expectedValue := big.NewInt(10) // (2 * 30 - 10) / 5 = 10 + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedValue) != 0 { + t.Error("Wrong/No scope value val") + } +} + +func TestDivModSafeDivPlusOneOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_SAFE_DIV_PLUS_ONE, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(5)) + scopes.AssignOrUpdateVariable("a", *big.NewInt(10)) + scopes.AssignOrUpdateVariable("b", *big.NewInt(30)) + scopes.AssignOrUpdateVariable("res", *big.NewInt(2)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_SAFE_DIV_PLUS_ONE hint test failed with error %s", err) + } + // Check result in scope + expectedValue := big.NewInt(11) // (2 * 30 - 10) / 5 + 1 = 11 + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedValue) != 0 { + t.Error("Wrong/No scope value val") + } +} diff --git a/pkg/math_utils/utils.go b/pkg/math_utils/utils.go deleted file mode 100644 index 282ba4a3..00000000 --- a/pkg/math_utils/utils.go +++ /dev/null @@ -1,27 +0,0 @@ -package math_utils - -import ( - "github.com/pkg/errors" - "math/big" -) - -// Finds a nonnegative integer x < p such that (m * x) % p == n. -func DivMod(n *big.Int, m *big.Int, p *big.Int) (*big.Int, error) { - a := new(big.Int) - gcd := new(big.Int) - gcd.GCD(a, nil, m, p) - - if gcd.Cmp(big.NewInt(1)) != 0 { - return nil, errors.Errorf("gcd(%s, %s) != 1", m, p) - } - - return n.Mul(n, a).Mod(n, p), nil -} - -func ISqrt(x *big.Int) (*big.Int, error) { - if x.Sign() == -1 { - return nil, errors.Errorf("Expected x: %s to be non-negative", x) - } - res := new(big.Int) - return res.Sqrt(x), nil -} diff --git a/pkg/math_utils/utils_test.go b/pkg/math_utils/utils_test.go deleted file mode 100644 index c4eee853..00000000 --- a/pkg/math_utils/utils_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package math_utils_test - -import ( - "math/big" - "testing" - - . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" -) - -func TestDivModOk(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - expected := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - b.SetString("4020711254448367604954374443741161860304516084891705811279711044808359405970", 10) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - expected.SetString("2904750555256547440469454488220756360634457312540595732507835416669695939476", 10) - - num, err := DivMod(a, b, prime) - if err != nil { - t.Errorf("DivMod failed with error: %s", err) - } - if num.Cmp(expected) != 0 { - t.Errorf("Expected result: %s to be equal to %s", num, expected) - } -} - -func TestDivModMZeroFail(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - - _, err := DivMod(a, b, prime) - if err == nil { - t.Errorf("DivMod expected to failed with gcd != 1") - } -} - -func TestDivModMEqPFail(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - b.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - - _, err := DivMod(a, b, prime) - if err == nil { - t.Errorf("DivMod expected to failed with gcd != 1") - } -} - -func TestIsSqrtOk(t *testing.T) { - x := new(big.Int) - y := new(big.Int) - x.SetString("4573659632505831259480", 10) - y.Mul(x, x) - - sqr_y, err := ISqrt(y) - if err != nil { - t.Errorf("ISqrt failed with error: %s", err) - } - if x.Cmp(sqr_y) != 0 { - t.Errorf("Failed to get square root of x^2, x: %s", x) - } -} - -func TestCalculateIsqrtA(t *testing.T) { - x := new(big.Int) - x.SetString("81", 10) - sqrt, err := ISqrt(x) - if err != nil { - t.Error("ISqrt failed") - } - - expected := new(big.Int) - expected.SetString("9", 10) - - if sqrt.Cmp(expected) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", expected, sqrt) - } -} - -func TestCalculateIsqrtB(t *testing.T) { - x := new(big.Int) - x.SetString("4573659632505831259480", 10) - square := new(big.Int) - square = square.Mul(x, x) - - sqrt, err := ISqrt(square) - if err != nil { - t.Error("ISqrt failed") - } - - if sqrt.Cmp(x) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) - } -} - -func TestCalculateIsqrtC(t *testing.T) { - x := new(big.Int) - x.SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10) - square := new(big.Int) - square = square.Mul(x, x) - - sqrt, err := ISqrt(square) - if err != nil { - t.Error("ISqrt failed") - } - - if sqrt.Cmp(x) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) - } -} - -func TestIsSqrtFail(t *testing.T) { - x := big.NewInt(-1) - - _, err := ISqrt(x) - if err == nil { - t.Errorf("expected ISqrt to fail") - } -} diff --git a/pkg/utils/math_utils.go b/pkg/utils/math_utils.go index b1054f70..7f011a6c 100644 --- a/pkg/utils/math_utils.go +++ b/pkg/utils/math_utils.go @@ -63,3 +63,49 @@ func SafeDivBig(x *big.Int, y *big.Int) (*big.Int, error) { } return q, nil } + +// Finds a nonnegative integer x < p such that (m * x) % p == n. +func DivMod(n *big.Int, m *big.Int, p *big.Int) (*big.Int, error) { + a, _, c := Igcdex(m, p) + if c.Cmp(big.NewInt(1)) != 0 { + return nil, errors.Errorf("Operation failed: divmod(%s, %s, %s), igcdex(%s, %s) != 1 ", n.Text(10), m.Text(10), p.Text(10), m.Text(10), p.Text(10)) + } + return new(big.Int).Mod(new(big.Int).Mul(n, a), p), nil +} + +func Igcdex(a *big.Int, b *big.Int) (*big.Int, *big.Int, *big.Int) { + zero := big.NewInt(0) + one := big.NewInt(1) + switch true { + case a.Cmp(zero) == 0 && b.Cmp(zero) == 0: + return zero, one, zero + case a.Cmp(zero) == 0: + return zero, big.NewInt(int64(a.Sign())), new(big.Int).Abs(b) + case b.Cmp(zero) == 0: + return big.NewInt(int64(a.Sign())), zero, new(big.Int).Abs(a) + default: + xSign := big.NewInt(int64(a.Sign())) + ySign := big.NewInt(int64(b.Sign())) + a = new(big.Int).Abs(a) + b = new(big.Int).Abs(b) + x, y, r, s := big.NewInt(1), big.NewInt(0), big.NewInt(0), big.NewInt(1) + for b.Cmp(zero) != 0 { + q, c := new(big.Int).DivMod(a, b, new(big.Int)) + x = new(big.Int).Sub(x, new(big.Int).Mul(q, r)) + y = new(big.Int).Sub(y, new(big.Int).Mul(q, s)) + + a, b, r, s, x, y = b, c, x, y, r, s + } + + return new(big.Int).Mul(x, xSign), new(big.Int).Mul(y, ySign), a + + } +} + +func ISqrt(x *big.Int) (*big.Int, error) { + if x.Sign() == -1 { + return nil, errors.Errorf("Expected x: %s to be non-negative", x) + } + res := new(big.Int) + return res.Sqrt(x), nil +} diff --git a/pkg/utils/math_utils_test.go b/pkg/utils/math_utils_test.go index e3b2a152..3b308322 100644 --- a/pkg/utils/math_utils_test.go +++ b/pkg/utils/math_utils_test.go @@ -44,3 +44,164 @@ func TestSafeDivBigErrZeroDivison(t *testing.T) { t.Error("SafeDivBig should have failed") } } + +func TestIgcdex11(t *testing.T) { + a := big.NewInt(1) + b := big.NewInt(1) + expectedX, expectedY, expectedZ := big.NewInt(0), big.NewInt(1), big.NewInt(1) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex00(t *testing.T) { + a := big.NewInt(0) + b := big.NewInt(0) + expectedX, expectedY, expectedZ := big.NewInt(0), big.NewInt(1), big.NewInt(0) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex10(t *testing.T) { + a := big.NewInt(1) + b := big.NewInt(0) + expectedX, expectedY, expectedZ := big.NewInt(1), big.NewInt(0), big.NewInt(1) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex46(t *testing.T) { + a := big.NewInt(4) + b := big.NewInt(6) + expectedX, expectedY, expectedZ := big.NewInt(-1), big.NewInt(1), big.NewInt(2) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestDivModOk(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + expected := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + b.SetString("4020711254448367604954374443741161860304516084891705811279711044808359405970", 10) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + expected.SetString("2904750555256547440469454488220756360634457312540595732507835416669695939476", 10) + + num, err := DivMod(a, b, prime) + if err != nil { + t.Errorf("DivMod failed with error: %s", err) + } + if num.Cmp(expected) != 0 { + t.Errorf("Expected result: %s to be equal to %s", num, expected) + } +} + +func TestDivModMZeroFail(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + + _, err := DivMod(a, b, prime) + if err == nil { + t.Errorf("DivMod expected to failed with gcd != 1") + } +} + +func TestDivModMEqPFail(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + b.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + + _, err := DivMod(a, b, prime) + if err == nil { + t.Errorf("DivMod expected to failed with gcd != 1") + } +} + +func TestIsSqrtOk(t *testing.T) { + x := new(big.Int) + y := new(big.Int) + x.SetString("4573659632505831259480", 10) + y.Mul(x, x) + + sqr_y, err := ISqrt(y) + if err != nil { + t.Errorf("ISqrt failed with error: %s", err) + } + if x.Cmp(sqr_y) != 0 { + t.Errorf("Failed to get square root of x^2, x: %s", x) + } +} + +func TestCalculateIsqrtA(t *testing.T) { + x := new(big.Int) + x.SetString("81", 10) + sqrt, err := ISqrt(x) + if err != nil { + t.Error("ISqrt failed") + } + + expected := new(big.Int) + expected.SetString("9", 10) + + if sqrt.Cmp(expected) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", expected, sqrt) + } +} + +func TestCalculateIsqrtB(t *testing.T) { + x := new(big.Int) + x.SetString("4573659632505831259480", 10) + square := new(big.Int) + square = square.Mul(x, x) + + sqrt, err := ISqrt(square) + if err != nil { + t.Error("ISqrt failed") + } + + if sqrt.Cmp(x) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) + } +} + +func TestCalculateIsqrtC(t *testing.T) { + x := new(big.Int) + x.SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10) + square := new(big.Int) + square = square.Mul(x, x) + + sqrt, err := ISqrt(square) + if err != nil { + t.Error("ISqrt failed") + } + + if sqrt.Cmp(x) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) + } +} + +func TestIsSqrtFail(t *testing.T) { + x := big.NewInt(-1) + + _, err := ISqrt(x) + if err == nil { + t.Errorf("expected ISqrt to fail") + } +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 845a3076..f0685afb 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -329,6 +329,10 @@ func TestSplitIntHintProofMode(t *testing.T) { testProgramProof("split_int", t) } +func TestDivModN(t *testing.T) { + testProgram("div_mod_n", t) +} + func TestEcDoubleAssign(t *testing.T) { testProgram("ec_double_assign", t) }