Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add halo2 verifier with custom algorithm #23

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions halo2-verfier/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
[package]
name = "halo2-verifier"
version = "0.1.0"
edition = "2021"

[dependencies]
rand = "=0.8"
ark-std = { version = "=0.3.0", features = ["print-trace"] }
serde = { version = "=1.0", default-features = false, features = ["derive"] }
serde_json = "=1.0"
log = "=0.4"
env_logger = "=0.10"
clap = { version = "=4.0", features = ["derive"] }
clap-num = "=1.0.2"

# halo2
halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02" }

# Axiom's helper API with basic functions
halo2-base = { git = "https://github.com/axiom-crypto/halo2-lib", branch = "community-edition" }
snark-verifier-sdk = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "community-edition" }

[dev-dependencies]
test-log = "=0.2.11"
ethers-core = "=2.0.6"

[features]
default = []

# Dev / testing mode. We make opt-level = 3 to improve proving times (otherwise it is really slow)
[profile.dev]
opt-level = 3
debug = 1 # change to 0 or 2 for more or less debug info
overflow-checks = true # default
incremental = true # default

# Local "release" mode, more optimized than dev but faster to compile than release
[profile.local]
inherits = "dev"
opt-level = 3
# Set this to 1 or 2 to get more useful backtraces
debug = 1
debug-assertions = false
panic = 'unwind'
# better recompile times
incremental = true
lto = "thin"
codegen-units = 16

[profile.release]
opt-level = 3
debug = false
debug-assertions = false
lto = "fat"
panic = "abort"
incremental = false
36 changes: 36 additions & 0 deletions halo2-verfier/src/cmd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use clap::{Parser, Subcommand};
use std::path::PathBuf;

#[derive(Clone, Copy, Debug, Subcommand)]
pub enum SnarkCmd {
/// Verify a proof
Verify,
}

impl std::fmt::Display for SnarkCmd {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Verify => write!(f, "verify"),
}
}
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
/// Command-line helper for various steps in ZK proving.
pub struct Cli {
#[command(subcommand)]
pub command: SnarkCmd,
#[arg(short, long = "name")]
pub name: String,
#[arg(short = 'k', long = "degree")]
pub degree: u32,
#[arg(short, long = "input")]
pub input_path: Option<PathBuf>,
#[arg(long = "create-contract")]
pub create_contract: bool,
#[arg(short, long = "config-path")]
pub config_path: Option<PathBuf>,
#[arg(short, long = "data-path")]
pub data_path: Option<PathBuf>,
}
218 changes: 218 additions & 0 deletions halo2-verfier/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use halo2_base::{
gates::{
circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, CircuitBuilderStage},
flex_gate::MultiPhaseThreadBreakPoints,
},
halo2_proofs::{
dev::MockProver,
halo2curves::bn256::{Bn256, Fr, G1Affine},
plonk::{verify_proof, Circuit, ProvingKey, VerifyingKey},
poly::{
commitment::{Params, ParamsProver},
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::VerifierSHPLONK,
strategy::SingleStrategy,
},
},
SerdeFormat,
},
utils::fs::gen_srs,
AssignedValue,
};
use serde::de::DeserializeOwned;
use snark_verifier_sdk::{
gen_pk,
halo2::{gen_snark_shplonk, read_snark, PoseidonTranscript},
read_pk, CircuitExt, NativeLoader,
};
use std::{
env::var,
fs::{self, File},
io::{BufReader, BufWriter},
path::{Path, PathBuf},
time::Instant,
};

use self::cmd::{Cli, SnarkCmd};

pub mod cmd;

pub struct CircuitScaffold<T, Fn> {
f: Fn,
private_inputs: T,
}

pub fn run<T: DeserializeOwned>(
f: impl FnOnce(&mut BaseCircuitBuilder<Fr>, T, &mut Vec<AssignedValue<Fr>>),
cli: Cli,
) {
let name = &cli.name;
let input_path = PathBuf::from("data").join(
cli.input_path
.clone()
.unwrap_or_else(|| PathBuf::from(format!("{name}.in"))),
);
let private_inputs: T = serde_json::from_reader(
File::open(&input_path)
.unwrap_or_else(|e| panic!("Input file not found at {input_path:?}. {e:?}")),
)
.expect("Input file should be a valid JSON file");
run_on_inputs(f, cli, private_inputs)
}

