Skip to content

Commit

Permalink
Added const-folding *_eq with 0 into *_iz_zero.
Browse files Browse the repository at this point in the history
commit-id:a9bfa8e4
  • Loading branch information
orizi committed Oct 6, 2024
1 parent 23164f7 commit 55d6230
Show file tree
Hide file tree
Showing 20 changed files with 36,470 additions and 36,740 deletions.
27 changes: 6 additions & 21 deletions crates/cairo-lang-lowering/src/lower/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use defs::diagnostic_utils::StableLocation;
use id_arena::Arena;
use itertools::{zip_eq, Itertools};
use semantic::corelib::{core_module, get_ty_by_name};
use semantic::expr::inference::InferenceError;
use semantic::types::wrap_in_snapshots;
use semantic::{ExprVarMemberPath, MatchArmSelector, TypeLongId};
use {cairo_lang_defs as defs, cairo_lang_semantic as semantic};
Expand Down Expand Up @@ -63,26 +62,12 @@ impl<'db> VariableAllocator<'db> {

/// Allocates a new variable in the context's variable arena according to the context.
pub fn new_var(&mut self, req: VarRequest) -> VariableId {
let ty_info = self.db.type_info(self.lookup_context.clone(), req.ty);
self.variables.alloc(Variable {
copyable: ty_info
.clone()
.map_err(InferenceError::Reported)
.and_then(|info| info.copyable),
droppable: ty_info
.clone()
.map_err(InferenceError::Reported)
.and_then(|info| info.droppable),
destruct_impl: ty_info
.clone()
.map_err(InferenceError::Reported)
.and_then(|info| info.destruct_impl),
panic_destruct_impl: ty_info
.map_err(InferenceError::Reported)
.and_then(|info| info.panic_destruct_impl),
ty: req.ty,
location: req.location,
})
self.variables.alloc(Variable::new(
self.db,
self.lookup_context.clone(),
req.ty,
req.location,
))
}

/// Retrieves the LocationId of a stable syntax pointer in the current function file.
Expand Down
19 changes: 9 additions & 10 deletions crates/cairo-lang-lowering/src/lower/test_data/loop
Original file line number Diff line number Diff line change
Expand Up @@ -879,26 +879,25 @@ End:
blk1:
Statements:
(v7: core::integer::u8) <- 0
(v8: core::integer::u8) <- 0
End:
Match(match core::integer::u8_eq(v7, v8) {
bool::False => blk2,
bool::True => blk3,
Match(match core::integer::u8_is_zero(v7) {
IsZeroResult::Zero => blk2,
IsZeroResult::NonZero(v8) => blk3,
})

blk2:
Statements:
(v9: core::RangeCheck, v10: core::gas::GasBuiltin, v11: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::foo[expr25](v3, v4, v7)
(v9: ()) <- struct_construct()
(v10: (core::integer::u8, ())) <- struct_construct(v7, v9)
(v11: core::panics::PanicResult::<(core::integer::u8, ())>) <- PanicResult::Ok(v10)
End:
Return(v9, v10, v11)
Return(v3, v4, v11)

blk3:
Statements:
(v12: ()) <- struct_construct()
(v13: (core::integer::u8, ())) <- struct_construct(v7, v12)
(v14: core::panics::PanicResult::<(core::integer::u8, ())>) <- PanicResult::Ok(v13)
(v12: core::RangeCheck, v13: core::gas::GasBuiltin, v14: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::foo[expr25](v3, v4, v7)
End:
Return(v3, v4, v14)
Return(v12, v13, v14)

blk4:
Statements:
Expand Down
22 changes: 22 additions & 0 deletions crates/cairo-lang-lowering/src/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use cairo_lang_debug::DebugWithDb;
use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_diagnostics::{DiagnosticNote, Diagnostics};
use cairo_lang_semantic as semantic;
use cairo_lang_semantic::items::imp::ImplLookupContext;
use cairo_lang_semantic::types::TypeInfo;
use cairo_lang_semantic::{ConcreteEnumId, ConcreteVariant};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{Intern, LookupIntern};
Expand Down Expand Up @@ -220,6 +222,26 @@ pub struct Variable {
/// Location of the variable.
pub location: LocationId,
}
impl Variable {
pub fn new(
db: &dyn LoweringGroup,
ctx: ImplLookupContext,
ty: semantic::TypeId,
location: LocationId,
) -> Self {
let TypeInfo { droppable, copyable, destruct_impl, panic_destruct_impl } =
match db.type_info(ctx, ty) {
Ok(info) => info,
Err(diag_added) => TypeInfo {
droppable: Err(InferenceError::Reported(diag_added)),
copyable: Err(InferenceError::Reported(diag_added)),
destruct_impl: Err(InferenceError::Reported(diag_added)),
panic_destruct_impl: Err(InferenceError::Reported(diag_added)),
},
};
Self { copyable, droppable, destruct_impl, panic_destruct_impl, ty, location }
}
}

/// Lowered statement.
#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down
103 changes: 74 additions & 29 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::sync::Arc;

use cairo_lang_defs::ids::{ExternFunctionId, ModuleId, ModuleItemId};
use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_semantic::{corelib, GenericArgumentId, TypeId};
use cairo_lang_semantic::items::imp::ImplLookupContext;
use cairo_lang_semantic::{corelib, GenericArgumentId, MatchArmSelector, TypeId};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
Expand All @@ -21,8 +22,8 @@ use smol_str::SmolStr;
use crate::db::LoweringGroup;
use crate::ids::{FunctionId, FunctionLongId};
use crate::{
BlockId, FlatBlockEnd, FlatLowered, MatchEnumInfo, MatchExternInfo, MatchInfo, Statement,
StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
BlockId, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchExternInfo, MatchInfo,
Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
StatementStructConstruct, StatementStructDestructure, VarUsage, Variable, VariableId,
};

Expand Down Expand Up @@ -53,7 +54,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
let mut ctx = ConstFoldingContext {
db,
var_info: UnorderedHashMap::default(),
variables: &lowered.variables,
variables: &mut lowered.variables,
libfunc_info: &libfunc_info,
};
let mut stack = vec![BlockId::root()];
Expand Down Expand Up @@ -106,7 +107,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
for input in inputs.iter() {
let Some(info) = ctx.var_info.get(&input.var_id) else {
all_args.push(
lowered.variables[input.var_id]
ctx.variables[input.var_id]
.copyable
.is_ok()
.then(|| VarInfo::Var(*input)),
Expand All @@ -120,7 +121,7 @@ pub fn const_folding(db: &dyn LoweringGroup, lowered: &mut FlatLowered) {
all_args.push(Some(info.clone()));
}
if const_args.len() == inputs.len() {
let value = ConstValue::Struct(const_args, lowered.variables[*output].ty);
let value = ConstValue::Struct(const_args, ctx.variables[*output].ty);
ctx.var_info.insert(*output, VarInfo::Const(value));
} else if contains_info {
ctx.var_info.insert(*output, VarInfo::Struct(all_args));
Expand Down Expand Up @@ -210,7 +211,7 @@ struct ConstFoldingContext<'a> {
/// The used database.
db: &'a dyn LoweringGroup,
/// The variables arena, mostly used to get the type of variables.
variables: &'a Arena<Variable>,
variables: &'a mut Arena<Variable>,
/// The accumulated information about the const values of variables.
var_info: UnorderedHashMap<VariableId, VarInfo>,
/// The libfunc information.
Expand Down Expand Up @@ -355,12 +356,53 @@ impl<'a> ConstFoldingContext<'a> {
)
})
} else if self.eq_fns.contains(&id) {
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let lhs = self.as_int(info.inputs[0].var_id);
let rhs = self.as_int(info.inputs[1].var_id);
if (lhs.map(Zero::is_zero).unwrap_or_default() && rhs.is_none())
|| (rhs.map(Zero::is_zero).unwrap_or_default() && lhs.is_none())
{
let db = self.db.upcast();
let nz_input = info.inputs[if lhs.is_some() { 1 } else { 0 }];
let var = &self.variables[nz_input.var_id].clone();
let function = self.type_value_ranges.get(&var.ty)?.is_zero;
let unused_nz_var = Variable::new(
self.db,
ImplLookupContext::default(),
corelib::core_nonzero_ty(db, var.ty),
var.location,
);
let unused_nz_var = self.variables.alloc(unused_nz_var);
return Some((
None,
FlatBlockEnd::Match {
info: MatchInfo::Extern(MatchExternInfo {
function,
inputs: vec![nz_input],
arms: vec![
MatchArm {
arm_selector: MatchArmSelector::VariantId(
corelib::jump_nz_zero_variant(db),
),
block_id: info.arms[1].block_id,
var_ids: vec![],
},
MatchArm {
arm_selector: MatchArmSelector::VariantId(
corelib::jump_nz_nonzero_variant(db),
),
block_id: info.arms[0].block_id,
var_ids: vec![unused_nz_var],
},
],
location: info.location,
}),
},
));
}
Some((
None,
FlatBlockEnd::Goto(
info.arms[if lhs == rhs { 1 } else { 0 }].block_id,
info.arms[if lhs? == rhs? { 1 } else { 0 }].block_id,
Default::default(),
),
))
Expand Down Expand Up @@ -576,7 +618,7 @@ pub struct ConstFoldingLibfuncInfo {
/// The storage access module.
storage_access_module: ModuleId,
/// Type ranges.
type_value_ranges: OrderedHashMap<TypeId, TypeRange>,
type_value_ranges: OrderedHashMap<TypeId, TypeInfo>,
}
impl ConstFoldingLibfuncInfo {
fn new(db: &dyn LoweringGroup) -> Self {
Expand Down Expand Up @@ -636,20 +678,25 @@ impl ConstFoldingLibfuncInfo {
let bounded_int_constrain = bounded_int_module.extern_function_id("bounded_int_constrain");
let type_value_ranges = OrderedHashMap::from_iter(
[
("u8", TypeRange::closed(0, u8::MAX)),
("u16", TypeRange::closed(0, u16::MAX)),
("u32", TypeRange::closed(0, u32::MAX)),
("u64", TypeRange::closed(0, u64::MAX)),
("u128", TypeRange::closed(0, u128::MAX)),
("u256", TypeRange::closed(0, BigInt::from(1) << 256)),
("i8", TypeRange::closed(i8::MIN, i8::MAX)),
("i16", TypeRange::closed(i16::MIN, i16::MAX)),
("i32", TypeRange::closed(i32::MIN, i32::MAX)),
("i64", TypeRange::closed(i64::MIN, i64::MAX)),
("i128", TypeRange::closed(i128::MIN, i128::MAX)),
("u8", BigInt::ZERO, u8::MAX.into()),
("u16", BigInt::ZERO, u16::MAX.into()),
("u32", BigInt::ZERO, u32::MAX.into()),
("u64", BigInt::ZERO, u64::MAX.into()),
("u128", BigInt::ZERO, u128::MAX.into()),
("u256", BigInt::ZERO, BigInt::from(1) << 256),
("i8", i8::MIN.into(), i8::MAX.into()),
("i16", i16::MIN.into(), i16::MAX.into()),
("i32", i32::MIN.into(), i32::MAX.into()),
("i64", i64::MIN.into(), i64::MAX.into()),
("i128", i128::MIN.into(), i128::MAX.into()),
]
.map(|(ty, range)| {
(corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), range)
.map(|(ty, min, max): (&str, BigInt, BigInt)| {
let info = TypeInfo {
min,
max,
is_zero: integer_module.function_id(format!("{ty}_is_zero"), vec![]),
};
(corelib::get_core_ty_by_name(db.upcast(), ty.into(), vec![]), info)
}),
);
Self {
Expand Down Expand Up @@ -685,14 +732,12 @@ impl std::ops::Deref for ConstFoldingContext<'_> {

/// The range of a type for normalizations.
#[derive(Debug, PartialEq, Eq)]
struct TypeRange {
struct TypeInfo {
min: BigInt,
max: BigInt,
is_zero: FunctionId,
}
impl TypeRange {
fn closed(min: impl Into<BigInt>, max: impl Into<BigInt>) -> Self {
Self { min: min.into(), max: max.into() }
}
impl TypeInfo {
/// Normalizes the value to the range.
/// Assumes the value is within size of range of the range.
fn normalized(&self, value: BigInt) -> NormalizedResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4368,3 +4368,84 @@ End:
Return(v18)

//! > lowering_diagnostics

//! > ==========================================================================

//! > Eq with 0 const fold.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo(x: u8) -> bool {
x == 0
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > before
Parameters: v0: core::integer::u8
blk0 (root):
Statements:
(v3: core::integer::u8) <- 0
End:
Match(match core::integer::u8_eq(v0, v3) {
bool::False => blk1,
bool::True => blk2,
})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > after
Parameters: v0: core::integer::u8
blk0 (root):
Statements:
(v3: core::integer::u8) <- 0
End:
Match(match core::integer::u8_is_zero(v0) {
IsZeroResult::Zero => blk2,
IsZeroResult::NonZero(v13) => blk1,
})

blk1:
Statements:
(v8: ()) <- struct_construct()
(v9: core::bool) <- bool::False(v8)
End:
Goto(blk3, {v9 -> v10})

blk2:
Statements:
(v11: ()) <- struct_construct()
(v12: core::bool) <- bool::True(v11)
End:
Goto(blk3, {v12 -> v10})

blk3:
Statements:
End:
Return(v10)

//! > lowering_diagnostics
Loading

0 comments on commit 55d6230

Please sign in to comment.