Skip to content

Commit

Permalink
Added support for div-rem const folding.
Browse files Browse the repository at this point in the history
commit-id:b582d31f
  • Loading branch information
orizi committed Sep 29, 2024
1 parent a6b0530 commit bac6339
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 26 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ id-arena.workspace = true
itertools = { workspace = true, default-features = true }
log.workspace = true
num-bigint = { workspace = true, default-features = true }
num-integer = { workspace = true, default-features = true }
num-traits = { workspace = true, default-features = true }
salsa.workspace = true
smol_str.workspace = true
Expand Down
65 changes: 48 additions & 17 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use cairo_lang_utils::{extract_matches, try_extract_matches, Intern, LookupInter
use id_arena::Arena;
use itertools::{chain, zip_eq};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::Zero;
use smol_str::SmolStr;

Expand Down Expand Up @@ -64,6 +65,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
visited[block_id.0] = true;

let block = &mut lowered.blocks[block_id];
let mut additional_consts = vec![];
for stmt in block.statements.iter_mut() {
ctx.maybe_replace_inputs(stmt.inputs_mut());
match stmt {
Expand Down Expand Up @@ -92,8 +94,9 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
}
}
Statement::Call(call_stmt) => {
if let Some(updated_stmt) = ctx.handle_statement_call(call_stmt) {
*stmt = updated_stmt;
if let Some((updated_stmt, additional)) = ctx.handle_statement_call(call_stmt) {
*stmt = Statement::Const(updated_stmt);
additional_consts.extend(additional);
}
}
Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
Expand Down Expand Up @@ -164,6 +167,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
}
}
}
block.statements.splice(0..0, additional_consts.into_iter().map(Statement::Const));