pub fn run_on_inputs<T: DeserializeOwned>(
f: impl FnOnce(&mut BaseCircuitBuilder<Fr>, T, &mut Vec<AssignedValue<Fr>>),
cli: Cli,
private_inputs: T,
) {
let precircuit = CircuitScaffold { f, private_inputs };

let name = cli.name;
let k = cli.degree;

let config_path = cli.config_path.unwrap_or_else(|| PathBuf::from("configs"));
let data_path = cli.data_path.unwrap_or_else(|| PathBuf::from("data"));
fs::create_dir_all(&config_path).unwrap();
fs::create_dir_all(&data_path).unwrap();

let params = gen_srs(k);

println!("Universal trusted setup (unsafe!) available at: params/kzg_bn254_{k}.srs");
match cli.command {
let pinning_path = config_path.join(PathBuf::from(format!("{name}.json")));
let mut pinning_file = File::open(&pinning_path)
.unwrap_or_else(|_| panic!("Could not read file at {pinning_path:?}"));
let pinning: (BaseCircuitParams, MultiPhaseThreadBreakPoints) =
serde_json::from_reader(&mut pinning_file).expect("Could not read pinning file");
let circuit =
precircuit.create_circuit(CircuitBuilderStage::Prover, Some(pinning), &params);
let pk_path = data_path.join(PathBuf::from(format!("{name}.pk")));
let pk = custom_read_pk(pk_path, &circuit);
let snark_path = data_path.join(PathBuf::from(format!("{name}.snark")));
if snark_path.exists() {
fs::remove_file(&snark_path).unwrap();
}
let start = Instant::now();
gen_snark_shplonk(&params, &pk, circuit, Some(&snark_path));
let prover_time = start.elapsed();
println!("Proving time: {:?}", prover_time);
println!("Snark written to: {snark_path:?}");
}
SnarkCmd::Verify => {
let vk_path = data_path.join(PathBuf::from(format!("{name}.vk")));
let mut circuit = precircuit.create_circuit(CircuitBuilderStage::Keygen, None, &params);
let vk = custom_read_vk(vk_path, &circuit);
let snark_path = data_path.join(PathBuf::from(format!("{name}.snark")));
let snark = read_snark(&snark_path)
.unwrap_or_else(|e| panic!("Snark not found at {snark_path:?}. {e:?}"));

let verifier_params = params.verifier_params();
let strategy = SingleStrategy::new(&params);
let mut transcript =
PoseidonTranscript::<NativeLoader, &[u8]>::new::<0>(&snark.proof[..]);
let instance = &snark.instances[0][..];
let start = Instant::now();
verify_proof::<
KZGCommitmentScheme<Bn256>,
VerifierSHPLONK<'_, Bn256>,
_,
_,
SingleStrategy<'_, Bn256>,
>(
verifier_params,
&vk,
strategy,
&[&[instance]],
&mut transcript,
)
.unwrap();
let verification_time = start.elapsed();
println!("Snark verified successfully in {:?}", verification_time);
circuit.clear();
}
}


fn custom_read_pk<C, P>(fname: P, circuit: &C) -> ProvingKey<G1Affine>
where
C: Circuit<Fr>,
P: AsRef<Path>,
{
read_pk::<C>(fname.as_ref(), circuit.params())
.unwrap_or_else(|e| panic!("Failed to open file: {:?}: {e:?}", fname.as_ref()))
}

fn custom_read_vk<C, P>(fname: P, circuit: &C) -> VerifyingKey<G1Affine>
where
C: Circuit<Fr>,
P: AsRef<Path>,
{
let f = File::open(&fname)
.unwrap_or_else(|e| panic!("Failed to open file: {:?}: {e:?}", fname.as_ref()));
let mut bufreader = BufReader::new(f);
VerifyingKey::read::<_, C>(&mut bufreader, SerdeFormat::RawBytes, circuit.params())
.expect("Could not read vkey")
}

