Skip to content

Commit

Permalink
Refactored const folding. (#6418)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored Sep 26, 2024
1 parent 7d1e58b commit f0e7c4e
Showing 1 changed file with 90 additions and 99 deletions.
189 changes: 90 additions & 99 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,24 @@ impl<'a> ConstFoldingContext<'a> {
/// Returns None if no additional changes are required.
/// If changes are required, returns an updated statement.
fn handle_statement_call(&mut self, stmt: &mut StatementCall) -> Option<Statement> {
// (a - 0) can be replaced by a.
if stmt.function == self.felt_sub {
let id = stmt.function.get_extern(self.db)?;
if id == self.felt_sub {
// (a - 0) can be replaced by a.
let val = self.as_int(stmt.inputs[1].var_id)?;
if val.is_zero() {
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
}
} else if self.wide_mul_fns.contains(&stmt.function) {
None
} else if self.wide_mul_fns.contains(&id) {
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let value = lhs * rhs;
let output = stmt.outputs[0];
let ty = self.variables[output].ty;
let value = ConstValue::Int(value, ty);
self.var_info.insert(output, VarInfo::Const(value.clone()));
return Some(Statement::Const(StatementConst { value, output }));
} else if stmt.function == self.storage_base_address_from_felt252 {
Some(Statement::Const(StatementConst { value, output }))
} else if id == self.storage_base_address_from_felt252 {
let input_var = stmt.inputs[0].var_id;
if let Some(ConstValue::Int(val, ty)) = self.as_const(input_var) {
stmt.inputs.clear();
Expand All @@ -245,25 +247,25 @@ impl<'a> ConstFoldingContext<'a> {
)],
);
}
} else if let Some(extrn) = stmt.function.get_extern(self.db) {
if extrn == self.into_box {
let const_value = match self.var_info.get(&stmt.inputs[0].var_id)? {
VarInfo::Const(val) => val,
VarInfo::Snapshot(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)?,
_ => return None,
};
let value = ConstValue::Boxed(const_value.clone().into());
// Not inserting the value into the `var_info` map because the
// resulting box isn't an actual const at the Sierra level.
return Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }));
} else if extrn == self.upcast {
let int_value = self.as_int(stmt.inputs[0].var_id)?;
let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty);
self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone()));
return Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }));
}
None
} else if id == self.into_box {
let const_value = match self.var_info.get(&stmt.inputs[0].var_id)? {
VarInfo::Const(val) => val,
VarInfo::Snapshot(info) => try_extract_matches!(info.as_ref(), VarInfo::Const)?,
_ => return None,
};
let value = ConstValue::Boxed(const_value.clone().into());
// Not inserting the value into the `var_info` map because the
// resulting box isn't an actual const at the Sierra level.
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
} else if id == self.upcast {
let int_value = self.as_int(stmt.inputs[0].var_id)?;
let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty);
self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone()));
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
} else {
None
}
None
}

