Skip to content

Commit

Permalink
Fixing const folded NonZero types.
Browse files Browse the repository at this point in the history
commit-id:51f0d664
  • Loading branch information
orizi committed Sep 26, 2024
1 parent 0f19b5e commit 5a3f04a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
41 changes: 28 additions & 13 deletions crates/cairo-lang-lowering/src/optimizations/const_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl<'a> ConstFoldingContext<'a> {
let id = stmt.function.get_extern(self.db)?;
if id == self.felt_sub {
// (a - 0) can be replaced by a.
let val = self.as_int(stmt.inputs[1].var_id)?;
let val = self.as_int(stmt.inputs[1].var_id)?.0;
if val.is_zero() {
self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0]));
}
Expand All @@ -230,8 +230,8 @@ impl<'a> ConstFoldingContext<'a> {
|| id == self.bounded_int_add
|| id == self.bounded_int_sub
{
let lhs = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?;
let (lhs, nz_ty) = self.as_int(stmt.inputs[0].var_id)?;
let rhs = self.as_int(stmt.inputs[1].var_id)?.0;
let value = if id == self.bounded_int_add {
lhs + rhs
} else if id == self.bounded_int_sub {
Expand All @@ -240,8 +240,10 @@ impl<'a> ConstFoldingContext<'a> {
lhs * rhs
};
let output = stmt.outputs[0];
let ty = self.variables[output].ty;
let value = ConstValue::Int(value, ty);
let mut value = ConstValue::Int(value, self.variables[output].ty);
if nz_ty {
value = ConstValue::NonZero(Box::new(value));
}
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some(Statement::Const(StatementConst { value, output }))
} else if id == self.storage_base_address_from_felt252 {
Expand All @@ -268,7 +270,7 @@ impl<'a> ConstFoldingContext<'a> {
// resulting box isn't an actual const at the Sierra level.
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
} else if id == self.upcast {
let int_value = self.as_int(stmt.inputs[0].var_id)?;
let int_value = self.as_int(stmt.inputs[0].var_id)?.0;
let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty);
self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone()));
Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] }))
Expand Down Expand Up @@ -312,8 +314,8 @@ impl<'a> ConstFoldingContext<'a> {
|| self.iadd_fns.contains(&id)
|| self.isub_fns.contains(&id)
{
let lhs = self.as_int(info.inputs[0].var_id)?;
let rhs = self.as_int(info.inputs[1].var_id)?;
let lhs = self.as_int(info.inputs[0].var_id)?.0;
let rhs = self.as_int(info.inputs[1].var_id)?.0;
let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) {
lhs + rhs
} else {
Expand All @@ -339,7 +341,7 @@ impl<'a> ConstFoldingContext<'a> {
))
} else if id == self.downcast {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let value = self.as_int(input_var)?.0;
let success_output = info.arms[0].var_ids[0];
let ty = self.variables[success_output].ty;
let range = self.type_value_ranges.get(&ty)?;
Expand All @@ -355,7 +357,7 @@ impl<'a> ConstFoldingContext<'a> {
})
} else if id == self.bounded_int_constrain {
let input_var = info.inputs[0].var_id;
let value = self.as_int(input_var)?;
let (value, nz_ty) = self.as_int(input_var)?;
let semantic_id =
extract_matches!(info.function.lookup_intern(self.db), FunctionLongId::Semantic);
let generic_arg = semantic_id.get_concrete(self.db.upcast()).generic_args[1];
Expand All @@ -365,7 +367,10 @@ impl<'a> ConstFoldingContext<'a> {
.unwrap();
let arm_idx = if value < &constrain_value { 0 } else { 1 };
let output = info.arms[arm_idx].var_ids[0];
let value = ConstValue::Int(value.clone(), self.variables[output].ty);
let mut value = ConstValue::Int(value.clone(), self.variables[output].ty);
if nz_ty {
value = ConstValue::NonZero(Box::new(value));
}
self.var_info.insert(output, VarInfo::Const(value.clone()));
Some((
Some(Statement::Const(StatementConst { value, output })),
Expand All @@ -382,8 +387,18 @@ impl<'a> ConstFoldingContext<'a> {
}

/// Return the const value as a int if it exists and is an integer.
fn as_int(&self, var_id: VariableId) -> Option<&BigInt> {
if let ConstValue::Int(value, _) = self.as_const(var_id)? { Some(value) } else { None }
fn as_int(&self, var_id: VariableId) -> Option<(&BigInt, bool)> {
match self.as_const(var_id)? {
ConstValue::Int(value, _) => Some((value, false)),
ConstValue::NonZero(const_value) => {
if let ConstValue::Int(value, _) = const_value.as_ref() {
Some((value, true))
} else {
None
}
}
_ => None,
}
}

/// Replaces the inputs in place if they are in the var_info map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3629,15 +3629,13 @@ Parameters:
blk0 (root):
Statements:
(v0: core::zeroable::NonZero::<core::integer::i8>) <- NonZero(-5)
(v1: core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>) <- NonZero(-5)
End:
Match(match core::internal::bounded_int::bounded_int_constrain::<core::zeroable::NonZero::<core::integer::i8>, 0, core::internal::bounded_int::NonZeroConstrainHelper::<core::integer::i8, 0, core::internal::bounded_int::constrain0::Impl::<core::integer::i8, -128, 127>>>(v0) {
Result::Ok(v1) => blk1,
Result::Err(v2) => blk2,
})
Goto(blk1, {})

blk1:
Statements:
(v3: core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>>) <- core::box::into_box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>>(v1)
(v3: core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>>) <- NonZero(-5).into_box()
(v4: core::result::Result::<core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>>, core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>>>) <- Result::Ok(v3)
End:
Goto(blk3, {v4 -> v5})
Expand Down Expand Up @@ -3714,11 +3712,9 @@ Parameters:
blk0 (root):
Statements:
(v0: core::zeroable::NonZero::<core::integer::i8>) <- NonZero(5)
(v2: core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>) <- NonZero(5)
End:
Match(match core::internal::bounded_int::bounded_int_constrain::<core::zeroable::NonZero::<core::integer::i8>, 0, core::internal::bounded_int::NonZeroConstrainHelper::<core::integer::i8, 0, core::internal::bounded_int::constrain0::Impl::<core::integer::i8, -128, 127>>>(v0) {
Result::Ok(v1) => blk1,
Result::Err(v2) => blk2,
})
Goto(blk2, {})

blk1:
Statements:
Expand All @@ -3729,7 +3725,7 @@ End:

blk2:
Statements:
(v6: core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>>) <- core::box::into_box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>>(v2)
(v6: core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>>) <- NonZero(5).into_box()
(v7: core::result::Result::<core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<-128, -1>>>, core::box::Box::<core::zeroable::NonZero::<core::internal::bounded_int::BoundedInt::<0, 127>>>>) <- Result::Err(v6)
End:
Goto(blk3, {v7 -> v5})
Expand Down

0 comments on commit 5a3f04a

Please sign in to comment.