Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] PoseidonHasher supports multiple inputs in compact format #127

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion halo2-base/src/poseidon/hasher/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
gates::{GateInstructions, RangeInstructions},
poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState},
safe_types::SafeTypeChip,
safe_types::{SafeBool, SafeTypeChip},
utils::BigPrimeField,
AssignedValue, Context,
QuantumCell::Constant,
Expand Down Expand Up @@ -49,6 +49,52 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherConsts<F,
}
}

/// 1 logical row of compact input for Poseidon hasher.
pub struct PoseidonCompactInput<F: ScalarField, const RATE: usize> {
// Right padded inputs. No constrains on paddings.
inputs: [AssignedValue<F>; RATE],
// is_final = 1 triggers squeeze.
is_final: SafeBool<F>,
// Length of `inputs`.
len: AssignedValue<F>,
}

impl<F: ScalarField, const RATE: usize> PoseidonCompactInput<F, RATE> {
/// Create a new PoseidonCompactInput.
pub fn new(
inputs: [AssignedValue<F>; RATE],
is_final: SafeBool<F>,
len: AssignedValue<F>,
) -> Self {
Self { inputs, is_final, len }
}

/// Add data validation constraints.
pub fn add_validation_constraints(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
) {
range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64);
// Invalid case: (!is_final && len != RATE) ==> !(is_final || len == RATE)
let is_full: AssignedValue<F> =
range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64)));
let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full);
range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO);
}
}

/// 1 logical row of compact output for Poseidon hasher.
#[derive(Getters)]
pub struct PoseidonCompactOutput<F: ScalarField> {
/// hash of 1 logical input.
#[getset(get = "pub")]
hash: AssignedValue<F>,
/// is_final = 1 ==> this is the end of a logical input.
#[getset(get = "pub")]
is_final: SafeBool<F>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
/// Create a poseidon hasher from an existing spec.
pub fn new(spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
Expand Down Expand Up @@ -82,6 +128,7 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
where
F: BigPrimeField,
{
// TODO: rewrite this using hash_compact_input.
let max_len = inputs.len();
if max_len == 0 {
return *self.empty_hash();
Expand Down Expand Up @@ -147,6 +194,37 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
let mut state = self.init_state().clone();
fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec)
}

/// Constrains and returns hashes of inputs in a compact format. Length of `compact_inputs` should be determined at compile time.
pub fn hash_compact_input(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
compact_inputs: &[PoseidonCompactInput<F, RATE>],
) -> Vec<PoseidonCompactOutput<F>>
where
F: BigPrimeField,
{
let mut outputs = Vec::with_capacity(compact_inputs.len());
let mut state = self.init_state().clone();
for input in compact_inputs {
// Assume this is the last row of a logical input:
// Depending on if len == RATE.
let is_full = range.gate().is_equal(ctx, input.len, Constant(F::from(RATE as u64)));
// Case 1: if len != RATE.
state.permutation(ctx, range.gate(), &input.inputs, Some(input.len), &self.spec);
// Case 2: if len == RATE, an extra permuation is needed for squeeze.
let mut state_2 = state.clone();
state_2.permutation(ctx, range.gate(), &[], None, &self.spec);
// Select the result of case 1/2 depending on if len == RATE.
let hash = range.gate().select(ctx, state_2.s[1], state.s[1], is_full);
outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final });
// Reset state to init_state if this is the end of a logical input.
// TODO: skip this if this is the last row.
state.select(ctx, range.gate(), input.is_final, self.init_state());
}
outputs
}
}

/// Poseidon sponge. This is stateful.
Expand Down
152 changes: 130 additions & 22 deletions halo2-base/src/poseidon/hasher/tests/hasher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::{
gates::{circuit::builder::RangeCircuitBuilder, range::RangeInstructions},
gates::{range::RangeInstructions, RangeChip},
halo2_proofs::halo2curves::bn256::Fr,
poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher},
utils::{testing::base_test, BigPrimeField, ScalarField},
poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher},
safe_types::SafeTypeChip,
utils::{testing::base_test, ScalarField},
Context,
};
use halo2_proofs_axiom::arithmetic::Field;
use pse_poseidon::Poseidon;
use rand::Rng;

Expand All @@ -15,39 +18,96 @@ struct Payload<F: ScalarField> {
pub len: usize,
}

// check if the results from hasher and native sponge are same.
// check if the results from hasher and native sponge are same for hash_var_len_array.
fn hasher_compatiblity_verification<
F: ScalarField,
const T: usize,
const RATE: usize,
const R_F: usize,
const R_P: usize,
>(
payloads: Vec<Payload<F>>,
) where
F: BigPrimeField,
{
let lookup_bits = 3;
payloads: Vec<Payload<Fr>>,
) {
base_test().k(12).run(|ctx, range| {
// Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0.
let spec = OptimizedPoseidonSpec::<Fr, T, RATE>::new::<R_F, R_P, 0>();
let mut hasher = PoseidonHasher::<Fr, T, RATE>::new(spec);
hasher.initialize_consts(ctx, range.gate());

let mut builder = RangeCircuitBuilder::new(true).use_lookup_bits(lookup_bits);
let range = builder.range_chip();
let ctx = builder.main(0);
for payload in payloads {
// Construct native Poseidon sponge.
let mut native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
native_sponge.update(&payload.values[..payload.len]);
let native_result = native_sponge.squeeze();
let inputs = ctx.assign_witnesses(payload.values);
let len = ctx.load_witness(Fr::from(payload.len as u64));
let hasher_result = hasher.hash_var_len_array(ctx, range, &inputs, len);
assert_eq!(native_result, *hasher_result.value());
}
});
}

