Skip to content

Commit

Permalink
Added *_eq const folding. (#6435)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Oct 6, 2024
1 parent dbc88ef commit 1b6162e
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 0 deletions.
16 changes: 16 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,16 @@ impl<'a> ConstFoldingContext<'a> {
FlatBlockEnd::Goto(arm.block_id, Default::default()),
)
})
} else if self.eq_fns.contains(&id) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
Some((
None,
FlatBlockEnd::Goto(
info.arms[if lhs == rhs { 1 } else { 0 }].block_id,
Default::default(),
),
))
} else if self.uadd_fns.contains(&id)
|| self.usub_fns.contains(&id)
|| self.iadd_fns.contains(&id)
Expand Down Expand Up @@ -516,6 +526,8 @@ pub struct ConstFoldingLibfuncInfo {
storage_base_address_from_felt252: ExternFunctionId,
/// The set of functions that check if a number is zero.
nz_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions that check if numbers are equal.
eq_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to add unsigned ints.
uadd_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to subtract unsigned ints.
Expand Down Expand Up @@ -563,6 +575,9 @@ impl ConstFoldingLibfuncInfo {
));
let utypes = ["u8", "u16", "u32", "u64", "u128"];
let itypes = ["i8", "i16", "i32", "i64", "i128"];
let eq_fns = OrderedHashSet::<_>::from_iter(
chain!(utypes, itypes).map(|ty| integer_module.extern_function_id(format!("{ty}_eq"))),
);
let uadd_fns = OrderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
);
Expand Down Expand Up @@ -616,6 +631,7 @@ impl ConstFoldingLibfuncInfo {
downcast,
storage_base_address_from_felt252,
nz_fns,
eq_fns,
uadd_fns,
usub_fns,
iadd_fns,
Expand Down
160 changes: 160 additions & 0 deletions crates/cairo-lang-lowering/src/optimizations/test_data/const_folding
Original file line number Diff line number Diff line change
Expand Up @@ -3741,3 +3741,163 @@ End:
Return(v5)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Eq const fold.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> bool {
1_u8 == 1
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 1
(v3: core::integer::u8) <- 1
End:
Match(match core::integer::u8_eq(v0, v3) {
bool::False => blk1,
bool::True => blk2,
})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 1
(v3: core::integer::u8) <- 1
End:
Goto(blk2, {})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Non eq const fold.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo() -> bool {
2_u8 == 1
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v3: core::integer::u8) <- 1
End:
Match(match core::integer::u8_eq(v0, v3) {
bool::False => blk1,
bool::True => blk2,
})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > after
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 2
(v3: core::integer::u8) <- 1
End:
Goto(blk1, {})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > lowering_diagnostics

0 comments on commit 1b6162e

Please sign in to comment.