Skip to content

Commit

Permalink
Implement remaining hints in keccak module (#276)
Browse files Browse the repository at this point in the history
* Add hint codes

* Move divrem hint codes to math hint codes

* Finish implementation of unsignedDivRem hint

* Add some fields to range check builtin runner

* WIP implementation of signedDivRem hint

* Dummy commit

* Save work in progress

* Add FeltFromBigInt function

* Add unsigned div rem integration test

* Implement UnsafeKeccak

* Update dependencies

* Add unit test

* Fix division bug and make integration test pass

* Fix hash

* Save work in progress

* Add unit tests

* Add integration test

* Finished unit test for divrem hints

* Remove unused commented code

* Add missing file

* Add constant + GetStructFieldRelocatable + start hint

* Add test file

* Add MemorySegmentManager.GetFeltRange

* Progress

* Finish hint

* Add unit test

* Add integration test

* Remove nParts and bound from range check runner fields, add method to calculate bound from constant and add test to assert bound is never zero

* Remove unused input parameter to NewRangeCheckBuiltinRunner

* Change number to constant in Bound impl

* Add hint + test

* Add integration tests

* Add hint + tests

* Add hint

* Fix test

* Add test

* Implement CAIRO_KECCAK_FINALIZE

* Extract aliased constants

* Fix index out of range

* Fix var name

* Make error more expressive

* Fix hint

* Add test, fix test

* Add test files

* Implement keccakWriteArgs

* Extend hint parsing

* Update program.go

* Fix DivCeil

* Add integration test

* Fix DivCeil

* v2

* Fix bug

* Fix bug but better

* Update keccak_add_uint256.cairo

* Update cairo_run_test.go

---------

Co-authored-by: Mariano Nicolini <[email protected]>
  • Loading branch information
fmoletta and entropidelic authored Sep 26, 2023
1 parent 64e07f9 commit d638fec
Show file tree
Hide file tree
Showing 13 changed files with 616 additions and 5 deletions.
29 changes: 29 additions & 0 deletions cairo_programs/cairo_keccak.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
%builtins range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import cairo_keccak, finalize_keccak
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (keccak_ptr: felt*) = alloc();
let keccak_ptr_start = keccak_ptr;

let (inputs: felt*) = alloc();

assert inputs[0] = 8031924123371070792;
assert inputs[1] = 560229490;

let n_bytes = 16;

let (res: Uint256) = cairo_keccak{keccak_ptr=keccak_ptr}(inputs=inputs, n_bytes=n_bytes);

assert res.low = 293431514620200399776069983710520819074;
assert res.high = 317109767021952548743448767588473366791;

finalize_keccak(keccak_ptr_start=keccak_ptr_start, keccak_ptr_end=keccak_ptr);

return ();
}
31 changes: 31 additions & 0 deletions cairo_programs/keccak_add_uint256.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
%builtins output range_check bitwise

from starkware.cairo.common.keccak_utils.keccak_utils import keccak_add_uint256
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (inputs) = alloc();
let inputs_start = inputs;

let num = Uint256(34623634663146736, 598249824422424658356);

keccak_add_uint256{inputs=inputs_start}(num=num, bigend=0);

assert inputs[0] = 34623634663146736;
assert inputs[1] = 0;
assert inputs[2] = 7954014063719006644;
assert inputs[3] = 32;

serialize_word(inputs[0]);
serialize_word(inputs[1]);
serialize_word(inputs[2]);
serialize_word(inputs[3]);

return ();
}

107 changes: 107 additions & 0 deletions cairo_programs/keccak_integration_tests.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
%builtins range_check bitwise

from starkware.cairo.common.keccak import unsafe_keccak, unsafe_keccak_finalize, KeccakState
from starkware.cairo.common.cairo_keccak.keccak import cairo_keccak, finalize_keccak
from starkware.cairo.common.keccak_utils.keccak_utils import keccak_add_uint256
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.math import unsigned_div_rem

func fill_array(array: felt*, base: felt, step: felt, array_length: felt, iter: felt) {
if (iter == array_length) {
return ();
}
assert array[iter] = base + step * iter;
return fill_array(array, base, step, array_length, iter + 1);
}

