From 98aa21d68c389c69de6ef3008ed6e2e362750f7e Mon Sep 17 00:00:00 2001
From: liorgold2 <38202661+liorgold2@users.noreply.github.com>
Date: Wed, 28 Aug 2024 11:27:10 +0300
Subject: [PATCH] Add int_range_try_new libfunc. (#6288)

---
 .../src/core_libfunc_ap_change.rs             |  1 +
 .../src/core_libfunc_cost_base.rs             |  6 ++
 .../src/invocations/range.rs                  | 58 +++++++++++++++-
 .../src/extensions/modules/range.rs           | 51 ++++++++++++++
 .../src/allowed_libfuncs_lists/all.json       |  1 +
 tests/e2e_test_data/libfuncs/range            | 69 +++++++++++++++++++
 6 files changed, 185 insertions(+), 1 deletion(-)

diff --git a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs
index 254a8436a28..9c30c72dca3 100644
--- a/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs
+++ b/crates/cairo-lang-sierra-ap-change/src/core_libfunc_ap_change.rs
@@ -413,6 +413,7 @@ pub fn core_libfunc_ap_change<InfoProvider: InvocationApChangeInfoProvider>(
             vec![ApChange::Known(0)]
         }
         IntRange(libfunc) => match libfunc {
+            IntRangeConcreteLibfunc::TryNew(_) => vec![ApChange::Known(2), ApChange::Known(3)],
             IntRangeConcreteLibfunc::PopFront(_) => vec![ApChange::Known(1), ApChange::Known(1)],
         },
     }
