From 73388ceb35f5793ef5a894f283fbeeb523fd5a09 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 26 Sep 2024 15:42:20 +0300 Subject: [PATCH] Fixing const folded NonZero types. commit-id:51f0d664 --- .../src/optimizations/const_folding.rs | 33 +++++++++++++++---- .../src/optimizations/test_data/const_folding | 16 ++++----- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index 469f2f5e4ff..a38c3338087 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -230,7 +230,7 @@ impl<'a> ConstFoldingContext<'a> { || id == self.bounded_int_add || id == self.bounded_int_sub { - let lhs = self.as_int(stmt.inputs[0].var_id)?; + let (lhs, nz_ty) = self.as_int_ex(stmt.inputs[0].var_id)?; let rhs = self.as_int(stmt.inputs[1].var_id)?; let value = if id == self.bounded_int_add { lhs + rhs @@ -240,8 +240,10 @@ impl<'a> ConstFoldingContext<'a> { lhs * rhs }; let output = stmt.outputs[0]; - let ty = self.variables[output].ty; - let value = ConstValue::Int(value, ty); + let mut value = ConstValue::Int(value, self.variables[output].ty); + if nz_ty { + value = ConstValue::NonZero(Box::new(value)); + } self.var_info.insert(output, VarInfo::Const(value.clone())); Some(Statement::Const(StatementConst { value, output })) } else if id == self.storage_base_address_from_felt252 { @@ -355,7 +357,7 @@ impl<'a> ConstFoldingContext<'a> { }) } else if id == self.bounded_int_constrain { let input_var = info.inputs[0].var_id; - let value = self.as_int(input_var)?; + let (value, nz_ty) = self.as_int_ex(input_var)?; let semantic_id = extract_matches!(info.function.lookup_intern(self.db), FunctionLongId::Semantic); let generic_arg = semantic_id.get_concrete(self.db.upcast()).generic_args[1]; @@ -365,7 +367,10 @@ impl<'a> ConstFoldingContext<'a> { .unwrap(); let arm_idx = if value < &constrain_value { 0 } else { 1 }; let output = info.arms[arm_idx].var_ids[0]; - let value = ConstValue::Int(value.clone(), self.variables[output].ty); + let mut value = ConstValue::Int(value.clone(), self.variables[output].ty); + if nz_ty { + value = ConstValue::NonZero(Box::new(value)); + } self.var_info.insert(output, VarInfo::Const(value.clone())); Some(( Some(Statement::Const(StatementConst { value, output })), @@ -381,9 +386,25 @@ impl<'a> ConstFoldingContext<'a> { try_extract_matches!(self.var_info.get(&var_id)?, VarInfo::Const) } + /// Return the const value as a int if it exists and is an integer, aditionally, if it is of a + /// non-zero type. + fn as_int_ex(&self, var_id: VariableId) -> Option<(&BigInt, bool)> { + match self.as_const(var_id)? { + ConstValue::Int(value, _) => Some((value, false)), + ConstValue::NonZero(const_value) => { + if let ConstValue::Int(value, _) = const_value.as_ref() { + Some((value, true)) + } else { + None + } + } + _ => None, + } + } + /// Return the const value as a int if it exists and is an integer. fn as_int(&self, var_id: VariableId) -> Option<&BigInt> { - if let ConstValue::Int(value, _) = self.as_const(var_id)? { Some(value) } else { None } + Some(self.as_int_ex(var_id)?.0) } /// Replaces the inputs in place if they are in the var_info map. diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index bacc0a4a54d..a84b5456bfc 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -3629,15 +3629,13 @@ Parameters: blk0 (root): Statements: (v0: core::zeroable::NonZero::) <- NonZero(-5) + (v1: core::zeroable::NonZero::>) <- NonZero(-5) End: - Match(match core::internal::bounded_int::bounded_int_constrain::, 0, core::internal::bounded_int::NonZeroConstrainHelper::>>(v0) { - Result::Ok(v1) => blk1, - Result::Err(v2) => blk2, - }) + Goto(blk1, {}) blk1: Statements: - (v3: core::box::Box::>>) <- core::box::into_box::>>(v1) + (v3: core::box::Box::>>) <- NonZero(-5).into_box() (v4: core::result::Result::>>, core::box::Box::>>>) <- Result::Ok(v3) End: Goto(blk3, {v4 -> v5}) @@ -3714,11 +3712,9 @@ Parameters: blk0 (root): Statements: (v0: core::zeroable::NonZero::) <- NonZero(5) + (v2: core::zeroable::NonZero::>) <- NonZero(5) End: - Match(match core::internal::bounded_int::bounded_int_constrain::, 0, core::internal::bounded_int::NonZeroConstrainHelper::>>(v0) { - Result::Ok(v1) => blk1, - Result::Err(v2) => blk2, - }) + Goto(blk2, {}) blk1: Statements: @@ -3729,7 +3725,7 @@ End: blk2: Statements: - (v6: core::box::Box::>>) <- core::box::into_box::>>(v2) + (v6: core::box::Box::>>) <- NonZero(5).into_box() (v7: core::result::Result::>>, core::box::Box::>>>) <- Result::Err(v6) End: Goto(blk3, {v7 -> v5})