func test_integration{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}(iter: felt, last: felt) {
alloc_locals;
if (iter == last) {
return ();
}

let (data_1: felt*) = alloc();
let data_len: felt = 15;
let chunk_len: felt = 5;

fill_array(data_1, iter, iter + 1, data_len, 0);

let (low_1: felt, high_1: felt) = unsafe_keccak(data_1, chunk_len);
let (low_2: felt, high_2: felt) = unsafe_keccak(data_1 + chunk_len, chunk_len);
let (low_3: felt, high_3: felt) = unsafe_keccak(data_1 + 2 * chunk_len, chunk_len);

// With the results of unsafe_keccak, create an array to pass to unsafe_keccak_finalize
// through a KeccakState
let (data_2: felt*) = alloc();
assert data_2[0] = low_1;
assert data_2[1] = high_1;
assert data_2[2] = low_2;
assert data_2[3] = high_2;
assert data_2[4] = low_3;
assert data_2[5] = high_3;

let keccak_state: KeccakState = KeccakState(start_ptr=data_2, end_ptr=data_2 + 6);
let res_1: Uint256 = unsafe_keccak_finalize(keccak_state);

let (data_3: felt*) = alloc();

// This is done to make sure that the numbers inserted in data_3
// fit in a u64
let (q, r) = unsigned_div_rem(res_1.low, 18446744073709551615);
assert data_3[0] = q;
let (q, r) = unsigned_div_rem(res_1.high, 18446744073709551615);
assert data_3[1] = q;

let (keccak_ptr: felt*) = alloc();
let keccak_ptr_start = keccak_ptr;

let res_2: Uint256 = cairo_keccak{keccak_ptr=keccak_ptr}(data_3, 16);

finalize_keccak(keccak_ptr_start=keccak_ptr_start, keccak_ptr_end=keccak_ptr);

let (inputs) = alloc();
let inputs_start = inputs;
keccak_add_uint256{inputs=inputs_start}(num=res_2, bigend=0);

// These values are hardcoded for last = 10
// Since we are dealing with hash functions and using the output of one of them
// as the input of the other, asserting only the last results of the iteration
// should be enough
if (iter == last - 1 and last == 10) {
assert res_2.low = 3896836249413878817054429671793519200;
assert res_2.high = 253424239110447628170109510737834198489;
assert inputs[0] = 16681956707691293280;
assert inputs[1] = 211247916371739620;
assert inputs[2] = 6796127878994642393;
assert inputs[3] = 13738155530201662906;
}

// These values are hardcoded for last = 100
// This should be used for benchmarking.
if (iter == last - 1 and last == 100) {
assert res_2.low = 52798800345724801884797411011515944813;
assert res_2.high = 159010026777930121161844734347918361509;
assert inputs[0] = 14656556134934286189;
assert inputs[1] = 2862228701973161639;
assert inputs[2] = 206697371206337445;
assert inputs[3] = 8619950823980503604;
}

return test_integration{range_check_ptr=range_check_ptr, bitwise_ptr=bitwise_ptr}(
iter + 1, last
);
}

func run_test{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}(last: felt) {
test_integration(0, last);
return ();
}

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
run_test(10);
return ();
}
4 changes: 2 additions & 2 deletions pkg/builtins/keccak.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (k *KeccakBuiltinRunner) DeduceMemoryCell(address Relocatable, mem *Memory)
for i := 0; i < 25; i++ {
output_message_u64[i] = binary.LittleEndian.Uint64(output_message_bytes[8*i : 8*i+8])
}
keccakF1600(&output_message_u64)
KeccakF1600(&output_message_u64)