// check if the results from hasher and native sponge are same for hash_compact_input.
fn hasher_compact_inputs_compatiblity_verification<
const T: usize,
const RATE: usize,
const R_F: usize,
const R_P: usize,
>(
payloads: Vec<Payload<Fr>>,
ctx: &mut Context<Fr>,
range: &RangeChip<Fr>,
) {
// Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0.
let spec = OptimizedPoseidonSpec::<F, T, RATE>::new::<R_F, R_P, 0>();
let mut hasher = PoseidonHasher::<F, T, RATE>::new(spec);
let spec = OptimizedPoseidonSpec::<Fr, T, RATE>::new::<R_F, R_P, 0>();
let mut hasher = PoseidonHasher::<Fr, T, RATE>::new(spec);
hasher.initialize_consts(ctx, range.gate());

let mut native_results = Vec::with_capacity(payloads.len());
let mut compact_inputs = Vec::<PoseidonCompactInput<Fr, RATE>>::new();
let rate_witness = ctx.load_constant(Fr::from(RATE as u64));
let true_witness = ctx.load_constant(Fr::ONE);
let false_witness = ctx.load_zero();
for payload in payloads {
assert!(payload.values.len() % RATE == 0);
assert!(payload.values.len() >= payload.len);
assert!(payload.values.len() == RATE || payload.values.len() - payload.len < RATE);
let num_chunk = payload.values.len() / RATE;
let last_chunk_len = RATE - (payload.values.len() - payload.len);
let inputs = ctx.assign_witnesses(payload.values.clone());
for (chunk_idx, input_chunk) in inputs.chunks(RATE).enumerate() {
let len_witness = if chunk_idx + 1 == num_chunk {
ctx.load_witness(Fr::from(last_chunk_len as u64))
} else {
rate_witness
};
let is_final_witness = SafeTypeChip::unsafe_to_bool(if chunk_idx + 1 == num_chunk {
true_witness
} else {
false_witness
});
compact_inputs.push(PoseidonCompactInput {
inputs: input_chunk.try_into().unwrap(),
len: len_witness,
is_final: is_final_witness,
});
}
// Construct native Poseidon sponge.
let mut native_sponge = Poseidon::<F, T, RATE>::new(R_F, R_P);
let mut native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
native_sponge.update(&payload.values[..payload.len]);
let native_result = native_sponge.squeeze();
let inputs = ctx.assign_witnesses(payload.values);
let len = ctx.load_witness(F::from(payload.len as u64));
let hasher_result = hasher.hash_var_len_array(ctx, &range, &inputs, len);
// 0x1f0db93536afb96e038f897b4fb5548b6aa3144c46893a6459c4b847951a23b4
assert_eq!(native_result, *hasher_result.value());
native_results.push(native_result);
}
let compact_outputs = hasher.hash_compact_input(ctx, range, &compact_inputs);
let mut output_offset = 0;
for (compact_output, compact_input) in compact_outputs.iter().zip(compact_inputs) {
// into() doesn't work if ! is in the beginning in the bool expression...
let is_not_final_input: bool = compact_input.is_final.as_ref().value().is_zero().into();
let is_not_final_output: bool = compact_output.is_final().as_ref().value().is_zero().into();
assert_eq!(is_not_final_input, is_not_final_output);
if !is_not_final_output {
assert_eq!(native_results[output_offset], *compact_output.hash().value());
output_offset += 1;
}
}
}

Expand Down Expand Up @@ -98,7 +158,7 @@ fn test_poseidon_hasher_compatiblity() {
random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX),
random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX),
];
hasher_compatiblity_verification::<Fr, T, RATE, 8, 57>(payloads);
hasher_compatiblity_verification::<T, RATE, 8, 57>(payloads);
}
}

Expand Down Expand Up @@ -127,3 +187,51 @@ fn test_poseidon_hasher_with_prover() {
}
}
}

#[test]
fn test_poseidon_hasher_compact_inputs() {
{
const T: usize = 3;
const RATE: usize = 2;
let payloads = vec![
// len == 0
random_payload(RATE, 0, usize::MAX),
// 0 < len < max_len
random_payload(RATE * 2, RATE + 1, usize::MAX),
random_payload(RATE * 5, RATE * 4 + 1, usize::MAX),
// len == max_len
random_payload(RATE * 2, RATE * 2, usize::MAX),
random_payload(RATE * 5, RATE * 5, usize::MAX),
];
base_test().k(12).run(|ctx, range| {
hasher_compact_inputs_compatiblity_verification::<T, RATE, 8, 57>(payloads, ctx, range);
});
}
}

#[test]
fn test_poseidon_hasher_compact_inputs_with_prover() {
{
const T: usize = 3;
const RATE: usize = 2;
let params = vec![
(RATE, 0),
(RATE * 2, RATE + 1),
(RATE * 5, RATE * 4 + 1),
(RATE * 2, RATE * 2),
(RATE * 5, RATE * 5),
];
let init_payloads = params
.iter()
.map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX))
.collect::<Vec<_>>();
let logic_payloads = params
.iter()
.map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX))
.collect::<Vec<_>>();
base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| {
let ctx = pool.main();
hasher_compact_inputs_compatiblity_verification::<T, RATE, 8, 57>(input, ctx, range);
});
}
}
Loading