match &mut block.end {
FlatBlockEnd::Goto(block_id, remappings) => {
Expand All @@ -188,7 +192,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
MatchInfo::Extern(info) => {
if let Some((extra_stmt, updated_end)) = ctx.handle_extern_block_end(info) {
if let Some(stmt) = extra_stmt {
block.statements.push(stmt);
block.statements.push(Statement::Const(stmt));
}
block.end = updated_end;
}
Expand Down Expand Up @@ -216,8 +220,13 @@ struct ConstFoldingContext<'a> {
impl<'a> ConstFoldingContext<'a> {
/// Handles a statement call.
/// Returns None if no additional changes are required.
/// If changes are required, returns an updated statement.
fn handle_statement_call(&mut self, stmt: &mut StatementCall) -> Option<Statement> {
/// If changes are required, returns an updated const-statement (to override the current
/// statement), and a possible additional const-statement, if multiple statements are required
/// for replacing the existing statement.
fn handle_statement_call(
&mut self,
stmt: &mut StatementCall,
) -> Option<(StatementConst, Option<StatementConst>)> {
let id = stmt.function.get_extern(self.db)?;
if id == self.felt_sub {
// (a - 0) can be replaced by a.
Expand Down Expand Up @@ -245,7 +254,21 @@ impl<'a> ConstFoldingContext<'a> {
value = ConstValue::NonZero(Box::new(value));
}
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some(Statement::Const(StatementConst { value, output }))
Some((StatementConst { value, output }, None))
} else if self.div_rem_fns.contains(&id) {
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let (q, r) = lhs.div_rem(rhs);
let q_output = stmt.outputs[0];
let q_value = ConstValue::Int(q, self.variables[q_output].ty);
self.var_info.insert(q_output, VarInfo::Const(q_value.clone()));
let r_output = stmt.outputs[1];
let r_value = ConstValue::Int(r, self.variables[r_output].ty);
self.var_info.insert(r_output, VarInfo::Const(r_value.clone()));
Some((
StatementConst { value: q_value, output: q_output },
Some(StatementConst { value: r_value, output: r_output }),
))
} else if id == self.storage_base_address_from_felt252 {
let input_var = stmt.inputs[0].var_id;
if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
Expand All @@ -268,25 +291,26 @@ impl<'a> ConstFoldingContext<'a> {
let value = ConstValue::Boxed(const_value.clone().into());
// Not inserting the value into the `var_info` map because the
// resulting box isn't an actual const at the Sierra level.
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
Some((StatementConst { value, output: stmt.outputs[0] }, None))
} else if id == self.upcast {
let int_value = self.as_int(stmt.inputs[0].var_id)?;
let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty);
self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone()));
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
let output = stmt.outputs[0];
let value = ConstValue::Int(int_value.clone(), self.variables[output].ty);
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some((StatementConst { value, output }, None))
} else {
None
}
}

/// Handles the end of an extern block.
/// Returns None if no additional changes are required.
/// If changes are required, returns a possible additional statement to the block, as well as an
/// updated block end.
/// If changes are required, returns a possible additional const-statement to the block, as well
/// as an updated block end.
fn handle_extern_block_end(
&mut self,
info: &mut MatchExternInfo,
) -> Option<(Option<Statement>, FlatBlockEnd)> {
) -> Option<(Option<StatementConst>, FlatBlockEnd)> {
let id = info.function.get_extern(self.db)?;
if self.nz_fns.contains(&id) {
let val = self.as_const(info.inputs[0].var_id)?;
Expand All @@ -305,7 +329,7 @@ impl<'a> ConstFoldingContext<'a> {
let nz_val = ConstValue::NonZero(Box::new(val.clone()));
self.var_info.insert(nz_var, VarInfo::Const(nz_val.clone()));
(
Some(Statement::Const(StatementConst { value: nz_val, output: nz_var })),
Some(StatementConst { value: nz_val, output: nz_var }),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
)
})
Expand Down Expand Up @@ -336,7 +360,7 @@ impl<'a> ConstFoldingContext<'a> {
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
Some(StatementConst { value, output: actual_output }),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
))
} else if id == self.downcast {
Expand All @@ -349,7 +373,7 @@ impl<'a> ConstFoldingContext<'a> {
let value = ConstValue::Int(value, ty);
self.var_info.insert(success_output, VarInfo::Const(value.clone()));
(
Some(Statement::Const(StatementConst { value, output: success_output })),
Some(StatementConst { value, output: success_output }),
FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
)
} else {
Expand All @@ -373,7 +397,7 @@ impl<'a> ConstFoldingContext<'a> {
}
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some((
Some(Statement::Const(StatementConst { value, output })),
Some(StatementConst { value, output }),
FlatBlockEnd::Goto(info.arms[arm_idx].block_id, Default::default()),
))
} else {
Expand Down Expand Up @@ -500,6 +524,8 @@ pub struct ConstFoldingLibfuncInfo {
isub_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to multiply integers.
wide_mul_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to divide and get the remainder of integers.
div_rem_fns: OrderedHashSet<ExternFunctionId>,
/// The `bounded_int_add` libfunc.
bounded_int_add: ExternFunctionId,
/// The `bounded_int_sub` libfunc.
Expand Down Expand Up @@ -556,6 +582,10 @@ impl ConstFoldingLibfuncInfo {
["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
));
let div_rem_fns = OrderedHashSet::<_>::from_iter(chain!(
[bounded_int_module.extern_function_id("bounded_int_div_rem")],
utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_safe_divmod"))),
));
let bounded_int_add = bounded_int_module.extern_function_id("bounded_int_add");
let bounded_int_sub = bounded_int_module.extern_function_id("bounded_int_sub");
let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
Expand Down Expand Up @@ -589,6 +619,7 @@ impl ConstFoldingLibfuncInfo {
iadd_fns,
isub_fns,
wide_mul_fns,
div_rem_fns,
bounded_int_add,
bounded_int_sub,
bounded_int_constrain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ End:
test_match_optimizer

//! > function
fn foo(x: u8) -> u8 {
x / 4
fn foo() -> u8 {
8 / 4
}

//! > function_name
Expand All @@ -185,9 +185,10 @@ foo
//! > semantic_diagnostics

//! > before
Parameters: v0: core::integer::u8
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 8
(v1: core::integer::u8) <- 4
End:
Match(match core::integer::u8_is_zero(v1) {
Expand Down Expand Up @@ -235,9 +236,10 @@ End:
Return(v19)

//! > after
Parameters: v0: core::integer::u8
Parameters:
blk0 (root):
Statements:
(v0: core::integer::u8) <- 8
(v1: core::integer::u8) <- 4
(v2: core::zeroable::NonZero::<core::integer::u8>) <- NonZero(4)
End:
Expand All @@ -256,7 +258,8 @@ End:

blk2:
Statements:
(v10: core::integer::u8, v11: core::integer::u8) <- core::integer::u8_safe_divmod(v0, v2)
(v11: core::integer::u8) <- 0
(v10: core::integer::u8) <- 2
(v12: (core::integer::u8,)) <- struct_construct(v10)
(v13: core::panics::PanicResult::<(core::integer::u8,)>) <- PanicResult::Ok(v12)
End:
Expand Down Expand Up @@ -3203,8 +3206,8 @@ End:
test_match_optimizer

//! > function
fn foo(x: i8) -> i8 {
x / 4
fn foo() -> i8 {
8 / 4
}

//! > function_name
Expand All @@ -3215,9 +3218,10 @@ foo
//! > semantic_diagnostics

//! > before
Parameters: v0: core::integer::i8
Parameters:
blk0 (root):
Statements:
(v0: core::integer::i8) <- 8
(v1: core::integer::i8) <- 4
End:
Match(match core::internal::bounded_int::bounded_int_is_zero::<core::integer::i8>(v1) {
Expand Down Expand Up @@ -3281,9 +3285,10 @@ End:
Return(v24)

//! > after
Parameters: v0: core::integer::i8
Parameters:
blk0 (root):
Statements:
(v0: core::integer::i8) <- 8
(v1: core::integer::i8) <- 4
(v2: core::zeroable::NonZero::<core::integer::i8>) <- NonZero(4)
End:
Expand Down

0 comments on commit bac6339

Please sign in to comment.