From 14765dfce12fd26e3119e5ea48539a606ed7b9e5 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 26 Sep 2024 22:06:16 +0300 Subject: [PATCH] Added `*_eq` const folding. commit-id:5850139f --- .../src/optimizations/const_folding.rs | 16 ++ .../src/optimizations/test_data/const_folding | 160 ++++++++++++++++++ 2 files changed, 176 insertions(+) diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index fbfdcce8ec0..0cb6c57d38f 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -335,6 +335,16 @@ impl<'a> ConstFoldingContext<'a> { FlatBlockEnd::Goto(arm.block_id, Default::default()), ) }) + } else if self.eq_fns.contains(&id) { + let lhs = self.as_int(info.inputs[0].var_id)?; + let rhs = self.as_int(info.inputs[1].var_id)?; + Some(( + None, + FlatBlockEnd::Goto( + info.arms[if lhs == rhs { 1 } else { 0 }].block_id, + Default::default(), + ), + )) } else if self.uadd_fns.contains(&id) || self.usub_fns.contains(&id) || self.iadd_fns.contains(&id) @@ -516,6 +526,8 @@ pub struct ConstFoldingLibfuncInfo { storage_base_address_from_felt252: ExternFunctionId, /// The set of functions that check if a number is zero. nz_fns: OrderedHashSet, + /// The set of functions that check if numbers are equal. + eq_fns: OrderedHashSet, /// The set of functions to add unsigned ints. uadd_fns: OrderedHashSet, /// The set of functions to subtract unsigned ints. @@ -563,6 +575,9 @@ impl ConstFoldingLibfuncInfo { )); let utypes = ["u8", "u16", "u32", "u64", "u128"]; let itypes = ["i8", "i16", "i32", "i64", "i128"]; + let eq_fns = OrderedHashSet::<_>::from_iter( + chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))), + ); let uadd_fns = OrderedHashSet::<_>::from_iter( utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))), ); @@ -616,6 +631,7 @@ impl ConstFoldingLibfuncInfo { downcast, storage_base_address_from_felt252, nz_fns, + eq_fns, uadd_fns, usub_fns, iadd_fns, diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index 526b8a05d4b..dde9331c119 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -3741,3 +3741,163 @@ End: Return(v5) //! > lowering_diagnostics + +//! > ========================================================================== + +//! > Eq const fold. + +//! > test_runner_name +test_match_optimizer + +//! > function +fn foo() -> bool { + 1_u8 == 1 +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > before +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 1 + (v3: core::integer::u8) <- 1 +End: + Match(match core::integer::u8_eq(v0, v3) { + bool::False => blk1, + bool::True => blk2, + }) + +blk1: +Statements: + (v8: ()) <- struct_construct() + (v9: core::bool) <- bool::False(v8) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: ()) <- struct_construct() + (v12: core::bool) <- bool::True(v11) +End: + Goto(blk3, {v12 -> v10}) + +blk3: +Statements: +End: + Return(v10) + +//! > after +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 1 + (v3: core::integer::u8) <- 1 +End: + Goto(blk2, {}) + +blk1: +Statements: + (v8: ()) <- struct_construct() + (v9: core::bool) <- bool::False(v8) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: ()) <- struct_construct() + (v12: core::bool) <- bool::True(v11) +End: + Goto(blk3, {v12 -> v10}) + +blk3: +Statements: +End: + Return(v10) + +//! > lowering_diagnostics + +//! > ========================================================================== + +//! > Non eq const fold. + +//! > test_runner_name +test_match_optimizer + +//! > function +fn foo() -> bool { + 2_u8 == 1 +} + +//! > function_name +foo + +//! > module_code + +//! > semantic_diagnostics + +//! > before +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 2 + (v3: core::integer::u8) <- 1 +End: + Match(match core::integer::u8_eq(v0, v3) { + bool::False => blk1, + bool::True => blk2, + }) + +blk1: +Statements: + (v8: ()) <- struct_construct() + (v9: core::bool) <- bool::False(v8) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: ()) <- struct_construct() + (v12: core::bool) <- bool::True(v11) +End: + Goto(blk3, {v12 -> v10}) + +blk3: +Statements: +End: + Return(v10) + +//! > after +Parameters: +blk0 (root): +Statements: + (v0: core::integer::u8) <- 2 + (v3: core::integer::u8) <- 1 +End: + Goto(blk1, {}) + +blk1: +Statements: + (v8: ()) <- struct_construct() + (v9: core::bool) <- bool::False(v8) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: ()) <- struct_construct() + (v12: core::bool) <- bool::True(v11) +End: + Goto(blk3, {v12 -> v10}) + +blk3: +Statements: +End: + Return(v10) + +//! > lowering_diagnostics