diff --git a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
index 8c91fcca9b6..2c3ea029f00 100644
--- a/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
+++ b/crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
@@ -583,6 +583,12 @@ pub fn core_libfunc_cost(
             }
         },
         IntRange(libfunc) => match libfunc {
+            IntRangeConcreteLibfunc::TryNew(_) => {
+                vec![
+                    ConstCost { steps: 3, holes: 0, range_checks: 1, range_checks96: 0 }.into(),
+                    ConstCost { steps: 5, holes: 0, range_checks: 1, range_checks96: 0 }.into(),
+                ]
+            }
             IntRangeConcreteLibfunc::PopFront(_) => {
                 vec![ConstCost::steps(2).into(), ConstCost::steps(2).into()]
             }
diff --git a/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs b/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs
index c5a700cfbff..7cd40294f20 100644
--- a/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs
+++ b/crates/cairo-lang-sierra-to-casm/src/invocations/range.rs
@@ -1,9 +1,13 @@
 use cairo_lang_casm::builder::CasmBuilder;
 use cairo_lang_casm::casm_build_extend;
+use cairo_lang_sierra::extensions::gas::CostTokenType;
 use cairo_lang_sierra::extensions::range::IntRangeConcreteLibfunc;
+use num_bigint::BigInt;
 
 use super::{CompiledInvocation, CompiledInvocationBuilder, InvocationError};
-use crate::invocations::{add_input_variables, get_non_fallthrough_statement_id};
+use crate::invocations::{
+    add_input_variables, get_non_fallthrough_statement_id, BuiltinInfo, CostValidationInfo,
+};
 
 /// Builds instructions for `Range` operations.
 pub fn build(
@@ -12,9 +16,61 @@ pub fn build(
 ) -> Result<CompiledInvocation, InvocationError> {
     match libfunc {
         IntRangeConcreteLibfunc::PopFront(_) => build_pop_front(builder),
+        IntRangeConcreteLibfunc::TryNew(_) => build_try_new(builder),
     }
 }
 
+/// Libfunc for constructing a range `[a, b)` if `a <= b`.
+fn build_try_new(
+    builder: CompiledInvocationBuilder<'_>,
+) -> Result<CompiledInvocation, InvocationError> {
+    let [expr_range_check, expr_start, expr_end] = builder.try_get_refs::<3>()?;
+    let range_check = expr_range_check.try_unpack_single()?;
+    let start = expr_start.try_unpack_single()?;
+    let end = expr_end.try_unpack_single()?;
+
+    let mut casm_builder = CasmBuilder::default();
+    add_input_variables! {casm_builder,
+        deref start;
+        deref end;
+        buffer(1) range_check;
+    };
+    casm_build_extend! {casm_builder,
+        let orig_range_check = range_check;
+
+        tempvar diff = end - start;
+        const bound = BigInt::from(u128::MAX) + BigInt::from(1);
+        tempvar is_valid_range;
+        hint TestLessThan {lhs: diff, rhs: bound} into {dst: is_valid_range};
+        jump Valid if is_valid_range != 0;
+
+        // Invalid range.
+        tempvar diff_fixed = diff + bound;
+        assert diff_fixed = *(range_check++);
+        jump Failure;
+
+        Valid:
+        assert diff = *(range_check++);
+    };
+
+    let failure_handle = get_non_fallthrough_statement_id(&builder);
+    Ok(builder.build_from_casm_builder(
+        casm_builder,
+        [
+            ("Fallthrough", &[&[range_check], &[start, end]], None),
+            ("Failure", &[&[range_check]], Some(failure_handle)),
+        ],
+        CostValidationInfo {
+            builtin_infos: vec![BuiltinInfo {
+                cost_token_ty: CostTokenType::RangeCheck,
+                start: orig_range_check,
+                end: range_check,
+            }],
+            extra_costs: None,
+        },
+    ))
+}
+
 /// Libfunc for reducing `[a, b)` to `[a + 1, b)`.
 fn build_pop_front(
     builder: CompiledInvocationBuilder<'_>,
diff --git a/crates/cairo-lang-sierra/src/extensions/modules/range.rs b/crates/cairo-lang-sierra/src/extensions/modules/range.rs
index e9b80541375..687e14f264a 100644
--- a/crates/cairo-lang-sierra/src/extensions/modules/range.rs
+++ b/crates/cairo-lang-sierra/src/extensions/modules/range.rs
@@ -3,6 +3,8 @@ use super::int::signed::{Sint16Type, Sint32Type, Sint64Type, Sint8Type};
 use super::int::signed128::Sint128Type;
 use super::int::unsigned::{Uint16Type, Uint32Type, Uint64Type, Uint8Type};
 use super::int::unsigned128::Uint128Type;
+use super::range_check::RangeCheckType;
+use super::utils::Range;
 use crate::define_libfunc_hierarchy;
 use crate::extensions::lib_func::{
     BranchSignature, DeferredOutputKind, LibfuncSignature, OutputVarInfo, ParamSignature,
@@ -72,10 +74,59 @@ pub type IntRangeType = GenericTypeArgGenericTypeWrapper<IntRangeTypeWrapped>;
 
 define_libfunc_hierarchy! {
     pub enum IntRangeLibfunc {
+        TryNew(IntRangeTryNewLibfunc),
         PopFront(IntRangePopFrontLibfunc),
     }, IntRangeConcreteLibfunc
 }
 
+/// Libfunc that constructs the range `[x, y)` if `x <= y` and fails otherwise.
+#[derive(Default)]
+pub struct IntRangeTryNewLibfunc {}
+impl SignatureOnlyGenericLibfunc for IntRangeTryNewLibfunc {
+    const STR_ID: &'static str = "int_range_try_new";
+
+    fn specialize_signature(
+        &self,
+        context: &dyn SignatureSpecializationContext,
+        args: &[GenericArg],
+    ) -> Result<LibfuncSignature, SpecializationError> {
+        let ty = args_as_single_type(args)?;
+        let range_ty = context.get_wrapped_concrete_type(IntRangeType::id(), ty.clone())?;
+        let range_check_type = context.get_concrete_type(RangeCheckType::id(), &[])?;
+
+        if !Range::from_type(context, ty.clone())?.is_small_range() {
+            return Err(SpecializationError::UnsupportedGenericArg);
+        }
+
+        Ok(LibfuncSignature {
+            param_signatures: vec![
+                ParamSignature::new(range_check_type.clone()).with_allow_add_const(),
+                ParamSignature::new(ty.clone()),
+                ParamSignature::new(ty.clone()),
+            ],
+            branch_signatures: vec![
+                // Success.
+                BranchSignature {
+                    vars: vec![
+                        OutputVarInfo::new_builtin(range_check_type.clone(), 0),
+                        OutputVarInfo {
+                            ty: range_ty,
+                            ref_info: OutputVarReferenceInfo::SimpleDerefs,
+                        },
+                    ],
+                    ap_change: SierraApChange::Known { new_vars_only: false },
+                },
+                // Failure.
+                BranchSignature {
+                    vars: vec![OutputVarInfo::new_builtin(range_check_type, 0)],
+                    ap_change: SierraApChange::Known { new_vars_only: false },
+                },
+            ],
+            fallthrough: Some(0),
+        })
+    }
+}
+
 /// Libfunc that takes the range `[x, y)` and if `x < y`, returns the range `[x + 1, y)` and the
 /// value `x`.
 #[derive(Default)]
diff --git a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json
index c3f6c92d592..c52bd0cc2cf 100644
--- a/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json
+++ b/crates/cairo-lang-starknet-classes/src/allowed_libfuncs_lists/all.json
@@ -137,6 +137,7 @@
         "into_box",
         "into_u96_guarantee",
         "int_range_pop_front",
+        "int_range_try_new",
         "jump",
         "keccak_syscall",
         "sha256_state_handle_init",
diff --git a/tests/e2e_test_data/libfuncs/range b/tests/e2e_test_data/libfuncs/range
index 7c9057bb41a..8d143408b35 100644
--- a/tests/e2e_test_data/libfuncs/range
+++ b/tests/e2e_test_data/libfuncs/range
@@ -58,3 +58,72 @@ store_temp<core::internal::OptionRev::<(test::IntRange::<core::integer::i16>, co
 return([6]); // 10
 
 test::foo@0([0]: IntRange<i16>) -> (core::internal::OptionRev::<(test::IntRange::<core::integer::i16>, core::integer::i16)>);
+
+//! > ==========================================================================
+
+//! > range_try_new libfunc
+
+//! > test_runner_name
+SmallE2ETestRunner
+
+//! > cairo
+// TODO(lior): Move to `range.cairo`.
+extern type IntRange<T>;
+extern fn int_range_try_new<T>(x: T, y: T) -> Option<IntRange<T>> implicits(RangeCheck) nopanic;
+
+fn foo(x: i16, y: i16) -> Option<IntRange<i16>> {
+    int_range_try_new(x, y)
+}
+
+//! > casm
+[fp + -3] = [ap + 0] + [fp + -4], ap++;
+%{ memory[ap + 0] = memory[ap + -1] < 340282366920938463463374607431768211456 %}
+jmp rel 7 if [ap + 0] != 0, ap++;
+[ap + 0] = [ap + -2] + 340282366920938463463374607431768211456, ap++;
+[ap + -1] = [[fp + -5] + 0];
+jmp rel 12;
+[ap + -2] = [[fp + -5] + 0];
+ap += 1;
+[ap + 0] = [fp + -5] + 1, ap++;
+[ap + 0] = 0, ap++;
+[ap + 0] = [fp + -4], ap++;
+[ap + 0] = [fp + -3], ap++;
+ret;
+[ap + 0] = [fp + -5] + 1, ap++;
+[ap + 0] = 1, ap++;
+[ap + 0] = 0, ap++;
+[ap + 0] = 0, ap++;
+ret;
+
+//! > function_costs
+test::foo: OrderedHashMap({Const: 970})
+
+//! > sierra_code
+type RangeCheck = RangeCheck [storable: true, drop: false, dup: false, zero_sized: false];
+type Unit = Struct<ut@Tuple> [storable: true, drop: true, dup: true, zero_sized: true];
+type IntRange<i16> = IntRange<i16> [storable: true, drop: true, dup: true, zero_sized: false];
+type core::option::Option::<test::IntRange::<core::integer::i16>> = Enum<ut@core::option::Option::<test::IntRange::<core::integer::i16>>, IntRange<i16>, Unit> [storable: true, drop: true, dup: true, zero_sized: false];
+type i16 = i16 [storable: true, drop: true, dup: true, zero_sized: false];
+
+libfunc int_range_try_new<i16> = int_range_try_new<i16>;
+libfunc branch_align = branch_align;
+libfunc enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0> = enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0>;
+libfunc store_temp<RangeCheck> = store_temp<RangeCheck>;
+libfunc store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>> = store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>;
+libfunc struct_construct<Unit> = struct_construct<Unit>;
+libfunc enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1> = enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1>;
+
+int_range_try_new<i16>([0], [1], [2]) { fallthrough([3], [4]) 6([5]) }; // 0
+branch_align() -> (); // 1
+enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 0>([4]) -> ([6]); // 2
+store_temp<RangeCheck>([3]) -> ([3]); // 3
+store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>([6]) -> ([6]); // 4
+return([3], [6]); // 5
+branch_align() -> (); // 6
+struct_construct<Unit>() -> ([7]); // 7
+enum_init<core::option::Option::<test::IntRange::<core::integer::i16>>, 1>([7]) -> ([8]); // 8
+store_temp<RangeCheck>([5]) -> ([5]); // 9
+store_temp<core::option::Option::<test::IntRange::<core::integer::i16>>>([8]) -> ([8]); // 10
+return([5], [8]); // 11
+
+test::foo@0([0]: RangeCheck, [1]: i16, [2]: i16) -> (RangeCheck, core::option::Option::<test::IntRange::<core::integer::i16>>);