Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STARK constraints for signed inequality instructions #168

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion alu_u32/src/lt/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@ pub struct Lt32Cols<T> {
pub byte_flag: [T; 4],

/// Bit decomposition of 256 + input_1 - input_2
pub bits: [T; 10],
morganthomas marked this conversation as resolved.
Show resolved Hide resolved
pub bits: [T; 9],

pub output: T,

pub multiplicity: T,

pub is_lt: T,
pub is_lte: T,
pub is_slt: T,
pub is_sle: T,

// inverse of input_1[i] - input_2[i] where i is the first byte that differs
pub diff_inv: T,

// bit decomposition of top bytes for input_1 and input_2
pub top_bits_1: [T; 8],
pub top_bits_2: [T; 8],

// boolean flag for whether the sign of the two inputs is different
pub different_signs: T,
}

pub const NUM_LT_COLS: usize = size_of::<Lt32Cols<u8>>();
Expand Down
46 changes: 33 additions & 13 deletions alu_u32/src/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ where
vec![
(LT_COL_MAP.is_lt, SC::Val::from_canonical_u32(LT32)),
(LT_COL_MAP.is_lte, SC::Val::from_canonical_u32(LTE32)),
(LT_COL_MAP.is_slt, SC::Val::from_canonical_u32(SLT32)),
(LT_COL_MAP.is_sle, SC::Val::from_canonical_u32(SLE32)),
],
SC::Val::zero(),
);
Expand Down Expand Up @@ -94,28 +96,32 @@ impl Lt32Chip {
match op {
Operation::Lt32(a, b, c) => {
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Lte32(a, b, c) => {
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
self.set_cols(cols, false, a, b, c);
}
Operation::Slt32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lt = F::one();
self.set_cols(cols, a, b, c);
cols.is_slt = F::one();
self.set_cols(cols, true, a, b, c);
}
Operation::Sle32(a, b, c) => {
// TODO: this is just a placeholder
cols.is_lte = F::one();
self.set_cols(cols, a, b, c);
cols.is_sle = F::one();
self.set_cols(cols, true, a, b, c);
}
}
row
}

fn set_cols<F>(&self, cols: &mut Lt32Cols<F>, a: &Word<u8>, b: &Word<u8>, c: &Word<u8>)
where
fn set_cols<F>(
&self,
cols: &mut Lt32Cols<F>,
is_signed: bool,
a: &Word<u8>,
b: &Word<u8>,
c: &Word<u8>,
) where
F: PrimeField,
{
// Set the input columns
Expand All @@ -133,13 +139,29 @@ impl Lt32Chip {
.find_map(|(n, (x, y))| if x == y { None } else { Some(n) })
{
let z = 256u16 + b[n] as u16 - c[n] as u16;
for i in 0..10 {
for i in 0..9 {
cols.bits[i] = F::from_canonical_u16(z >> i & 1);
}
cols.byte_flag[n] = F::one();
// b[n] != c[n] always here, so the difference is never zero.
cols.diff_inv = (cols.input_1[n] - cols.input_2[n]).inverse();
}
// compute (little-endian) bit decomposition of the top bytes
for i in 0..8 {
cols.top_bits_1[i] = F::from_canonical_u8(b[0] >> i & 1);
cols.top_bits_2[i] = F::from_canonical_u8(c[0] >> i & 1);
}
// check if sign bits agree and set different_signs accordingly
tess-eract marked this conversation as resolved.
Show resolved Hide resolved
cols.different_signs = if is_signed {
if cols.top_bits_1[7] != cols.top_bits_2[7] {
F::one()
} else {
F::zero()
}
} else {
F::zero()
};

cols.multiplicity = F::one();
}

Expand Down Expand Up @@ -218,7 +240,6 @@ where
let opcode = <Self as Instruction<M, F>>::OPCODE;
let comp = |a, b| a < b;
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down Expand Up @@ -281,7 +302,6 @@ where
a_i <= b_i
};
let (dst, src1, src2) = Lt32Chip::execute_with_closure(state, ops, opcode, comp);

state
.lt_u32_mut()
.operations
Expand Down
91 changes: 79 additions & 12 deletions alu_u32/src/lt/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ where
let main = builder.main();
let local: &Lt32Cols<AB::Var> = main.row_slice(0).borrow();

let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512].map(AB::Expr::from_canonical_u32);
let base_2 = [1, 2, 4, 8, 16, 32, 64, 128, 256].map(AB::Expr::from_canonical_u32);