/// Handles the end of an extern block.
Expand All @@ -274,7 +276,8 @@ impl<'a> ConstFoldingContext<'a> {
&mut self,
info: &mut MatchExternInfo,
) -> Option<(Option<Statement>, FlatBlockEnd)> {
if self.nz_fns.contains(&info.function) {
let id = info.function.get_extern(self.db)?;
if self.nz_fns.contains(&id) {
let val = self.as_const(info.inputs[0].var_id)?;
let is_zero = match val {
ConstValue::Int(v, _) => v.is_zero(),
Expand All @@ -283,7 +286,7 @@ impl<'a> ConstFoldingContext<'a> {
}),
_ => unreachable!(),
};
return Some(if is_zero {
Some(if is_zero {
(None, FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()))
} else {
let arm = &info.arms[1];
Expand All @@ -294,69 +297,56 @@ impl<'a> ConstFoldingContext<'a> {
Some(Statement::Const(StatementConst { value: nz_val, output: nz_var })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
)
});
} else if self.uadd_fns.contains(&info.function) || self.usub_fns.contains(&info.function) {
})
} else if self.uadd_fns.contains(&id)
|| self.usub_fns.contains(&id)
|| self.iadd_fns.contains(&id)
|| self.isub_fns.contains(&id)
{
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = if self.uadd_fns.contains(&info.function) { lhs + rhs } else { lhs - rhs };
let ty = self.variables[info.arms[0].var_ids[0]].ty;
let range = self.type_value_ranges.get(&ty)?;
let (arm_index, value) = match range.normalized(value) {
NormalizedResult::InRange(value) => (0, value),
NormalizedResult::Over(value) | NormalizedResult::Under(value) => (1, value),
let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
lhs + rhs
} else {
lhs - rhs
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
return Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
));
} else if self.iadd_fns.contains(&info.function) || self.isub_fns.contains(&info.function) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let value = if self.iadd_fns.contains(&info.function) { lhs + rhs } else { lhs - rhs };
let ty = self.variables[info.arms[0].var_ids[0]].ty;
let range = self.type_value_ranges.get(&ty)?;
let (arm_index, value) = match range.normalized(value) {
NormalizedResult::InRange(value) => (0, value),
NormalizedResult::Under(value) => (1, value),
NormalizedResult::Over(value) => (2, value),
NormalizedResult::Over(value) => (
if self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) { 2 } else { 1 },
value,
),
};
let arm = &info.arms[arm_index];
let actual_output = arm.var_ids[0];
let value = ConstValue::Int(value, ty);
self.var_info.insert(actual_output, VarInfo::Const(value.clone()));
return Some((
Some((
Some(Statement::Const(StatementConst { value, output: actual_output })),
FlatBlockEnd::Goto(arm.block_id, Default::default()),
));
} else if let Some(extrn) = info.function.get_extern(self.db) {
if extrn == self.downcast {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let success_output = info.arms[0].var_ids[0];
let ty = self.variables[success_output].ty;
let range = self.type_value_ranges.get(&ty)?;
return Some(
if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
let value = ConstValue::Int(value, ty);
self.var_info.insert(success_output, VarInfo::Const(value.clone()));
(
Some(Statement::Const(StatementConst {
value,
output: success_output,
})),
FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
)
} else {
(None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
},
);
}
))
} else if id == self.downcast {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let success_output = info.arms[0].var_ids[0];
let ty = self.variables[success_output].ty;
let range = self.type_value_ranges.get(&ty)?;
Some(if let NormalizedResult::InRange(value) = range.normalized(value.clone()) {
let value = ConstValue::Int(value, ty);
self.var_info.insert(success_output, VarInfo::Const(value.clone()));
(
Some(Statement::Const(StatementConst { value, output: success_output })),
FlatBlockEnd::Goto(info.arms[0].block_id, Default::default()),
)
} else {
(None, FlatBlockEnd::Goto(info.arms[1].block_id, Default::default()))
})
} else {
None
}
None
}

