From 5a3f04a94af27ac1f838b24414af43ed83446020 Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Thu, 26 Sep 2024 15:42:20 +0300 Subject: [PATCH] Fixing const folded NonZero types. commit-id:51f0d664 --- .../src/optimizations/const_folding.rs | 41 +++++++++++++------ .../src/optimizations/test_data/const_folding | 16 +++----- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs index 469f2f5e4ff..2566189230f 100644 --- a/crates/cairo-lang-lowering/src/optimizations/const_folding.rs +++ b/crates/cairo-lang-lowering/src/optimizations/const_folding.rs @@ -221,7 +221,7 @@ impl<'a> ConstFoldingContext<'a> { let id = stmt.function.get_extern(self.db)?; if id == self.felt_sub { // (a - 0) can be replaced by a. - let val = self.as_int(stmt.inputs[1].var_id)?; + let val = self.as_int(stmt.inputs[1].var_id)?.0; if val.is_zero() { self.var_info.insert(stmt.outputs[0], VarInfo::Var(stmt.inputs[0])); } @@ -230,8 +230,8 @@ impl<'a> ConstFoldingContext<'a> { || id == self.bounded_int_add || id == self.bounded_int_sub { - let lhs = self.as_int(stmt.inputs[0].var_id)?; - let rhs = self.as_int(stmt.inputs[1].var_id)?; + let (lhs, nz_ty) = self.as_int(stmt.inputs[0].var_id)?; + let rhs = self.as_int(stmt.inputs[1].var_id)?.0; let value = if id == self.bounded_int_add { lhs + rhs } else if id == self.bounded_int_sub { @@ -240,8 +240,10 @@ impl<'a> ConstFoldingContext<'a> { lhs * rhs }; let output = stmt.outputs[0]; - let ty = self.variables[output].ty; - let value = ConstValue::Int(value, ty); + let mut value = ConstValue::Int(value, self.variables[output].ty); + if nz_ty { + value = ConstValue::NonZero(Box::new(value)); + } self.var_info.insert(output, VarInfo::Const(value.clone())); Some(Statement::Const(StatementConst { value, output })) } else if id == self.storage_base_address_from_felt252 { @@ -268,7 +270,7 @@ impl<'a> ConstFoldingContext<'a> { // resulting box isn't an actual const at the Sierra level. Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] })) } else if id == self.upcast { - let int_value = self.as_int(stmt.inputs[0].var_id)?; + let int_value = self.as_int(stmt.inputs[0].var_id)?.0; let value = ConstValue::Int(int_value.clone(), self.variables[stmt.outputs[0]].ty); self.var_info.insert(stmt.outputs[0], VarInfo::Const(value.clone())); Some(Statement::Const(StatementConst { value, output: stmt.outputs[0] })) @@ -312,8 +314,8 @@ impl<'a> ConstFoldingContext<'a> { || self.iadd_fns.contains(&id) || self.isub_fns.contains(&id) { - let lhs = self.as_int(info.inputs[0].var_id)?; - let rhs = self.as_int(info.inputs[1].var_id)?; + let lhs = self.as_int(info.inputs[0].var_id)?.0; + let rhs = self.as_int(info.inputs[1].var_id)?.0; let value = if self.uadd_fns.contains(&id) || self.iadd_fns.contains(&id) { lhs + rhs } else { @@ -339,7 +341,7 @@ impl<'a> ConstFoldingContext<'a> { )) } else if id == self.downcast { let input_var = info.inputs[0].var_id; - let value = self.as_int(input_var)?; + let value = self.as_int(input_var)?.0; let success_output = info.arms[0].var_ids[0]; let ty = self.variables[success_output].ty; let range = self.type_value_ranges.get(&ty)?; @@ -355,7 +357,7 @@ impl<'a> ConstFoldingContext<'a> { }) } else if id == self.bounded_int_constrain { let input_var = info.inputs[0].var_id; - let value = self.as_int(input_var)?; + let (value, nz_ty) = self.as_int(input_var)?; let semantic_id = extract_matches!(info.function.lookup_intern(self.db), FunctionLongId::Semantic); let generic_arg = semantic_id.get_concrete(self.db.upcast()).generic_args[1]; @@ -365,7 +367,10 @@ impl<'a> ConstFoldingContext<'a> { .unwrap(); let arm_idx = if value < &constrain_value { 0 } else { 1 }; let output = info.arms[arm_idx].var_ids[0]; - let value = ConstValue::Int(value.clone(), self.variables[output].ty); + let mut value = ConstValue::Int(value.clone(), self.variables[output].ty); + if nz_ty { + value = ConstValue::NonZero(Box::new(value)); + } self.var_info.insert(output, VarInfo::Const(value.clone())); Some(( Some(Statement::Const(StatementConst { value, output })), @@ -382,8 +387,18 @@ impl<'a> ConstFoldingContext<'a> { } /// Return the const value as a int if it exists and is an integer. - fn as_int(&self, var_id: VariableId) -> Option<&BigInt> { - if let ConstValue::Int(value, _) = self.as_const(var_id)? { Some(value) } else { None } + fn as_int(&self, var_id: VariableId) -> Option<(&BigInt, bool)> { + match self.as_const(var_id)? { + ConstValue::Int(value, _) => Some((value, false)), + ConstValue::NonZero(const_value) => { + if let ConstValue::Int(value, _) = const_value.as_ref() { + Some((value, true)) + } else { + None + } + } + _ => None, + } } /// Replaces the inputs in place if they are in the var_info map. diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding index bacc0a4a54d..a84b5456bfc 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/const_folding @@ -3629,15 +3629,13 @@ Parameters: blk0 (root): Statements: (v0: core::zeroable::NonZero::) <- NonZero(-5) + (v1: core::zeroable::NonZero::>) <- NonZero(-5) End: - Match(match core::internal::bounded_int::bounded_int_constrain::, 0, core::internal::bounded_int::NonZeroConstrainHelper::>>(v0) { - Result::Ok(v1) => blk1, - Result::Err(v2) => blk2, - }) + Goto(blk1, {}) blk1: Statements: - (v3: core::box::Box::>>) <- core::box::into_box::>>(v1) + (v3: core::box::Box::>>) <- NonZero(-5).into_box() (v4: core::result::Result::>>, core::box::Box::>>>) <- Result::Ok(v3) End: Goto(blk3, {v4 -> v5}) @@ -3714,11 +3712,9 @@ Parameters: blk0 (root): Statements: (v0: core::zeroable::NonZero::) <- NonZero(5) + (v2: core::zeroable::NonZero::>) <- NonZero(5) End: - Match(match core::internal::bounded_int::bounded_int_constrain::, 0, core::internal::bounded_int::NonZeroConstrainHelper::>>(v0) { - Result::Ok(v1) => blk1, - Result::Err(v2) => blk2, - }) + Goto(blk2, {}) blk1: Statements: @@ -3729,7 +3725,7 @@ End: blk2: Statements: - (v6: core::box::Box::>>) <- core::box::into_box::>>(v2) + (v6: core::box::Box::>>) <- NonZero(5).into_box() (v7: core::result::Result::>>, core::box::Box::>>>) <- Result::Err(v6) End: Goto(blk3, {v7 -> v5})