let bit_comp: AB::Expr = local
.bits
Expand Down Expand Up @@ -76,26 +76,93 @@ where
builder.assert_bool(local.byte_flag[i]);
}

// Check the bit decomposition of the top bytes:
let top_comp_1: AB::Expr = local
morganthomas marked this conversation as resolved.
Show resolved Hide resolved
.top_bits_1
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
let top_comp_2: AB::Expr = local
.top_bits_2
.into_iter()
.zip(base_2.iter().cloned())
.map(|(bit, base)| bit * base)
.sum();
builder.assert_eq(top_comp_1, local.input_1[0]);
builder.assert_eq(top_comp_2, local.input_2[0]);

let is_signed = local.is_slt + local.is_sle;
let is_unsigned = AB::Expr::one() - is_signed.clone();
let same_sign = AB::Expr::one() - local.different_signs;
let are_equal = AB::Expr::one() - flag_sum.clone();

builder
.when(is_unsigned.clone())
.assert_zero(local.different_signs);

// Check that `different_signs` is set correctly by comparing sign bits.
builder
.when(is_signed.clone())
.when_ne(local.top_bits_1[7], local.top_bits_2[7])
.assert_eq(local.different_signs, AB::Expr::one());
builder
.when(local.different_signs)
.assert_eq(local.byte_flag[0], AB::Expr::one());
// local.top_bits_1[7] and local.top_bits_2[7] are boolean; their sum is 1 iff they are unequal.
builder
.when(local.different_signs)
.assert_eq(local.top_bits_1[7] + local.top_bits_2[7], AB::Expr::one());

builder.assert_bool(local.is_lt);
builder.assert_bool(local.is_lte);
builder.assert_bool(local.is_lt + local.is_lte);
builder.assert_bool(local.is_slt);
builder.assert_bool(local.is_sle);
builder.assert_bool(local.is_lt + local.is_lte + local.is_slt + local.is_sle);

// Output constraints
// local.bits[8] is 1 iff input_1 > input_2: output should be 0
builder.when(local.bits[8]).assert_zero(local.output);
// output should be 1 if is_lte & input_1 == input_2
// Case 0: input_1 > input_2 as unsigned ints; equivalently, local.bits[8] == 1
// when both inputs have the same sign, signed and unsigned inequality agree.
builder
.when(local.is_lte)
.when_ne(flag_sum.clone(), AB::Expr::one())
.when(local.bits[8])
.when(is_unsigned.clone() + same_sign.clone())
.assert_zero(local.output);
// when the inputs have different signs, signed inequality is the opposite of unsigned inequality.
builder
.when(local.bits[8])
.when(local.different_signs)
.assert_one(local.output);

// Case 1: input_1 < input_2 as unsigned ints; equivalently, local.bits[8] == is_equal == 0.
morganthomas marked this conversation as resolved.
Show resolved Hide resolved
builder
// when are_equal == 1, we have already enforced that local.bits[8] == 0
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(is_unsigned.clone() + same_sign.clone())
.assert_one(local.output);
// output should be 0 if is_lt & input_1 == input_2
builder
.when(local.is_lt)
.when_ne(flag_sum, AB::Expr::one())
.when_ne(local.bits[8] + are_equal.clone(), AB::Expr::one())
.when(local.different_signs)
.assert_zero(local.output);

// Check bit decomposition
for bit in local.bits.into_iter() {
// Case 2: input_1 == input_2; equivalently, are_equal == 1
// output should be 1 if is_lte or is_sle
builder
.when(are_equal.clone())
.when(local.is_lte + local.is_sle)
.assert_one(local.output);
// output should be 0 if is_lt or is_slt
builder
.when(are_equal.clone())
.when(local.is_lt + local.is_slt)
.assert_zero(local.output);

// Check "bit" values are all boolean
for bit in local
.bits
.into_iter()
.chain(local.top_bits_1.into_iter())
.chain(local.top_bits_2.into_iter())
{
builder.assert_bool(bit);
}
}
Expand Down
Loading
Loading