/// Returns the const value of a variable if it exists.
Expand Down Expand Up @@ -411,11 +401,12 @@ impl<'a> ModuleHelper<'a> {
Self { db: self.db, id }
}
/// Returns the id of an extern function named `name` in the current module.
fn extern_function_id(&self, name: &str) -> ExternFunctionId {
fn extern_function_id(&self, name: impl Into<SmolStr>) -> ExternFunctionId {
let name = name.into();
let Ok(Some(ModuleItemId::ExternFunction(id))) =
self.db.module_item_by_name(self.id, name.into())
self.db.module_item_by_name(self.id, name.clone())
else {
panic!("`{name}` not found in `{}`.", self.id.full_path(self.db.upcast()));
panic!("`{}` not found in `{}`.", name, self.id.full_path(self.db.upcast()));
};
id
}
Expand All @@ -440,27 +431,27 @@ impl<'a> ModuleHelper<'a> {
#[derive(Debug, PartialEq, Eq)]
pub struct ConstFoldingLibfuncInfo {
/// The `felt252_sub` libfunc.
felt_sub: FunctionId,
felt_sub: ExternFunctionId,
/// The `into_box` libfunc.
into_box: ExternFunctionId,
/// The `upcast` libfunc.
upcast: ExternFunctionId,
/// The `downcast` libfunc.
downcast: ExternFunctionId,
/// The `storage_base_address_from_felt252` libfunc.
storage_base_address_from_felt252: FunctionId,
storage_base_address_from_felt252: ExternFunctionId,
/// The set of functions that check if a number is zero.
nz_fns: OrderedHashSet<FunctionId>,
nz_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to add unsigned ints.
uadd_fns: OrderedHashSet<FunctionId>,
uadd_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to subtract unsigned ints.
usub_fns: OrderedHashSet<FunctionId>,
usub_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to add signed ints.
iadd_fns: OrderedHashSet<FunctionId>,
iadd_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to subtract signed ints.
isub_fns: OrderedHashSet<FunctionId>,
isub_fns: OrderedHashSet<ExternFunctionId>,
/// The set of functions to multiply integers.
wide_mul_fns: OrderedHashSet<FunctionId>,
wide_mul_fns: OrderedHashSet<ExternFunctionId>,
/// The storage access module.
storage_access_module: ModuleId,
/// Type ranges.
Expand All @@ -469,7 +460,7 @@ pub struct ConstFoldingLibfuncInfo {
impl ConstFoldingLibfuncInfo {
fn new(db: &dyn LoweringGroup) -> Self {
let core = ModuleHelper::core(db);
let felt_sub = core.function_id("felt252_sub", vec![]);
let felt_sub = core.extern_function_id("felt252_sub");
let box_module = core.submodule("box");
let into_box = box_module.extern_function_id("into_box");
let integer_module = core.submodule("integer");
Expand All @@ -478,33 +469,33 @@ impl ConstFoldingLibfuncInfo {
let starknet_module = core.submodule("starknet");
let storage_access_module = starknet_module.submodule("storage_access");
let storage_base_address_from_felt252 =
storage_access_module.function_id("storage_base_address_from_felt252", vec![]);
storage_access_module.extern_function_id("storage_base_address_from_felt252");
let nz_fns = OrderedHashSet::<_>::from_iter(chain!(
[core.function_id("felt252_is_zero", vec![])],
[core.extern_function_id("felt252_is_zero")],
["u8", "u16", "u32", "u64", "u128", "u256", "i8", "i16", "i32", "i64", "i128"]
.map(|ty| integer_module.function_id(format!("{}_is_zero", ty), vec![]))
.map(|ty| integer_module.extern_function_id(format!("{ty}_is_zero")))
));
let utypes = ["u8", "u16", "u32", "u64", "u128"];
let itypes = ["i8", "i16", "i32", "i64", "i128"];
let uadd_fns = OrderedHashSet::<_>::from_iter(
utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_add"), vec![])),
utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add"))),
);
let usub_fns = OrderedHashSet::<_>::from_iter(chain!(
utypes.map(|ty| integer_module.function_id(format!("{ty}_overflowing_sub"), vec![])),
utypes.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub"))),
// Considering `i*_diff` as `usub` operations - as they act exactly the same.
itypes.map(|ty| integer_module.function_id(format!("{ty}_diff"), vec![])),
itypes.map(|ty| integer_module.extern_function_id(format!("{ty}_diff"))),
));
let iadd_fns =
OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
integer_module.function_id(format!("{ty}_overflowing_add_impl"), vec![])
}));
let isub_fns =
OrderedHashSet::<_>::from_iter(itypes.map(|ty| {
integer_module.function_id(format!("{ty}_overflowing_sub_impl"), vec![])
}));
let iadd_fns = OrderedHashSet::<_>::from_iter(
itypes
.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_add_impl"))),
);
let isub_fns = OrderedHashSet::<_>::from_iter(
itypes
.map(|ty| integer_module.extern_function_id(format!("{ty}_overflowing_sub_impl"))),
);
let wide_mul_fns = OrderedHashSet::<_>::from_iter(
["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.map(|ty| integer_module.function_id(format!("{ty}_wide_mul"), vec![])),
.map(|ty| integer_module.extern_function_id(format!("{ty}_wide_mul"))),
);
let type_value_ranges = OrderedHashMap::from_iter(
[
Expand Down

0 comments on commit f0e7c4e

Please sign in to comment.