diff --git a/alu_u32/src/lt/columns.rs b/alu_u32/src/lt/columns.rs index a87a249..48bd5ea 100644 --- a/alu_u32/src/lt/columns.rs +++ b/alu_u32/src/lt/columns.rs @@ -13,7 +13,7 @@ pub struct Lt32Cols { pub byte_flag: [T; 4], /// Bit decomposition of 256 + input_1 - input_2 - pub bits: [T; 10], + pub bits: [T; 9], pub output: T, @@ -21,9 +21,18 @@ pub struct Lt32Cols { pub is_lt: T, pub is_lte: T, + pub is_slt: T, + pub is_sle: T, // inverse of input_1[i] - input_2[i] where i is the first byte that differs pub diff_inv: T, + + // bit decomposition of top bytes for input_1 and input_2 + pub top_bits_1: [T; 8], + pub top_bits_2: [T; 8], + + // boolean flag for whether the sign of the two inputs is different + pub different_signs: T, } pub const NUM_LT_COLS: usize = size_of::>(); diff --git a/alu_u32/src/lt/mod.rs b/alu_u32/src/lt/mod.rs index c1348d0..05d1e83 100644 --- a/alu_u32/src/lt/mod.rs +++ b/alu_u32/src/lt/mod.rs @@ -60,6 +60,8 @@ where vec![ (LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)), (LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)), + (LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)), + (LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)), ], SC::Val::zero(), ); @@ -94,28 +96,32 @@ impl Lt32Chip { match op { Operation::Lt32(a, b, c) => { cols.is_lt = F::one(); - self.set_cols(cols, a, b, c); + self.set_cols(cols, false, a, b, c); } Operation::Lte32(a, b, c) => { cols.is_lte = F::one(); - self.set_cols(cols, a, b, c); + self.set_cols(cols, false, a, b, c); } Operation::Slt32(a, b, c) => { - // TODO: this is just a placeholder - cols.is_lt = F::one(); - self.set_cols(cols, a, b, c); + cols.is_slt = F::one(); + self.set_cols(cols, true, a, b, c); } Operation::Sle32(a, b, c) => { - // TODO: this is just a placeholder - cols.is_lte = F::one(); - self.set_cols(cols, a, b, c); + cols.is_sle = F::one(); + self.set_cols(cols, true, a, b, c); } } row } - fn set_cols(&self, cols: &mut Lt32Cols, a: &Word, b: &Word, c: &Word) - where + fn set_cols( + &self, + cols: &mut Lt32Cols, + is_signed: bool, + a: &Word, + b: &Word, + c: &Word, + ) where F: PrimeField, { // Set the input columns @@ -133,13 +139,29 @@ impl Lt32Chip { .find_map(|(n, (x, y))| if x == y { None } else { Some(n) }) { let z = 256u16 + b[n] as u16 - c[n] as u16; - for i in 0..10 { + for i in 0..9 { cols.bits[i] = F::from_canonical_u16(z >> i & 1); } cols.byte_flag[n] = F::one(); // b[n] != c[n] always here, so the difference is never zero. cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse(); } + // compute (little-endian) bit decomposition of the top bytes + for i in 0..8 { + cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1); + cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1); + } + // check if sign bits agree and set different_signs accordingly + cols.different_signs = if is_signed { + if cols.top_bits_1[7] != cols.top_bits_2[7] { + F::one() + } else { + F::zero() + } + } else { + F::zero() + }; + cols.multiplicity = F::one(); } @@ -218,7 +240,6 @@ where let opcode = >::OPCODE; let comp = |a, b| a < b; let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); - state .lt_u32_mut() .operations @@ -281,7 +302,6 @@ where a_i <= b_i }; let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp); - state .lt_u32_mut() .operations diff --git a/alu_u32/src/lt/stark.rs b/alu_u32/src/lt/stark.rs index 48684df..c875b71 100644 --- a/alu_u32/src/lt/stark.rs +++ b/alu_u32/src/lt/stark.rs @@ -22,7 +22,7 @@ where let main = builder.main(); let local: &Lt32Cols = main.row_slice(0).borrow(); - let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32); + let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256].map(AB::Expr::from_canonical_u32); let bit_comp: AB::Expr = local .bits @@ -76,26 +76,93 @@ where builder.assert_bool(local.byte_flag[i]); } + // Check the bit decomposition of the top bytes: + let top_comp_1: AB::Expr = local + .top_bits_1 + .into_iter() + .zip(base_2.iter().cloned()) + .map(|(bit, base)| bit * base) + .sum(); + let top_comp_2: AB::Expr = local + .top_bits_2 + .into_iter() + .zip(base_2.iter().cloned()) + .map(|(bit, base)| bit * base) + .sum(); + builder.assert_eq(top_comp_1, local.input_1[0]); + builder.assert_eq(top_comp_2, local.input_2[0]); + + let is_signed = local.is_slt + local.is_sle; + let is_unsigned = AB::Expr::one() - is_signed.clone(); + let same_sign = AB::Expr::one() - local.different_signs; + let are_equal = AB::Expr::one() - flag_sum.clone(); + + builder + .when(is_unsigned.clone()) + .assert_zero(local.different_signs); + + // Check that `different_signs` is set correctly by comparing sign bits. + builder + .when(is_signed.clone()) + .when_ne(local.top_bits_1[7], local.top_bits_2[7]) + .assert_eq(local.different_signs, AB::Expr::one()); + builder + .when(local.different_signs) + .assert_eq(local.byte_flag[0], AB::Expr::one()); + // local.top_bits_1[7] and local.top_bits_2[7] are boolean; their sum is 1 iff they are unequal. + builder + .when(local.different_signs) + .assert_eq(local.top_bits_1[7] + local.top_bits_2[7], AB::Expr::one()); + builder.assert_bool(local.is_lt); builder.assert_bool(local.is_lte); - builder.assert_bool(local.is_lt + local.is_lte); + builder.assert_bool(local.is_slt); + builder.assert_bool(local.is_sle); + builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle); // Output constraints - // local.bits[8] is 1 iff input_1 > input_2: output should be 0 - builder.when(local.bits[8]).assert_zero(local.output); - // output should be 1 if is_lte & input_1 == input_2 + // Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1 + // when both inputs have the same sign, signed and unsigned inequality agree. builder - .when(local.is_lte) - .when_ne(flag_sum.clone(), AB::Expr::one()) + .when(local.bits[8]) + .when(is_unsigned.clone() + same_sign.clone()) + .assert_zero(local.output); + // when the inputs have different signs, signed inequality is the opposite of unsigned inequality. + builder + .when(local.bits[8]) + .when(local.different_signs) + .assert_one(local.output); + + // Case 1: input_1 < input_2 as unsigned ints; equivalently, local.bits[8] == is_equal == 0. + builder + // when are_equal == 1, we have already enforced that local.bits[8] == 0 + .when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one()) + .when(is_unsigned.clone() + same_sign.clone()) .assert_one(local.output); - // output should be 0 if is_lt & input_1 == input_2 builder - .when(local.is_lt) - .when_ne(flag_sum, AB::Expr::one()) + .when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one()) + .when(local.different_signs) .assert_zero(local.output); - // Check bit decomposition - for bit in local.bits.into_iter() { + // Case 2: input_1 == input_2; equivalently, are_equal == 1 + // output should be 1 if is_lte or is_sle + builder + .when(are_equal.clone()) + .when(local.is_lte + local.is_sle) + .assert_one(local.output); + // output should be 0 if is_lt or is_slt + builder + .when(are_equal.clone()) + .when(local.is_lt + local.is_slt) + .assert_zero(local.output); + + // Check "bit" values are all boolean + for bit in local + .bits + .into_iter() + .chain(local.top_bits_1.into_iter()) + .chain(local.top_bits_2.into_iter()) + { builder.assert_bool(bit); } } diff --git a/basic/tests/test_prover.rs b/basic/tests/test_prover.rs index 4c6c672..611616a 100644 --- a/basic/tests/test_prover.rs +++ b/basic/tests/test_prover.rs @@ -3,7 +3,7 @@ extern crate core; use p3_baby_bear::BabyBear; use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig}; use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip}; -use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction}; +use valida_alu_u32::lt::{Lt32Instruction, Lte32Instruction, Sle32Instruction, Slt32Instruction}; use valida_basic::BasicMachine; use valida_cpu::{ BeqInstruction, BneInstruction, Imm32Instruction, JalInstruction, JalvInstruction, @@ -261,6 +261,123 @@ fn left_imm_ops_program() -> Vec() -> Vec> { + let mut program = vec![]; + + // imm32 -4(fp), 0, 0, 0, 1 + // imm32 -8(fp), 255, 255, 255, 255 + // imm32 -12(fp), 255, 255, 255, 254 + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-4, 0, 0, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-8, 255, 255, 255, 255]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([-12, 255, 255, 255, 254]), + }, + ]); + + // slt32 4(fp), -12(fp), -8(fp), 0, 0 + // slt32 8(fp), -12(fp), -4(fp), 0, 0 + // slt32 12(fp), -4(fp), -1, 0, 1 + // slt32 16(fp), -1, -8(fp), 1, 0 + // sle32 20(fp), -1, -8(fp), 1, 0 + // slt32 24(fp), -1, -12(fp), 1, 0 + // slt32 28(fp), -8(fp), -12(fp), 0, 0 + // slt32 32(fp), -8(fp), -4(fp), 0, 0 + + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([4, -12, -8, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([8, -12, -4, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([12, -4, -1, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([16, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([20, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([24, -1, -12, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([28, -8, -12, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([32, -8, -4, 0, 0]), + }, + ]); + + // lt32 36(fp), -12(fp), -8(fp), 0, 0 + // lt32 40(fp), -12(fp), -4(fp), 0, 0 + // lt32 44(fp), -4(fp), -1, 0, 1 + // lt32 48(fp), -1, -8(fp), 1, 0 + // lte32 52(fp), -1, -8(fp), 1, 0 + // lt32 56(fp), -1, -12(fp), 1, 0 + // lt32 60(fp), -8(fp), -12(fp), 0, 0 + // lt32 64(fp), -8(fp), -4(fp), 0, 0 + // stop + program.extend([ + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([36, -12, -8, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([40, -12, -4, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([44, -4, -1, 0, 1]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([48, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([52, -1, -8, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([56, -1, -12, 1, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([60, -8, -12, 0, 0]), + }, + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([64, -8, -4, 0, 0]), + }, + // stop 0, 0, 0, 0, 0 + InstructionWord { + opcode: , Val>>::OPCODE, + operands: Operands([0, 0, 0, 0, 0]), + }, + ]); + + program +} + fn prove_program(program: Vec>) -> BasicMachine { let mut machine = BasicMachine::::default(); let rom = ProgramROM::new(program); @@ -351,7 +468,6 @@ fn prove_left_imm_ops() { let program = left_imm_ops_program::(); let machine = prove_program(program); - assert_eq!( *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), Word([0, 0, 0, 0]) // 3 < 3 (false) @@ -393,3 +509,78 @@ fn prove_left_imm_ops() { Word([0, 0, 0, 1]) // 3 <= 256 (false) ); } + +#[test] +fn prove_signed_inequality() { + let program = signed_inequality_program::(); + + let machine = prove_program(program); + + // signed inequalities + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 4)).unwrap(), + Word([0, 0, 0, 1]) // -2 < -1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 8)).unwrap(), + Word([0, 0, 0, 1]) // -2 < 1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 12)).unwrap(), + Word([0, 0, 0, 0]) // 1 < -1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 16)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 20)).unwrap(), + Word([0, 0, 0, 1]) // -1 <= -1 (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 24)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -2 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 28)).unwrap(), + Word([0, 0, 0, 0]) // -1 < -2 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 32)).unwrap(), + Word([0, 0, 0, 1]) // -1 < 1 (true) + ); + + // unsigned inequalities + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 36)).unwrap(), + Word([0, 0, 0, 1]) // 0xFFFFFFFE < 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 40)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFE < 1 (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 44)).unwrap(), + Word([0, 0, 0, 1]) // 1 < 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 48)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFFFF (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 52)).unwrap(), + Word([0, 0, 0, 1]) // 0xFFFFFFFF <= 0xFFFFFFFF (true) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 56)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFE (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 60)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 0xFFFFFFFE (false) + ); + assert_eq!( + *machine.mem().cells.get(&(0x1000 + 64)).unwrap(), + Word([0, 0, 0, 0]) // 0xFFFFFFFF < 1 (false) + ); +}