From 98aa21d68c389c69de6ef3008ed6e2e362750f7e Mon Sep 17 00:00:00 2001 From: liorgold2 <38202661+liorgold2@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:27:10 +0300 Subject: [PATCH] Add int_range_try_new libfunc. (#6288) --- .../src/core_libfunc_ap_change.rs | 1 + .../src/core_libfunc_cost_base.rs | 6 ++ .../src/invocations/range.rs | 58 +++++++++++++++- .../src/extensions/modules/range.rs | 51 ++++++++++++++ .../src/allowed_libfuncs_lists/all.json | 1 + tests/e2e_test_data/libfuncs/range | 69 +++++++++++++++++++ 6 files changed, 185 insertions(+), 1 deletion(-) diff --git a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs index 254a8436a28..9c30c72dca3 100644 --- a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs +++ b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs @@ -413,6 +413,7 @@ pub fn core_libfunc_ap_change<InfoProvider: InvocationApChangeInfoProvider>( vec![ApChange::Known(0)] } IntRange(libfunc) => match libfunc { + IntRangeConcreteLibfunc::TryNew(_) => vec![ApChange::Known(2), ApChange::Known(3)], IntRangeConcreteLibfunc::PopFront(_) => vec![ApChange::Known(1), ApChange::Known(1)], }, } diff --git a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs index 8c91fcca9b6..2c3ea029f00 100644 --- a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs +++ b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs @@ -583,6 +583,12 @@ pub fn core_libfunc_cost( } }, IntRange(libfunc) => match libfunc { + IntRangeConcreteLibfunc::TryNew(_) => { + vec![ + ConstCost { steps: 3, holes: 0, range_checks: 1, range_checks96: 0 }.into(), + ConstCost { steps: 5, holes: 0, range_checks: 1, range_checks96: 0 }.into(), + ] + } IntRangeConcreteLibfunc::PopFront(_) => { vec![ConstCost::steps(2).into(), ConstCost::steps(2).into()] } diff --git a/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs b/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs index c5a700cfbff..7cd40294f20 100644 --- a/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs +++ b/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs @@ -1,9 +1,13 @@ use cairo_lang_casm::builder::CasmBuilder; use cairo_lang_casm::casm_build_extend; +use cairo_lang_sierra::extensions::gas::CostTokenType; use cairo_lang_sierra::extensions::range::IntRangeConcreteLibfunc; +use num_bigint::BigInt; use super::{CompiledInvocation, CompiledInvocationBuilder, InvocationError}; -use crate::invocations::{add_input_variables, get_non_fallthrough_statement_id}; +use crate::invocations::{ + add_input_variables, get_non_fallthrough_statement_id, BuiltinInfo, CostValidationInfo, +}; /// Builds instructions for `Range` operations. pub fn build( @@ -12,9 +16,61 @@ pub fn build( ) -> Result<CompiledInvocation, InvocationError> { match libfunc { IntRangeConcreteLibfunc::PopFront(_) => build_pop_front(builder), + IntRangeConcreteLibfunc::TryNew(_) => build_try_new(builder), } } +/// Libfunc for constructing a range `[a, b)` if `a <= b`. +fn build_try_new( + builder: CompiledInvocationBuilder<'_>, +) -> Result<CompiledInvocation, InvocationError> { + let [expr_range_check, expr_start, expr_end] = builder.try_get_refs::<3>()?; + let range_check = expr_range_check.try_unpack_single()?; + let start = expr_start.try_unpack_single()?; + let end = expr_end.try_unpack_single()?; + + let mut casm_builder = CasmBuilder::default(); + add_input_variables! {casm_builder, + deref start; + deref end; + buffer(1) range_check; + }; + casm_build_extend! {casm_builder, + let orig_range_check = range_check; + + tempvar diff = end - start; + const bound = BigInt::from(u128::MAX) + BigInt::from(1); + tempvar is_valid_range; + hint TestLessThan {lhs: diff, rhs: bound} into {dst: is_valid_range}; + jump Valid if is_valid_range != 0; + + // Invalid range. + tempvar diff_fixed = diff + bound; + assert diff_fixed = *(range_check++); + jump Failure; + + Valid: + assert diff = *(range_check++); + }; + + let failure_handle = get_non_fallthrough_statement_id(&builder); + Ok(builder.build_from_casm_builder( + casm_builder, + [ + ("Fallthrough", &[&[range_check], &[start, end]], None), + ("Failure", &[&[range_check]], Some(failure_handle)), + ], + CostValidationInfo { + builtin_infos: vec![BuiltinInfo { + cost_token_ty: CostTokenType::RangeCheck, + start: orig_range_check, + end: range_check, + }], + extra_costs: None, + }, + )) +} + /// Libfunc for reducing `[a, b)` to `[a + 1, b)`. fn build_pop_front( builder: CompiledInvocationBuilder<'_>, diff --git a/crates/cairo-lang-sierra/src/extensions/modules/range.rs b/crates/cairo-lang-sierra/src/extensions/modules/range.rs index e9b80541375..687e14f264a 100644 --- a/crates/cairo-lang-sierra/src/extensions/modules/range.rs +++ b/crates/cairo-lang-sierra/src/extensions/modules/range.rs @@ -3,6 +3,8 @@ use super::int::signed::{Sint16Type, Sint32Type, Sint64Type, Sint8Type}; use super::int::signed128::Sint128Type; use super::int::unsigned::{Uint16Type, Uint32Type, Uint64Type, Uint8Type}; use super::int::unsigned128::Uint128Type; +use super::range_check::RangeCheckType; +use super::utils::Range; use crate::define_libfunc_hierarchy; use crate::extensions::lib_func::{ BranchSignature, DeferredOutputKind, LibfuncSignature, OutputVarInfo, ParamSignature, @@ -72,10 +74,59 @@ pub type IntRangeType = GenericTypeArgGenericTypeWrapper<IntRangeTypeWrapped>; define_libfunc_hierarchy! { pub enum IntRangeLibfunc { + TryNew(IntRangeTryNewLibfunc), PopFront(IntRangePopFrontLibfunc), }, IntRangeConcreteLibfunc } +/// Libfunc that constructs the range `[x, y)` if `x <= y` and fails otherwise. +#[derive(Default)] +pub struct IntRangeTryNewLibfunc {} +impl SignatureOnlyGenericLibfunc for IntRangeTryNewLibfunc { + const STR_ID: &'static str = "int_range_try_new"; + + fn specialize_signature( + &self, + context: &dyn SignatureSpecializationContext, + args: &[GenericArg], + ) -> Result<LibfuncSignature, SpecializationError> { + let ty = args_as_single_type(args)?; + let range_ty = context.get_wrapped_concrete_type(IntRangeType::id(), ty.clone())?; + let range_check_type = context.get_concrete_type(RangeCheckType::id(), &[])?; + + if !Range::from_type(context, ty.clone())?.is_small_range() { + return Err(SpecializationError::UnsupportedGenericArg); + } + + Ok(LibfuncSignature { + param_signatures: vec![ + ParamSignature::new(range_check_type.clone()).with_allow_add_const(), + ParamSignature::new(ty.clone()), + ParamSignature::new(ty.clone()), + ], + branch_signatures: vec![ + // Success. + BranchSignature { + vars: vec![ + OutputVarInfo::new_builtin(range_check_type.clone(), 0), + OutputVarInfo { + ty: range_ty, + ref_info: OutputVarReferenceInfo::SimpleDerefs, + }, + ], + ap_change: SierraApChange::Known { new_vars_only: false }, + }, + // Failure. + BranchSignature { + vars: vec![OutputVarInfo::new_builtin(range_check_type, 0)], + ap_change: SierraApChange::Known { new_vars_only: false }, + }, + ], + fallthrough: Some(0), + }) + } +} + /// Libfunc that takes the range `[x, y)` and if `x < y`, returns the range `[x + 1, y)` and the /// value `x`. #[derive(Default)] diff --git a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json index c3f6c92d592..c52bd0cc2cf 100644 --- a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json +++ b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json @@ -137,6 +137,7 @@ "into_box", "into_u96_guarantee", "int_range_pop_front", + "int_range_try_new", "jump", "keccak_syscall", "sha256_state_handle_init", diff --git a/tests/e2e_test_data/libfuncs/range b/tests/e2e_test_data/libfuncs/range index 7c9057bb41a..8d143408b35 100644 --- a/tests/e2e_test_data/libfuncs/range +++ b/tests/e2e_test_data/libfuncs/range @@ -58,3 +58,72 @@ store_temp<core::internal::OptionRev::<(test::IntRange::<core::integer::i16>, co return([6]); // 10 test::foo@0([0]: IntRange<i16>) -> (core::internal::OptionRev::<(test::IntRange::<core::integer::i16>, core::integer::i16)>); + +//! > ========================================================================== + +//! > range_try_new libfunc + +//! > test_runner_name +SmallE2ETestRunner + +//! > cairo +// TODO(lior): Move to `range.cairo`. +extern type IntRange<T>; +extern fn int_range_try_new<T>(x: T, y: T) -> Option<IntRange<T>> implicits(RangeCheck) nopanic; + +fn foo(x: i16, y: i16) -> Option<IntRange<i16>> { + int_range_try_new(x, y) +} + +//! > casm +[fp + -3] = [ap + 0] + [fp + -4], ap++; +%{ memory[ap + 0] = memory[ap + -1] < 340282366920938463463374607431768211456 %} +jmp rel 7 if [ap + 0] != 0, ap++; +[ap + 0] = [ap + -2] + 340282366920938463463374607431768211456, ap++; +[ap + -1] = [[fp + -5] + 0]; +jmp rel 12; +[ap + -2] = [[fp + -5] + 0]; +ap += 1; +[ap + 0] = [fp + -5] + 1, ap++; +[ap + 0] = 0, ap++; +[ap + 0] = [fp + -4], ap++; +[ap + 0] = [fp + -3], ap++; +ret; +[ap + 0] = [fp + -5] + 1, ap++; +[ap + 0] = 1, ap++; +[ap + 0] = 0, ap++; +[ap + 0] = 0, ap++; +ret; + +//! > function_costs +test::foo: OrderedHashMap({Const: 970}) + +//! > sierra_code +type RangeCheck = RangeCheck [storable: true, drop: false, dup: false, zero_sized: false]; +type Unit = Struct<ut@Tuple> [storable: true, drop: true, dup: true, zero_sized: true]; +type IntRange<i16> = IntRange<i16> [storable: true, drop: true, dup: true, zero_sized: false]; +type core::option::Option::<test::IntRange::<core::integer::i16>> = Enum<ut@core::option::Option::<test::IntRange::<core::integer::i16>>, IntRange<i16>, Unit> [storable: true, drop: true, dup: true, zero_sized: false]; +type i16 = i16 [storable: true, drop: true, dup: true, zero_sized: false]; + +libfunc int_range_try_new<i16> = int_range_try_new<i16>; +libfunc branch_align = branch_align; +libfunc enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0> = enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0>; +libfunc store_temp<RangeCheck> = store_temp<RangeCheck>; +libfunc store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>> = store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>; +libfunc struct_construct<Unit> = struct_construct<Unit>; +libfunc enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1> = enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1>; + +int_range_try_new<i16>([0], [1], [2]) { fallthrough([3], [4]) 6([5]) }; // 0 +branch_align() -> (); // 1 +enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0>([4]) -> ([6]); // 2 +store_temp<RangeCheck>([3]) -> ([3]); // 3 +store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>([6]) -> ([6]); // 4 +return([3], [6]); // 5 +branch_align() -> (); // 6 +struct_construct<Unit>() -> ([7]); // 7 +enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1>([7]) -> ([8]); // 8 +store_temp<RangeCheck>([5]) -> ([5]); // 9 +store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>([8]) -> ([8]); // 10 +return([5], [8]); // 11 + +test::foo@0([0]: RangeCheck, [1]: i16, [2]: i16) -> (RangeCheck, core::option::Option::<test::IntRange::<core::integer::i16>>);