Skip to content

Commit

Permalink
Merge pull request #168 from valida-xyz/dorebell-issue-160
Browse files Browse the repository at this point in the history
STARK constraints for signed inequality instructions
  • Loading branch information
tess-eract authored May 6, 2024
2 parents 7a6d813 + 5a4bf94 commit f980d06
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 28 deletions.
11 changes: 10 additions & 1 deletion alu_u32/src/lt/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ pub struct Lt32Cols<T> {
pub byte_flag: [T; 4],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],
pub bits: [T; 9],

pub output: T,

pub multiplicity: T,

pub is_lt: T,
pub is_lte: T,
pub is_slt: T,
pub is_sle: T,

// inverse of input_1[i] - input_2[i] where i is the first byte that differs
pub diff_inv: T,

// bit decomposition of top bytes for input_1 and input_2
pub top_bits_1: [T; 8],
pub top_bits_2: [T; 8],

// boolean flag for whether the sign of the two inputs is different
pub different_signs: T,
}

pub const NUM_LT_COLS: usize = size_of::<Lt32Cols<u8>>();
Expand Down
46 changes: 33 additions & 13 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ where
vec![
(LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)),
(LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)),
(LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)),
(LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)),
],
SC::Val::zero(),
);
Expand Down Expand Up @@ -94,28 +96,32 @@ impl Lt32Chip {
match op {
Operation::Lt32(a, b, c) => {
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Lte32(a, b, c) => {
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Slt32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
cols.is_slt = F::one();
self.set_cols(cols, true, a, b, c);
}
Operation::Sle32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
cols.is_sle = F::one();
self.set_cols(cols, true, a, b, c);
}
}
row
}

fn set_cols<F>(&self, cols: &mut Lt32Cols<F>, a: &Word<u8>, b: &Word<u8>, c: &Word<u8>)
where
fn set_cols<F>(
&self,
cols: &mut Lt32Cols<F>,
is_signed: bool,
a: &Word<u8>,
b: &Word<u8>,
c: &Word<u8>,
) where
F: PrimeField,
{
// Set the input columns
Expand All @@ -133,13 +139,29 @@ impl Lt32Chip {
.find_map(|(n, (x, y))| if x == y { None } else { Some(n) })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
for i in 0..9 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
cols.byte_flag[n] = F::one();
// b[n] != c[n] always here, so the difference is never zero.
cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse();
}
// compute (little-endian) bit decomposition of the top bytes
for i in 0..8 {
cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1);
cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1);
}
// check if sign bits agree and set different_signs accordingly
cols.different_signs = if is_signed {
if cols.top_bits_1[7] != cols.top_bits_2[7] {
F::one()
} else {
F::zero()
}
} else {
F::zero()
};

cols.multiplicity = F::one();
}

Expand Down Expand Up @@ -218,7 +240,6 @@ where
let opcode = <Self as Instruction<M, F>>::OPCODE;
let comp = |a, b| a < b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down Expand Up @@ -281,7 +302,6 @@ where
a_i <= b_i
};
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down
91 changes: 79 additions & 12 deletions alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
let main = builder.main();
let local: &Lt32Cols<AB::Var> = main.row_slice(0).borrow();

let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);
let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256].map(AB::Expr::from_canonical_u32);

let bit_comp: AB::Expr = local
.bits
Expand Down Expand Up @@ -76,26 +76,93 @@ where
builder.assert_bool(local.byte_flag[i]);
}

// Check the bit decomposition of the top bytes:
let top_comp_1: AB::Expr = local
.top_bits_1
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
let top_comp_2: AB::Expr = local
.top_bits_2
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
builder.assert_eq(top_comp_1, local.input_1[0]);
builder.assert_eq(top_comp_2, local.input_2[0]);

let is_signed = local.is_slt + local.is_sle;
let is_unsigned = AB::Expr::one() - is_signed.clone();
let same_sign = AB::Expr::one() - local.different_signs;
let are_equal = AB::Expr::one() - flag_sum.clone();

builder
.when(is_unsigned.clone())
.assert_zero(local.different_signs);

// Check that `different_signs` is set correctly by comparing sign bits.
builder
.when(is_signed.clone())
.when_ne(local.top_bits_1[7], local.top_bits_2[7])
.assert_eq(local.different_signs, AB::Expr::one());
builder
.when(local.different_signs)
.assert_eq(local.byte_flag[0], AB::Expr::one());
// local.top_bits_1[7] and local.top_bits_2[7] are boolean; their sum is 1 iff they are unequal.
builder
.when(local.different_signs)
.assert_eq(local.top_bits_1[7] + local.top_bits_2[7], AB::Expr::one());

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
builder.assert_bool(local.is_lt + local.is_lte);
builder.assert_bool(local.is_slt);
builder.assert_bool(local.is_sle);
builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle);

// Output constraints
// local.bits[8] is 1 iff input_1 > input_2: output should be 0
builder.when(local.bits[8]).assert_zero(local.output);
// output should be 1 if is_lte & input_1 == input_2
// Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1
// when both inputs have the same sign, signed and unsigned inequality agree.
builder
.when(local.is_lte)
.when_ne(flag_sum.clone(), AB::Expr::one())
.when(local.bits[8])
.when(is_unsigned.clone() + same_sign.clone())
.assert_zero(local.output);
// when the inputs have different signs, signed inequality is the opposite of unsigned inequality.
builder
.when(local.bits[8])
.when(local.different_signs)
.assert_one(local.output);

// Case 1: input_1 < input_2 as unsigned ints; equivalently, local.bits[8] == is_equal == 0.
builder
// when are_equal == 1, we have already enforced that local.bits[8] == 0
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(is_unsigned.clone() + same_sign.clone())
.assert_one(local.output);
// output should be 0 if is_lt & input_1 == input_2
builder
.when(local.is_lt)
.when_ne(flag_sum, AB::Expr::one())
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(local.different_signs)
.assert_zero(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
// Case 2: input_1 == input_2; equivalently, are_equal == 1
// output should be 1 if is_lte or is_sle
builder
.when(are_equal.clone())
.when(local.is_lte + local.is_sle)
.assert_one(local.output);
// output should be 0 if is_lt or is_slt
builder
.when(are_equal.clone())
.when(local.is_lt + local.is_slt)
.assert_zero(local.output);

// Check "bit" values are all boolean
for bit in local
.bits
.into_iter()
.chain(local.top_bits_1.into_iter())
.chain(local.top_bits_2.into_iter())
{
builder.assert_bool(bit);
}
}
Expand Down
Loading

0 comments on commit f980d06

Please sign in to comment.