impl<T, Fn> CircuitScaffold<T, Fn>
where
Fn: FnOnce(&mut BaseCircuitBuilder<Fr>, T, &mut Vec<AssignedValue<Fr>>),
{
/// Creates a Halo2 circuit from the given function.
fn create_circuit(
self,
stage: CircuitBuilderStage,
pinning: Option<(BaseCircuitParams, MultiPhaseThreadBreakPoints)>,
params: &ParamsKZG<Bn256>,
) -> BaseCircuitBuilder<Fr> {
let mut builder = BaseCircuitBuilder::from_stage(stage);
if let Some((params, break_points)) = pinning {
builder.set_params(params);
builder.set_break_points(break_points);
} else {
let k = params.k() as usize;
// we use env var `LOOKUP_BITS` to determine whether to use `GateThreadBuilder` or `RangeCircuitBuilder`. The difference is that the latter creates a lookup table with 2^LOOKUP_BITS rows, while the former does not.
let lookup_bits: Option<usize> = var("LOOKUP_BITS")
.map(|str| {
let lookup_bits = str.parse::<usize>().unwrap();
// we use a lookup table with 2^LOOKUP_BITS rows. Due to blinding factors, we need a little more than 2^LOOKUP_BITS rows total in our circuit
assert!(lookup_bits < k, "LOOKUP_BITS needs to be less than DEGREE");
lookup_bits
})
.ok();
// we initiate a "thread builder". This is what keeps track of the execution trace of our program. If not in proving mode, it also keeps track of the ZK constraints.
builder.set_k(k);
if let Some(lookup_bits) = lookup_bits {
builder.set_lookup_bits(lookup_bits);
}
builder.set_instance_columns(1);
};

// builder.main(phase) gets a default "main" thread for the given phase. For most purposes we only need to think about phase 0
// we need a 64-bit number as input in this case
// while `some_algorithm_in_zk` was written generically for any field `F`, in practice we use the scalar field of the BN254 curve because that's what the proving system backend uses
let mut assigned_instances = vec![];
(self.f)(&mut builder, self.private_inputs, &mut assigned_instances);
if !assigned_instances.is_empty() {
assert_eq!(
builder.assigned_instances.len(),
1,
"num_instance_columns != 1"
);
builder.assigned_instances[0] = assigned_instances;
}

if !stage.witness_gen_only() {
// now `builder` contains the execution trace, and we are ready to actually create the circuit
// minimum rows is the number of rows used for blinding factors. This depends on the circuit itself, but we can guess the number and change it if something breaks (default 9 usually works)
let minimum_rows = var("MINIMUM_ROWS")
.unwrap_or_else(|_| "20".to_string())
.parse()
.unwrap();
builder.calculate_params(Some(minimum_rows));
}

builder
}
}
78 changes: 78 additions & 0 deletions halo2-verfier/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use clap::Parser;
use halo2_base::gates::circuit::builder::BaseCircuitBuilder;
use halo2_base::gates::{GateChip, GateInstructions};
use halo2_base::utils::ScalarField;
use halo2_base::AssignedValue;
#[allow(unused_imports)]
use halo2_base::{
Context,
QuantumCell::{Constant, Existing, Witness},
};
use halo2_scaffold::scaffold::cmd::Cli;
use halo2_scaffold::scaffold::run;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CircuitInput {
pub x: String, // field element, but easier to deserialize as a string
}

// this algorithm takes a public input x, computes x^2 + 72, and outputs the result as public output
fn some_algorithm_in_zk<F: ScalarField>(
builder: &mut BaseCircuitBuilder<F>,
input: CircuitInput,
make_public: &mut Vec<AssignedValue<F>>,
) {
let x = F::from_str_vartime(&input.x).expect("deserialize field element should not fail");
// `Context` can roughly be thought of as a single-threaded execution trace of a program we want to ZK prove. We do some post-processing on `Context` to optimally divide the execution trace into multiple columns in a PLONKish arithmetization
let ctx = builder.main(0);
// More advanced usage with multi-threaded witness generation is possible, but we do not explain it here

// first we load a number `x` into as system, as a "witness"
let x = ctx.load_witness(x);
// by default, all numbers in the system are private
// we can make it public like so:
make_public.push(x);

// create a Gate chip that contains methods for basic arithmetic operations
let gate = GateChip::<F>::default();

// ===== way 1 =====
// now we can perform arithmetic operations almost like a normal program using halo2-lib API functions
// square x
let x_sq = gate.mul(ctx, x, x);

// x^2 + 72
let c = F::from(72);
// the implicit type of most variables is an "Existing" assigned value
// a known constant is a separate type that we specify by `Constant(c)`:
let out = gate.add(ctx, x_sq, Constant(c));
// Halo2 does not distinguish between public inputs vs outputs because the verifier seems them all at the same time
// However in traditional terms, `out` is our output number. It is currently still private.
// Let's make it public:
make_public.push(out);
// ==== way 2 =======
// here is a more optimal way to compute x^2 + 72 using the lower level `assign_region` API:
let val = *x.value() * x.value() + c;
let _val_assigned =
ctx.assign_region_last([Constant(c), Existing(x), Existing(x), Witness(val)], [0]);
// the `[0]` tells us to turn on a vertical `a + b * c = d` gate at row position 0.
// this imposes the constraint c + x * x = val

// ==== way 3 ======
// this does the exact same thing as way 2, but with a pre-existing function
let _val_assigned = gate.mul_add(ctx, x, x, Constant(c));

println!("x: {:?}", x.value());
println!("val_assigned: {:?}", out.value());
assert_eq!(*x.value() * x.value() + c, *out.value());
}

fn main() {
env_logger::init();

let args = Cli::parse();

// run different zk commands based on the command line arguments
run(some_algorithm_in_zk, args);
}