diff --git a/acvm-repo/brillig_vm/src/black_box.rs b/acvm-repo/brillig_vm/src/black_box.rs index 36d045efabf..53599f79bc7 100644 --- a/acvm-repo/brillig_vm/src/black_box.rs +++ b/acvm-repo/brillig_vm/src/black_box.rs @@ -42,7 +42,7 @@ pub(crate) fn evaluate_black_box op: &BlackBoxOp, solver: &Solver, memory: &mut Memory, - bigint_solver: &mut BigIntSolver, + bigint_solver: &mut BrilligBigintSolver, ) -> Result<(), BlackBoxResolutionError> { match op { BlackBoxOp::AES128Encrypt { inputs, iv, key, outputs } => { @@ -270,29 +270,33 @@ pub(crate) fn evaluate_black_box BlackBoxOp::BigIntAdd { lhs, rhs, output } => { let lhs = memory.read(*lhs).try_into().unwrap(); let rhs = memory.read(*rhs).try_into().unwrap(); - let output = memory.read(*output).try_into().unwrap(); - bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntAdd)?; + + let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntAdd)?; + memory.write(*output, new_id.into()); Ok(()) } BlackBoxOp::BigIntSub { lhs, rhs, output } => { let lhs = memory.read(*lhs).try_into().unwrap(); let rhs = memory.read(*rhs).try_into().unwrap(); - let output = memory.read(*output).try_into().unwrap(); - bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntSub)?; + + let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntSub)?; + memory.write(*output, new_id.into()); Ok(()) } BlackBoxOp::BigIntMul { lhs, rhs, output } => { let lhs = memory.read(*lhs).try_into().unwrap(); let rhs = memory.read(*rhs).try_into().unwrap(); - let output = memory.read(*output).try_into().unwrap(); - bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntMul)?; + + let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntMul)?; + memory.write(*output, new_id.into()); Ok(()) } BlackBoxOp::BigIntDiv { lhs, rhs, output } => { let lhs = memory.read(*lhs).try_into().unwrap(); let rhs = memory.read(*rhs).try_into().unwrap(); - let output = memory.read(*output).try_into().unwrap(); - bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntDiv)?; + + let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntDiv)?; + memory.write(*output, new_id.into()); Ok(()) } BlackBoxOp::BigIntFromLeBytes { inputs, modulus, output } => { @@ -300,8 +304,10 @@ pub(crate) fn evaluate_black_box let input: Vec = input.iter().map(|x| x.try_into().unwrap()).collect(); let modulus = read_heap_vector(memory, modulus); let modulus: Vec = modulus.iter().map(|x| x.try_into().unwrap()).collect(); - let output = memory.read(*output).try_into().unwrap(); - bigint_solver.bigint_from_bytes(&input, &modulus, output)?; + + let new_id = bigint_solver.bigint_from_bytes(&input, &modulus)?; + memory.write(*output, new_id.into()); + Ok(()) } BlackBoxOp::BigIntToLeBytes { input, output } => { @@ -381,6 +387,46 @@ pub(crate) fn evaluate_black_box } } +/// Wrapper over the generic bigint solver to automatically assign bigint ids in brillig +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub(crate) struct BrilligBigintSolver { + bigint_solver: BigIntSolver, + last_id: u32, +} + +impl BrilligBigintSolver { + pub(crate) fn create_bigint_id(&mut self) -> u32 { + let output = self.last_id; + self.last_id += 1; + output + } + + pub(crate) fn bigint_from_bytes( + &mut self, + inputs: &[u8], + modulus: &[u8], + ) -> Result { + let id = self.create_bigint_id(); + self.bigint_solver.bigint_from_bytes(inputs, modulus, id)?; + Ok(id) + } + + pub(crate) fn bigint_to_bytes(&self, input: u32) -> Result, BlackBoxResolutionError> { + self.bigint_solver.bigint_to_bytes(input) + } + + pub(crate) fn bigint_op( + &mut self, + lhs: u32, + rhs: u32, + func: BlackBoxFunc, + ) -> Result { + let id = self.create_bigint_id(); + self.bigint_solver.bigint_op(lhs, rhs, id, func)?; + Ok(id) + } +} + fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { match op { BlackBoxOp::AES128Encrypt { .. } => BlackBoxFunc::AES128Encrypt, @@ -414,10 +460,10 @@ mod test { brillig::{BlackBoxOp, MemoryAddress}, FieldElement, }; - use acvm_blackbox_solver::{BigIntSolver, StubbedBlackBoxSolver}; + use acvm_blackbox_solver::StubbedBlackBoxSolver; use crate::{ - black_box::{evaluate_black_box, to_u8_vec, to_value_vec}, + black_box::{evaluate_black_box, to_u8_vec, to_value_vec, BrilligBigintSolver}, HeapArray, HeapVector, Memory, }; @@ -439,8 +485,13 @@ mod test { output: HeapArray { pointer: 2.into(), size: 32 }, }; - evaluate_black_box(&op, &StubbedBlackBoxSolver, &mut memory, &mut BigIntSolver::default()) - .unwrap(); + evaluate_black_box( + &op, + &StubbedBlackBoxSolver, + &mut memory, + &mut BrilligBigintSolver::default(), + ) + .unwrap(); let result = memory.read_slice(MemoryAddress(result_pointer), 32); diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 01f45bf653c..4d2dd2b8333 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -16,9 +16,9 @@ use acir::brillig::{ HeapVector, MemoryAddress, Opcode, ValueOrArray, }; use acir::AcirField; -use acvm_blackbox_solver::{BigIntSolver, BlackBoxFunctionSolver}; +use acvm_blackbox_solver::BlackBoxFunctionSolver; use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op, BrilligArithmeticError}; -use black_box::evaluate_black_box; +use black_box::{evaluate_black_box, BrilligBigintSolver}; use num_bigint::BigUint; // Re-export `brillig`. @@ -88,7 +88,7 @@ pub struct VM<'a, F, B: BlackBoxFunctionSolver> { /// The solver for blackbox functions black_box_solver: &'a B, // The solver for big integers - bigint_solver: BigIntSolver, + bigint_solver: BrilligBigintSolver, } impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs index 367cdbe4973..aa9cb8cd7a3 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs @@ -243,13 +243,7 @@ pub(crate) fn convert_black_box_call( [BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)], ) = (function_arguments, function_results) { - prepare_bigint_output( - brillig_context, - lhs_modulus, - rhs_modulus, - output, - modulus_id, - ); + prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id); brillig_context.black_box_op_instruction(BlackBoxOp::BigIntAdd { lhs: lhs.address, rhs: rhs.address, @@ -267,13 +261,7 @@ pub(crate) fn convert_black_box_call( [BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)], ) = (function_arguments, function_results) { - prepare_bigint_output( - brillig_context, - lhs_modulus, - rhs_modulus, - output, - modulus_id, - ); + prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id); brillig_context.black_box_op_instruction(BlackBoxOp::BigIntSub { lhs: lhs.address, rhs: rhs.address, @@ -291,13 +279,7 @@ pub(crate) fn convert_black_box_call( [BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)], ) = (function_arguments, function_results) { - prepare_bigint_output( - brillig_context, - lhs_modulus, - rhs_modulus, - output, - modulus_id, - ); + prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id); brillig_context.black_box_op_instruction(BlackBoxOp::BigIntMul { lhs: lhs.address, rhs: rhs.address, @@ -315,13 +297,7 @@ pub(crate) fn convert_black_box_call( [BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)], ) = (function_arguments, function_results) { - prepare_bigint_output( - brillig_context, - lhs_modulus, - rhs_modulus, - output, - modulus_id, - ); + prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id); brillig_context.black_box_op_instruction(BlackBoxOp::BigIntDiv { lhs: lhs.address, rhs: rhs.address, @@ -341,8 +317,6 @@ pub(crate) fn convert_black_box_call( { let inputs_vector = convert_array_or_vector(brillig_context, inputs, bb_func); let modulus_vector = convert_array_or_vector(brillig_context, modulus, bb_func); - let output_id = brillig_context.get_new_bigint_id(); - brillig_context.const_instruction(*output, F::from(output_id as u128)); brillig_context.black_box_op_instruction(BlackBoxOp::BigIntFromLeBytes { inputs: inputs_vector.to_heap_vector(), modulus: modulus_vector.to_heap_vector(), @@ -447,7 +421,6 @@ fn prepare_bigint_output( brillig_context: &mut BrilligContext, lhs_modulus: &SingleAddrVariable, rhs_modulus: &SingleAddrVariable, - output: &SingleAddrVariable, modulus_id: &SingleAddrVariable, ) { // Check moduli @@ -464,8 +437,6 @@ fn prepare_bigint_output( Some("moduli should be identical in BigInt operation".to_string()), ); brillig_context.deallocate_register(condition); - // Set output id - let output_id = brillig_context.get_new_bigint_id(); - brillig_context.const_instruction(*output, F::from(output_id as u128)); + brillig_context.mov_instruction(modulus_id.address, lhs_modulus.address); } diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 9785e073be9..80367d07635 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -91,8 +91,6 @@ pub(crate) struct BrilligContext { next_section: usize, /// IR printer debug_show: DebugShow, - /// Counter for generating bigint ids in unconstrained functions - bigint_new_id: u32, } impl BrilligContext { @@ -105,15 +103,9 @@ impl BrilligContext { section_label: 0, next_section: 1, debug_show: DebugShow::new(enable_debug_trace), - bigint_new_id: 0, } } - pub(crate) fn get_new_bigint_id(&mut self) -> u32 { - let result = self.bigint_new_id; - self.bigint_new_id += 1; - result - } /// Adds a brillig instruction to the brillig byte code fn push_opcode(&mut self, opcode: BrilligOpcode) { self.obj.push_opcode(opcode); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index dc06c2fa0d7..d10e31533dc 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -23,7 +23,6 @@ impl BrilligContext { section_label: 0, next_section: 1, debug_show: DebugShow::new(false), - bigint_new_id: 0, }; context.codegen_entry_point(&arguments, &return_parameters);