// Convert back to bytes
output_message := make([]byte, 0, 200)
Expand Down Expand Up @@ -152,7 +152,7 @@ var rc = [24]uint64{

// keccakF1600 applies the Keccak permutation to a 1600b-wide
// state represented as a slice of 25 uint64s.
func keccakF1600(a *[25]uint64) {
func KeccakF1600(a *[25]uint64) {
// Implementation translated from Keccak-inplace.c
// in the keccak reference code.
var t, bc0, bc1, bc2, bc3, bc4, d0, d1, d2, d3, d4 uint64
Expand Down
33 changes: 33 additions & 0 deletions pkg/hints/hint_codes/keccak_hint_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,36 @@ package hint_codes
const UNSAFE_KECCAK = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"

const UNSAFE_KECCAK_FINALIZE = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"

const COMPARE_BYTES_IN_WORD_NONDET = "memory[ap] = to_felt_or_relocatable(ids.n_bytes < ids.BYTES_IN_WORD)"

const COMPARE_KECCAK_FULL_RATE_IN_BYTES_NONDET = "memory[ap] = to_felt_or_relocatable(ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES)"

const BLOCK_PERMUTATION = `from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
assert 0 <= _keccak_state_size_felts < 100
output_values = keccak_func(memory.get_range(
ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))
segments.write_arg(ids.keccak_ptr, output_values)`

const CAIRO_KECCAK_FINALIZE_V1 = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
assert 0 <= _keccak_state_size_felts < 100
assert 0 <= _block_size < 10
inp = [0] * _keccak_state_size_felts
padding = (inp + keccak_func(inp)) * _block_size
segments.write_arg(ids.keccak_ptr_end, padding)`

const CAIRO_KECCAK_FINALIZE_V2 = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
assert 0 <= _keccak_state_size_felts < 100
assert 0 <= _block_size < 1000
inp = [0] * _keccak_state_size_felts
padding = (inp + keccak_func(inp)) * _block_size
segments.write_arg(ids.keccak_ptr_end, padding)`

const KECCAK_WRITE_ARGS = `segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])
segments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])`
12 changes: 12 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return unsafeKeccak(data.Ids, vm, *execScopes)
case UNSAFE_KECCAK_FINALIZE:
return unsafeKeccakFinalize(data.Ids, vm)
case COMPARE_BYTES_IN_WORD_NONDET:
return compareBytesInWordNondet(data.Ids, vm, constants)
case COMPARE_KECCAK_FULL_RATE_IN_BYTES_NONDET:
return compareKeccakFullRateInBytesNondet(data.Ids, vm, constants)
case BLOCK_PERMUTATION:
return blockPermutation(data.Ids, vm, constants)
case CAIRO_KECCAK_FINALIZE_V1:
return cairoKeccakFinalize(data.Ids, vm, constants, 10)
case CAIRO_KECCAK_FINALIZE_V2:
return cairoKeccakFinalize(data.Ids, vm, constants, 1000)
case KECCAK_WRITE_ARGS:
return keccakWriteArgs(data.Ids, vm)
case UNSIGNED_DIV_REM:
return unsignedDivRem(data.Ids, vm)
case SIGNED_DIV_REM:
Expand Down
12 changes: 12 additions & 0 deletions pkg/hints/hint_utils/hint_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ func ParseHintReference(reference parser.Reference) HintReference {
ValueType: valueType,
}
}
// Reference no dereference 2 offsets - + : cast(reg - off1 + off2, type)
_, err = fmt.Sscanf(valueString, "cast%c%c - %d + %d, %s", &off1Reg0, &off1Reg1, &off1, &off2, &valueType)
if err == nil {
off1Reg := getRegister(off1Reg0, off1Reg1)
return HintReference{
ApTrackingData: reference.ApTrackingData,
Offset1: OffsetValue{ValueType: Reference, Register: off1Reg, Value: -off1},
Offset2: OffsetValue{Value: off2},
Dereference: dereference,
ValueType: valueType,
}
}
// No matches (aka wrong format)
return HintReference{ApTrackingData: reference.ApTrackingData}
}
Expand Down
14 changes: 14 additions & 0 deletions pkg/hints/hint_utils/hint_reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,17 @@ func TestParseHintDereferenceReferenceDoubleDerefBothOffOmitted(t *testing.T) {
t.Errorf("Wrong parsed reference, %+v", ParseHintReference(reference))
}
}

func TestParseHintDereferenceValueMinusValPlusVal(t *testing.T) {
reference := parser.Reference{Value: "[cast(ap - 0 + (-1), felt*)]"}
expected := HintReference{
Offset1: OffsetValue{ValueType: Reference, Value: 0, Dereference: false},
Offset2: OffsetValue{ValueType: Value, Value: -1, Dereference: false},
ValueType: "felt*",
Dereference: true,
}

if ParseHintReference(reference) != expected {
t.Errorf("Wrong parsed reference, %+v", ParseHintReference(reference))
}
}
Loading

0 comments on commit d638fec

Please sign in to comment.