diff --git a/bench/compliance_circuit_bench.exs b/bench/compliance_circuit_bench.exs index b3491e4..16c3f72 100644 --- a/bench/compliance_circuit_bench.exs +++ b/bench/compliance_circuit_bench.exs @@ -1,5 +1,5 @@ -{:ok, program} = File.read("./native/cairo_vm/compliance.json") -{:ok, input} = File.read("./native/cairo_vm/compliance_input.json") +{:ok, program} = File.read("./juvix/compliance.json") +{:ok, input} = File.read("./juvix/compliance_input.json") {_output, trace, memory, public_inputs} = Cairo.cairo_vm_runner( diff --git a/bench/logic_circuit_bench.exs b/bench/logic_circuit_bench.exs index 3c7a0df..fd94909 100644 --- a/bench/logic_circuit_bench.exs +++ b/bench/logic_circuit_bench.exs @@ -1,7 +1,7 @@ -{:ok, program} = File.read("./native/cairo_vm/trivial_resource_logic.json") +{:ok, program} = File.read("./juvix/trivial_resource_logic.json") {:ok, input} = - File.read("./native/cairo_vm/trivial_resource_logic_input.json") + File.read("./juvix/trivial_resource_logic_input.json") {_output, trace, memory, public_inputs} = Cairo.cairo_vm_runner( diff --git a/native/cairo_vm/cairo.json b/juvix/cairo.json similarity index 100% rename from native/cairo_vm/cairo.json rename to juvix/cairo.json diff --git a/native/cairo_vm/cairo.juvix b/juvix/cairo.juvix similarity index 100% rename from native/cairo_vm/cairo.juvix rename to juvix/cairo.juvix diff --git a/native/cairo_vm/cairo_input.json b/juvix/cairo_input.json similarity index 100% rename from native/cairo_vm/cairo_input.json rename to juvix/cairo_input.json diff --git a/native/cairo_vm/compliance.json b/juvix/compliance.json similarity index 100% rename from native/cairo_vm/compliance.json rename to juvix/compliance.json diff --git a/native/cairo_vm/compliance.juvix b/juvix/compliance.juvix similarity index 100% rename from native/cairo_vm/compliance.juvix rename to juvix/compliance.juvix diff --git a/native/cairo_vm/compliance_input.json b/juvix/compliance_input.json similarity index 100% rename from native/cairo_vm/compliance_input.json rename to juvix/compliance_input.json diff --git a/native/cairo_vm/encryption.json b/juvix/encryption.json similarity index 100% rename from native/cairo_vm/encryption.json rename to juvix/encryption.json diff --git a/native/cairo_vm/encryption.juvix b/juvix/encryption.juvix similarity index 100% rename from native/cairo_vm/encryption.juvix rename to juvix/encryption.juvix diff --git a/native/cairo_vm/encryption_input.json b/juvix/encryption_input.json similarity index 100% rename from native/cairo_vm/encryption_input.json rename to juvix/encryption_input.json diff --git a/native/cairo_vm/trivial_resource_logic.json b/juvix/trivial_resource_logic.json similarity index 100% rename from native/cairo_vm/trivial_resource_logic.json rename to juvix/trivial_resource_logic.json diff --git a/native/cairo_vm/trivial_resource_logic.juvix b/juvix/trivial_resource_logic.juvix similarity index 100% rename from native/cairo_vm/trivial_resource_logic.juvix rename to juvix/trivial_resource_logic.juvix diff --git a/native/cairo_vm/trivial_resource_logic_input.json b/juvix/trivial_resource_logic_input.json similarity index 100% rename from native/cairo_vm/trivial_resource_logic_input.json rename to juvix/trivial_resource_logic_input.json diff --git a/lib/cairo.ex b/lib/cairo.ex index c6fb908..ba86fe5 100644 --- a/lib/cairo.ex +++ b/lib/cairo.ex @@ -88,7 +88,7 @@ defmodule Cairo do to: Cairo.CairoProver, as: :program_hash - @spec felt_to_string(list(byte())) :: binary() + @spec felt_to_string(list(byte())) :: binary() | {:error, term()} defdelegate felt_to_string(felt), to: Cairo.CairoProver, as: :cairo_felt_to_string diff --git a/lib/cairo/cairo.ex b/lib/cairo/cairo.ex index 002bc55..0810bb8 100644 --- a/lib/cairo/cairo.ex +++ b/lib/cairo/cairo.ex @@ -51,6 +51,7 @@ defmodule Cairo.CairoProver do @spec program_hash(list(byte())) :: nif_result(list(byte())) def program_hash(_public_inputs), do: error() + @spec cairo_felt_to_string(list(byte())) :: nif_result(binary()) def cairo_felt_to_string(_felt), do: error() def cairo_generate_compliance_input_json( diff --git a/native/cairo_prover/Cargo.lock b/native/cairo_prover/Cargo.lock index 8cd1ef7..d8c4825 100644 --- a/native/cairo_prover/Cargo.lock +++ b/native/cairo_prover/Cargo.lock @@ -189,7 +189,7 @@ dependencies = [ [[package]] name = "cairo-platinum-prover" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "bincode 2.0.0-rc.2", "cairo-vm", @@ -254,6 +254,7 @@ dependencies = [ "starknet-crypto 0.7.1", "starknet-curve 0.5.0", "starknet-types-core", + "thiserror", ] [[package]] @@ -485,7 +486,7 @@ dependencies = [ [[package]] name = "lambdaworks-crypto" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "lambdaworks-math 0.9.0", "serde", @@ -506,7 +507,7 @@ dependencies = [ [[package]] name = "lambdaworks-math" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "rayon", "serde", @@ -858,7 +859,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "stark-platinum-prover" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "bincode 2.0.0-rc.2", "itertools 0.11.0", diff --git a/native/cairo_prover/Cargo.toml b/native/cairo_prover/Cargo.toml index 422a35e..79e1acd 100644 --- a/native/cairo_prover/Cargo.toml +++ b/native/cairo_prover/Cargo.toml @@ -11,9 +11,9 @@ crate-type = ["cdylib"] [dependencies] rustler = "0.31.0" -cairo-platinum-prover = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} -stark-platinum-prover = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} -lambdaworks-math = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} +cairo-platinum-prover = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} +stark-platinum-prover = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} +lambdaworks-math = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} bincode = "2.0.0-rc.3" serde_json = { version = "1.0", features = ["preserve_order"] } hashbrown = { version = "0.14.0", features = ["serde"] } @@ -26,3 +26,4 @@ num-integer = { version = "0.1.45", default-features = false } rand = "0.8.5" lazy_static = "1.4" serde = { version = "1.0.160", features = ["derive"] } +thiserror = "1.0" diff --git a/native/cairo_prover/src/binding_signature.rs b/native/cairo_prover/src/binding_signature.rs new file mode 100644 index 0000000..4d8ac96 --- /dev/null +++ b/native/cairo_prover/src/binding_signature.rs @@ -0,0 +1,117 @@ +use crate::{ + error::CairoError, + utils::{bytes_to_affine, bytes_to_felt, bytes_to_felt_vec}, +}; +use num_bigint::BigInt; +use num_integer::Integer; +use num_traits::Zero; +use rand::{thread_rng, RngCore}; +use rustler::NifResult; +use starknet_crypto::{poseidon_hash_many, sign, verify}; +use starknet_curve::curve_params::{EC_ORDER, GENERATOR}; +use starknet_types_core::{curve::ProjectivePoint, felt::Felt}; +use std::ops::Add; + +// The private_key_segments are random values used in delta commitments. +// The messages are nullifiers and resource commitments in the transaction. +#[rustler::nif] +fn cairo_binding_sig_sign( + private_key_segments: Vec, + messages: Vec>, +) -> NifResult> { + if private_key_segments.is_empty() || private_key_segments.len() % 32 != 0 { + return Err(CairoError::InvalidInputs.into()); + } + // Compute private key + let private_key = { + let result = private_key_segments + .chunks(32) + .fold(BigInt::zero(), |acc, key_segment| { + let key = BigInt::from_bytes_be(num_bigint::Sign::Plus, key_segment); + acc.add(key) + }) + .mod_floor(&EC_ORDER.to_bigint()); + + let (_, buffer) = result.to_bytes_be(); + let mut result = [0u8; 32]; + result[(32 - buffer.len())..].copy_from_slice(&buffer[..]); + + Felt::from_bytes_be(&result) + }; + + // Message digest + let sig_hash = message_digest(messages)?; + + // ECDSA sign + let mut rng = thread_rng(); + let k = { + let mut felt: [u8; 32] = Default::default(); + rng.fill_bytes(&mut felt); + Felt::from_bytes_be(&felt) + }; + let signature = sign(&private_key, &sig_hash, &k).map_err(CairoError::from)?; + + // Serialize signature + let mut ret = Vec::new(); + ret.extend(signature.r.to_bytes_be()); + ret.extend(signature.s.to_bytes_be()); + // We don't need the v to recover pubkey + // ret.extend(signature.v.to_bytes_be()); + Ok(ret) +} + +// The pub_key_segments are delta commitments in compliance input inputs. +#[rustler::nif] +fn cairo_binding_sig_verify( + pub_key_segments: Vec>, + messages: Vec>, + signature: Vec, +) -> NifResult { + // Generate the public key + let mut pub_key = ProjectivePoint::identity(); + for pk_seg_bytes in pub_key_segments.into_iter() { + let pk_seg = bytes_to_affine(pk_seg_bytes)?; + pub_key += pk_seg; + } + let pub_key_x = pub_key + .to_affine() + .map_err(|_| CairoError::InvalidAffinePoint)? + .x(); + + // Message digest + let msg = message_digest(messages)?; + + // Decode the signature + if signature.len() != 64 { + return Err(CairoError::InvalidSignatureFormat.into()); + } + + let (r_bytes, s_bytes) = signature.split_at(32); + let r = bytes_to_felt(r_bytes.to_vec())?; + let s = bytes_to_felt(s_bytes.to_vec())?; + + // Verify the signature + verify(&pub_key_x, &msg, &r, &s).map_err(|_| CairoError::SigVerifyError.into()) +} + +#[rustler::nif] +fn get_public_key(priv_key: Vec) -> NifResult> { + let priv_key_felt = bytes_to_felt(priv_key)?; + + let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()) + .map_err(|_| CairoError::InvalidAffinePoint)?; + + let pub_key = (&generator * priv_key_felt) + .to_affine() + .map_err(|_| CairoError::InvalidAffinePoint)?; + + let mut ret = pub_key.x().to_bytes_be().to_vec(); + let mut y = pub_key.y().to_bytes_be().to_vec(); + ret.append(&mut y); + Ok(ret) +} + +fn message_digest(msg: Vec>) -> NifResult { + let felt_msg_vec: Vec = bytes_to_felt_vec(msg)?; + Ok(poseidon_hash_many(&felt_msg_vec)) +} diff --git a/native/cairo_prover/src/compliance_input.rs b/native/cairo_prover/src/compliance_input.rs index 27ef73b..6111988 100644 --- a/native/cairo_prover/src/compliance_input.rs +++ b/native/cairo_prover/src/compliance_input.rs @@ -1,6 +1,28 @@ -use crate::utils::felt_to_string; +use crate::{error::CairoError, utils::felt_to_string}; +use rustler::NifResult; use serde::{Deserialize, Serialize}; +#[rustler::nif] +fn cairo_generate_compliance_input_json( + input_resource: Vec, + output_resource: Vec, + path: Vec>, + pos: u64, + input_nf_key: Vec, + eph_root: Vec, + rcv: Vec, +) -> NifResult { + Ok(ComplianceInputJson::to_json_string( + input_resource, + output_resource, + path, + pos, + input_nf_key, + eph_root, + rcv, + )?) +} + #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ComplianceInputJson { input: ResourceJson, @@ -31,32 +53,31 @@ struct PathNode { impl ComplianceInputJson { pub fn to_json_string( - input_resource: &Vec, - output_resource: &Vec, - path: &Vec>, + input_resource: Vec, + output_resource: Vec, + path: Vec>, pos: u64, - input_nf_key: &Vec, - eph_root: &Vec, - rcv: &Vec, - ) -> String { - let input = ResourceJson::from_bytes(input_resource); - let output = ResourceJson::from_bytes(output_resource); + input_nf_key: Vec, + eph_root: Vec, + rcv: Vec, + ) -> Result { + let input = ResourceJson::from_bytes(input_resource)?; + let output = ResourceJson::from_bytes(output_resource)?; - let rcv = felt_to_string(rcv); - let eph_root = felt_to_string(eph_root); - let input_nf_key = felt_to_string(input_nf_key); + let rcv = felt_to_string(rcv)?; + let eph_root = felt_to_string(eph_root)?; + let input_nf_key = felt_to_string(input_nf_key)?; let mut next_pos = pos; - let merkle_path = path - .iter() - .map(|v| { - let snd = if next_pos % 2 == 0 { false } else { true }; - next_pos >>= 1; - PathNode { - fst: felt_to_string(v), - snd, - } - }) - .collect(); + let mut merkle_path = Vec::new(); + for node in path.into_iter() { + let snd = next_pos % 2 != 0; + next_pos >>= 1; + let node = PathNode { + fst: felt_to_string(node)?, + snd, + }; + merkle_path.push(node); + } let compliance_input = Self { input, @@ -66,22 +87,22 @@ impl ComplianceInputJson { rcv, eph_root, }; - serde_json::to_string(&compliance_input).unwrap() + Ok(serde_json::to_string(&compliance_input)?) } } impl ResourceJson { - pub fn from_bytes(bytes: &Vec) -> Self { - Self { - logic: felt_to_string(&bytes[0..32].to_vec()), - label: felt_to_string(&bytes[32..64].to_vec()), - quantity: felt_to_string(&bytes[64..96].to_vec()), - data: felt_to_string(&bytes[96..128].to_vec()), - nonce: felt_to_string(&bytes[128..160].to_vec()), - npk: felt_to_string(&bytes[160..192].to_vec()), - rseed: felt_to_string(&bytes[192..224].to_vec()), - eph: if bytes[224] == 0 { false } else { true }, - } + pub fn from_bytes(bytes: Vec) -> Result { + Ok(Self { + logic: felt_to_string(bytes[0..32].to_vec())?, + label: felt_to_string(bytes[32..64].to_vec())?, + quantity: felt_to_string(bytes[64..96].to_vec())?, + data: felt_to_string(bytes[96..128].to_vec())?, + nonce: felt_to_string(bytes[128..160].to_vec())?, + npk: felt_to_string(bytes[160..192].to_vec())?, + rseed: felt_to_string(bytes[192..224].to_vec())?, + eph: bytes[224] != 0, + }) } } @@ -97,14 +118,27 @@ fn test_compliance_input_json() { let path = (0..32).map(|_| random_felt()).collect(); let json = ComplianceInputJson::to_json_string( - &random_resouce.to_vec(), - &random_resouce.to_vec(), - &path, + random_resouce.to_vec(), + random_resouce.to_vec(), + path, 0, - &random_felt(), - &random_felt(), - &random_felt(), - ); + random_felt(), + random_felt(), + random_felt(), + ) + .unwrap(); println!("compliance_input_json: {}", json); } + +#[test] +fn generate_compliance_input_test_params() { + use starknet_crypto::poseidon_hash; + use starknet_types_core::felt::Felt; + + println!("Felf one hex: {:?}", Felt::ONE.to_hex_string()); + let input_nf_key = Felt::ONE; + let input_npk = poseidon_hash(input_nf_key, Felt::ZERO); + println!("input_npk: {:?}", input_npk.to_bytes_be()); + println!("input_npk: {:?}", input_npk.to_hex_string()); +} diff --git a/native/cairo_prover/src/constants.rs b/native/cairo_prover/src/constants.rs new file mode 100644 index 0000000..9600cc7 --- /dev/null +++ b/native/cairo_prover/src/constants.rs @@ -0,0 +1,43 @@ +use lazy_static::lazy_static; + +// The PLAINTEXT_NUM should be fixed to achieve the indistinguishability of resource logics +// Make it 10 +pub const PLAINTEXT_NUM: usize = 10; +pub const CIPHERTEXT_MAC: usize = PLAINTEXT_NUM; +pub const CIPHERTEXT_PK_X: usize = PLAINTEXT_NUM + 1; +pub const CIPHERTEXT_PK_Y: usize = PLAINTEXT_NUM + 2; +pub const CIPHERTEXT_NONCE: usize = PLAINTEXT_NUM + 3; +pub const CIPHERTEXT_NUM: usize = PLAINTEXT_NUM + 4; + +lazy_static! { + // Bytes: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 67, 97, 105, 114, 111, 95, 69, 120, 112, 97, 110, 100, 83, 101, 101, 100] + // Hexstring: "0x436169726f5f457870616e6453656564" + // Decimal string(used in juvix): "89564067232354163924078705540990330212" + pub static ref PRF_EXPAND_PERSONALIZATION_FELT: Vec = { + let personalization: Vec = b"Cairo_ExpandSeed".to_vec(); + let mut result = [0u8; 32]; + result[(32 - personalization.len())..].copy_from_slice(&personalization[..]); + + result.to_vec() + }; +} + +#[test] +fn test_prf_expand_personalization() { + use starknet_types_core::felt::Felt; + println!( + "PRF_EXPAND_PERSONALIZATION_FELT bytes: {:?}", + *PRF_EXPAND_PERSONALIZATION_FELT + ); + + println!( + "hex: {:?}", + Felt::from_bytes_be( + &PRF_EXPAND_PERSONALIZATION_FELT + .as_slice() + .try_into() + .unwrap() + ) + .to_hex_string() + ); +} diff --git a/native/cairo_prover/src/encryption.rs b/native/cairo_prover/src/encryption.rs index 735368d..02bc966 100644 --- a/native/cairo_prover/src/encryption.rs +++ b/native/cairo_prover/src/encryption.rs @@ -1,3 +1,12 @@ +use crate::{ + constants::{ + CIPHERTEXT_MAC, CIPHERTEXT_NONCE, CIPHERTEXT_NUM, CIPHERTEXT_PK_X, CIPHERTEXT_PK_Y, + PLAINTEXT_NUM, + }, + error::CairoError, + utils::{bytes_to_affine, bytes_to_felt, bytes_to_felt_vec}, +}; +use rustler::NifResult; use starknet_crypto::{poseidon_hash, poseidon_hash_many}; use starknet_curve::curve_params::GENERATOR; use starknet_types_core::{ @@ -5,14 +14,50 @@ use starknet_types_core::{ felt::Felt, }; -// The PLAINTEXT_NUM should be fixed to achieve the indistinguishability of resource logics -// Make it 10 -pub const PLAINTEXT_NUM: usize = 10; -pub const CIPHERTEXT_MAC: usize = PLAINTEXT_NUM; -pub const CIPHERTEXT_PK_X: usize = PLAINTEXT_NUM + 1; -pub const CIPHERTEXT_PK_Y: usize = PLAINTEXT_NUM + 2; -pub const CIPHERTEXT_NONCE: usize = PLAINTEXT_NUM + 3; -pub const CIPHERTEXT_NUM: usize = PLAINTEXT_NUM + 4; +#[rustler::nif] +fn encrypt( + messages: Vec>, + pk: Vec, + sk: Vec, + nonce: Vec, +) -> NifResult>> { + // Decode messages + let msgs_felt = bytes_to_felt_vec(messages)?; + + // Decode pk + let pk_affine = bytes_to_affine(pk)?; + + // Decode sk + let sk_felt = bytes_to_felt(sk)?; + + // Decode nonce + let nonce_felt = bytes_to_felt(nonce)?; + + // Encrypt + let cipher = Ciphertext::encrypt(&msgs_felt, &pk_affine, &sk_felt, &nonce_felt)?; + let cipher_bytes = cipher + .inner() + .iter() + .map(|x| x.to_bytes_be().to_vec()) + .collect(); + + Ok(cipher_bytes) +} + +#[rustler::nif] +fn decrypt(cihper: Vec>, sk: Vec) -> NifResult>> { + // Decode messages + let cipher = Ciphertext::from_bytes(cihper)?; + + // Decode sk + let sk_felt = bytes_to_felt(sk)?; + + // Encrypt + let plaintext = cipher.decrypt(&sk_felt)?; + let plaintext_bytes = plaintext.iter().map(|x| x.to_bytes_be().to_vec()).collect(); + + Ok(plaintext_bytes) +} #[derive(Debug, Clone)] pub struct Ciphertext([Felt; CIPHERTEXT_NUM]); @@ -29,9 +74,15 @@ impl Ciphertext { &self.0 } - pub fn encrypt(messages: &[Felt], pk: &AffinePoint, sk: &Felt, encrypt_nonce: &Felt) -> Self { + pub fn encrypt( + messages: &[Felt], + pk: &AffinePoint, + sk: &Felt, + encrypt_nonce: &Felt, + ) -> Result { // Generate the secret key - let (secret_key_x, secret_key_y) = SecretKey::from_dh_exchange(pk, sk).get_coordinates(); + let secret_key = SecretKey::from_dh_exchange(pk, sk)?; + let (secret_key_x, secret_key_y) = secret_key.get_coordinates(); // Pad the messages let plaintext = Plaintext::padding(messages); @@ -56,33 +107,34 @@ impl Ciphertext { cipher.push(poseidon_state); // Add sender's public key - let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()).unwrap(); - let sender_pk = (&generator * *sk).to_affine().unwrap(); + let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()) + .map_err(|_| CairoError::InvalidAffinePoint)?; + let sender_pk = (&generator * *sk) + .to_affine() + .map_err(|_| CairoError::InvalidAffinePoint)?; cipher.push(sender_pk.x()); cipher.push(sender_pk.y()); // Add encrypt_nonce cipher.push(*encrypt_nonce); - cipher.into() - } + let ret: [Felt; CIPHERTEXT_NUM] = cipher + .try_into() + .map_err(|_| CairoError::InvalidCiphertextLength)?; - pub fn decrypt(&self, sk: &Felt) -> Option> { - let cipher_text = self.inner(); - let cipher_len = cipher_text.len(); - if cipher_len != CIPHERTEXT_NUM { - return None; - } + Ok(Self(ret)) + } - let mac = cipher_text[CIPHERTEXT_MAC]; - let pk_x = cipher_text[CIPHERTEXT_PK_X]; - let pk_y = cipher_text[CIPHERTEXT_PK_Y]; - let encrypt_nonce = cipher_text[CIPHERTEXT_NONCE]; + pub fn decrypt(&self, sk: &Felt) -> Result, CairoError> { + let mac = self.inner()[CIPHERTEXT_MAC]; + let pk_x = self.inner()[CIPHERTEXT_PK_X]; + let pk_y = self.inner()[CIPHERTEXT_PK_Y]; + let encrypt_nonce = self.inner()[CIPHERTEXT_NONCE]; if let Ok(pk) = AffinePoint::new(pk_x, pk_y) { // Generate the secret key - let (secret_key_x, secret_key_y) = - SecretKey::from_dh_exchange(&pk, sk).get_coordinates(); + let sk = SecretKey::from_dh_exchange(&pk, sk)?; + let (secret_key_x, secret_key_y) = sk.get_coordinates(); // Init poseidon sponge state let mut poseidon_state = poseidon_hash_many(&vec![ @@ -94,30 +146,28 @@ impl Ciphertext { // Decrypt let mut msg = vec![]; - for cipher_element in &cipher_text[0..PLAINTEXT_NUM] { + for cipher_element in &self.inner()[0..PLAINTEXT_NUM] { let msg_element = *cipher_element - poseidon_state; msg.push(msg_element); poseidon_state = poseidon_hash(*cipher_element, secret_key_x); } if mac != poseidon_state { - return None; + return Err(CairoError::DecryptionFailure); } - Some(msg) + Ok(msg) } else { - return None; + Err(CairoError::InvalidPublicKey) } } -} -impl From> for Ciphertext { - fn from(input_vec: Vec) -> Self { - Ciphertext( - input_vec - .try_into() - .expect("public input with incorrect length"), - ) + pub fn from_bytes(input_vec: Vec>) -> Result { + let cipher_felt = bytes_to_felt_vec(input_vec)?; + let cipher: [Felt; CIPHERTEXT_NUM] = cipher_felt + .try_into() + .map_err(|_| CairoError::InvalidCiphertextLength)?; + Ok(Self(cipher)) } } @@ -126,10 +176,6 @@ impl Plaintext { &self.0 } - pub fn to_vec(&self) -> Vec { - self.0.to_vec() - } - pub fn padding(msg: &[Felt]) -> Self { let mut plaintext = msg.to_owned(); let padding = std::iter::repeat(Felt::ZERO).take(PLAINTEXT_NUM - msg.len()); @@ -149,12 +195,13 @@ impl From> for Plaintext { } impl SecretKey { - pub fn from_dh_exchange(pk: &AffinePoint, sk: &Felt) -> Self { - Self( - (&ProjectivePoint::try_from(pk.clone()).unwrap() * *sk) - .to_affine() - .unwrap(), - ) + pub fn from_dh_exchange(pk: &AffinePoint, sk: &Felt) -> Result { + let pk_projective = + ProjectivePoint::try_from(pk.clone()).map_err(|_| CairoError::InvalidAffinePoint)?; + let key = (&pk_projective * *sk) + .to_affine() + .map_err(|_| CairoError::InvalidDHKey)?; + Ok(Self(key)) } pub fn get_coordinates(&self) -> (Felt, Felt) { @@ -173,11 +220,11 @@ fn test_encryption() { let encrypt_nonce = Felt::ONE; // Encryption - let cipher = Ciphertext::encrypt(&messages, &pk, &sender_sk, &encrypt_nonce); + let cipher = Ciphertext::encrypt(&messages, &pk, &sender_sk, &encrypt_nonce).unwrap(); // Decryption let decryption = cipher.decrypt(&Felt::ONE).unwrap(); let padded_plaintext = Plaintext::padding(&messages); - assert_eq!(padded_plaintext.to_vec(), decryption); + assert_eq!(padded_plaintext.inner().to_vec(), decryption); } diff --git a/native/cairo_prover/src/error.rs b/native/cairo_prover/src/error.rs new file mode 100644 index 0000000..bd3fa1a --- /dev/null +++ b/native/cairo_prover/src/error.rs @@ -0,0 +1,59 @@ +use bincode::error::{DecodeError, EncodeError}; +use rustler::{Encoder, Env, Term}; +use serde_json::error::Error as JsonError; +use starknet_crypto::SignError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CairoError { + #[error("Inputs should not be empty")] + EmptyInputs, + #[error("Bytes should be a multiple of 24 for trace or 40 for memory")] + CairoImportError, + #[error("Parse public input error: {0}")] + ParsePublicInputError(String), + #[error("Proving error")] + ProvingError, + #[error(transparent)] + EncodeError(#[from] EncodeError), + #[error(transparent)] + DecodeError(#[from] DecodeError), + #[error("Segment not found in memory(public input)")] + SegmentNotFound, + #[error("Address({0}) not found in memory(public input)")] + AddressNotFound(u64), + #[error(transparent)] + SignError(#[from] SignError), + #[error("Invalid inputs")] + InvalidInputs, + #[error("Invalid finite field: 32 bytes needed")] + InvalidFiniteField, + #[error("Invalid Point")] + InvalidAffinePoint, + #[error("Invalid signature: 64 bytes needed")] + InvalidSignatureFormat, + #[error("Signature verification failed")] + SigVerifyError, + #[error(transparent)] + JsonError(#[from] JsonError), + #[error("Invalid public key")] + InvalidPublicKey, + #[error("Invalid DH key")] + InvalidDHKey, + #[error("Invalid mac in decryption")] + DecryptionFailure, + #[error("The length of ciphertext is not correct")] + InvalidCiphertextLength, +} + +impl Encoder for CairoError { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + self.to_string().encode(env) + } +} + +impl From for rustler::Error { + fn from(e: CairoError) -> Self { + rustler::Error::Term(Box::new(e)) + } +} diff --git a/native/cairo_prover/src/errors.rs b/native/cairo_prover/src/errors.rs deleted file mode 100644 index f9f55db..0000000 --- a/native/cairo_prover/src/errors.rs +++ /dev/null @@ -1,164 +0,0 @@ -use rustler::{Encoder, Env, Term}; - -#[derive(Debug)] -pub(crate) enum CairoProveError { - RegisterStatesError(String), - CairoMemoryError(String), - ProofGenerationError(String), - PublicInputError(String), - EncodingError(String), -} - -impl std::fmt::Display for CairoProveError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoProveError::RegisterStatesError(msg) => { - write!(f, "Register states error: {}", msg) - } - CairoProveError::CairoMemoryError(msg) => write!(f, "Cairo memory error: {}", msg), - CairoProveError::ProofGenerationError(msg) => { - write!(f, "Proof generation failed: {}", msg) - } - CairoProveError::PublicInputError(msg) => write!(f, "Public input error: {}", msg), - CairoProveError::EncodingError(msg) => write!(f, "Encoding error: {}", msg), - } - } -} - -impl Encoder for CairoProveError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoVerifyError { - ProofDecodingError(String), - PublicInputDecodingError(String), -} - -impl std::fmt::Display for CairoVerifyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoVerifyError::ProofDecodingError(msg) => write!(f, "Proof decoding error: {}", msg), - CairoVerifyError::PublicInputDecodingError(msg) => { - write!(f, "Public input decoding error: {}", msg) - } - } - } -} - -impl Encoder for CairoVerifyError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoGetOutputError { - DecodingError(String), - SegmentNotFound, - AddressNotFound(u64), -} - -impl std::fmt::Display for CairoGetOutputError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoGetOutputError::DecodingError(msg) => write!(f, "Decoding error: {}", msg), - CairoGetOutputError::SegmentNotFound => { - write!(f, "Output segment not found in memory segments") - } - CairoGetOutputError::AddressNotFound(addr) => { - write!(f, "Address {} not found in public memory", addr) - } - } - } -} - -impl Encoder for CairoGetOutputError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoSignError { - SignatureGenerationError(String), -} - -impl std::fmt::Display for CairoSignError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoSignError::SignatureGenerationError(msg) => { - write!(f, "Binding Signature generation error: {}", msg) - } - } - } -} - -impl Encoder for CairoSignError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum CairoBindingSigVerifyError { - InputError, - VerificationError, -} - -impl std::fmt::Display for CairoBindingSigVerifyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoBindingSigVerifyError::InputError => write!(f, "Invalid input data"), - CairoBindingSigVerifyError::VerificationError => { - write!(f, "Signature verification failed") - } - } - } -} - -impl Encoder for CairoBindingSigVerifyError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum CairoBindingSigError { - KeyGenerationError, -} - -impl std::fmt::Display for CairoBindingSigError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoBindingSigError::KeyGenerationError => write!(f, "Error generating key"), - } - } -} - -impl Encoder for CairoBindingSigError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum TypeError { - DecodingError(String), -} - -impl std::fmt::Display for TypeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - TypeError::DecodingError(msg) => write!(f, "Type error: {}", msg), - } - } -} - -impl Encoder for TypeError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} diff --git a/native/cairo_prover/src/lib.rs b/native/cairo_prover/src/lib.rs index 2641f37..64d649c 100644 --- a/native/cairo_prover/src/lib.rs +++ b/native/cairo_prover/src/lib.rs @@ -1,573 +1,30 @@ -#![allow(dead_code)] - +mod binding_signature; mod compliance_input; +mod constants; mod encryption; -mod errors; +mod error; +mod poseidon; +mod prover; mod utils; - -use crate::{ - compliance_input::ComplianceInputJson, - encryption::Ciphertext, - errors::{ - CairoBindingSigError, CairoBindingSigVerifyError, CairoGetOutputError, CairoProveError, - CairoSignError, CairoVerifyError, - }, - utils::{bytes_to_affine, bytes_to_felt, bytes_to_felt_vec, felt_to_string, random_felt}, -}; -use cairo_platinum_prover::{ - air::{generate_cairo_proof, verify_cairo_proof, PublicInputs, Segment, SegmentName}, - cairo_mem::CairoMemory, - execution_trace::build_main_trace, - register_states::RegisterStates, - Felt252, -}; -use hashbrown::HashMap; -use lambdaworks_math::traits::ByteConversion; -use num_bigint::BigInt; -use num_integer::Integer; -use num_traits::Zero; -use rand::{thread_rng, RngCore}; -use rustler::{Error, NifResult}; -use stark_platinum_prover::proof::options::{ProofOptions, SecurityLevel}; -use starknet_crypto::{poseidon_hash, poseidon_hash_many, poseidon_hash_single, sign, verify}; -use starknet_curve::curve_params::{EC_ORDER, GENERATOR}; -use starknet_types_core::{ - curve::{AffinePoint, ProjectivePoint}, - felt::Felt, -}; -use std::ops::Add; - -#[rustler::nif(schedule = "DirtyCpu")] -fn cairo_prove( - trace: Vec, - memory: Vec, - public_input: Vec, -) -> NifResult<(Vec, Vec)> { - // Generating the prover args - let register_states = RegisterStates::from_bytes_le(&trace).map_err(|e| { - Error::Term(Box::new(CairoProveError::RegisterStatesError(format!( - "{:?}", - e - )))) - })?; - - let memory = CairoMemory::from_bytes_le(&memory).map_err(|e| { - Error::Term(Box::new(CairoProveError::CairoMemoryError(format!( - "{:?}", - e - )))) - })?; - - // Handle public inputs - let (rc_min, rc_max, public_memory, memory_segments) = parse_public_input(&public_input) - .map_err(|e| Error::Term(Box::new(CairoProveError::PublicInputError(e.to_string()))))?; - - let num_steps = register_states.steps(); - let mut pub_inputs = PublicInputs { - pc_init: Felt252::from(register_states.rows[0].pc), - ap_init: Felt252::from(register_states.rows[0].ap), - fp_init: Felt252::from(register_states.rows[0].fp), - pc_final: Felt252::from(register_states.rows[num_steps - 1].pc), - ap_final: Felt252::from(register_states.rows[num_steps - 1].ap), - range_check_min: Some(rc_min), - range_check_max: Some(rc_max), - memory_segments, - public_memory, - num_steps, - }; - - // Build main trace - let main_trace = build_main_trace(®ister_states, &memory, &mut pub_inputs); - - // Generating proof - let proof_options = ProofOptions::new_secure(SecurityLevel::Conjecturable100Bits, 3); - let proof = generate_cairo_proof(&main_trace, &pub_inputs, &proof_options).map_err(|e| { - Error::Term(Box::new(CairoProveError::ProofGenerationError(format!( - "{:?}", - e - )))) - })?; - - // Encode proof and pub_inputs - let proof_bytes = bincode::serde::encode_to_vec(proof, bincode::config::standard()) - .map_err(|e| Error::Term(Box::new(CairoProveError::EncodingError(format!("{:?}", e)))))?; - let pub_input_bytes = bincode::serde::encode_to_vec(&pub_inputs, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoProveError::EncodingError(format!("{:?}", e)))) - })?; - - Ok((proof_bytes, pub_input_bytes)) -} - -fn parse_public_input( - public_input: &[u8], -) -> Result< - ( - u16, - u16, - HashMap, - HashMap, - ), - &'static str, -> { - let rc_min = u16::from_le_bytes( - public_input - .get(0..2) - .ok_or("Input must be at least 2 bytes long for rc_min")? - .try_into() - .map_err(|_| "Failed to convert rc_min bytes")?, - ); - - let rc_max = u16::from_le_bytes( - public_input - .get(2..4) - .ok_or("Input must be at least 4 bytes long for rc_max")? - .try_into() - .map_err(|_| "Failed to convert rc_max bytes")?, - ); - - let mem_len = u64::from_le_bytes( - public_input - .get(4..12) - .ok_or("Input must be at least 12 bytes long for mem_len")? - .try_into() - .map_err(|_| "Failed to convert mem_len bytes")?, - ) as usize; - - let mut public_memory: HashMap = HashMap::new(); - for i in 0..mem_len { - let start_index = 12 + i * 40; - let addr = Felt252::from(u64::from_le_bytes( - public_input - .get(start_index..start_index + 8) - .ok_or("Input too short for public memory address")? - .try_into() - .map_err(|_| "Failed to convert public memory address bytes")?, - )); - let value = Felt252::from_bytes_le( - public_input - .get(start_index + 8..start_index + 40) - .ok_or("Input too short for public memory value")? - .try_into() - .map_err(|_| "Failed to convert public memory value bytes")?, - ) - .map_err(|_| "Failed to create Felt252 from bytes")?; - public_memory.insert(addr, value); - } - - let memory_segments_len = *public_input - .get(12 + 40 * mem_len) - .ok_or("Input too short for memory segments length")? - as usize; - let mut memory_segments = HashMap::new(); - for i in 0..memory_segments_len { - let start_index = 12 + 40 * mem_len + 1 + i * 17; - let segment_type = match public_input - .get(start_index) - .ok_or("Input too short for segment type")? - { - 0u8 => SegmentName::RangeCheck, - 1u8 => SegmentName::Output, - 2u8 => SegmentName::Program, - 3u8 => SegmentName::Execution, - 4u8 => SegmentName::Ecdsa, - 5u8 => SegmentName::Pedersen, - _ => continue, // skip unknown type - }; - - let segment_begin = u64::from_le_bytes( - public_input - .get(start_index + 1..start_index + 9) - .ok_or("Input too short for segment begin")? - .try_into() - .map_err(|_| "Failed to convert segment begin bytes")?, - ); - let segment_stop = u64::from_le_bytes( - public_input - .get(start_index + 9..start_index + 17) - .ok_or("Input too short for segment stop")? - .try_into() - .map_err(|_| "Failed to convert segment stop bytes")?, - ); - memory_segments.insert(segment_type, Segment::new(segment_begin, segment_stop)); - } - - Ok((rc_min, rc_max, public_memory, memory_segments)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -fn cairo_verify(proof: Vec, public_input: Vec) -> NifResult { - let proof_options = ProofOptions::new_secure(SecurityLevel::Conjecturable100Bits, 3); - - // Decode proof - let proof = bincode::serde::decode_from_slice(&proof, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoVerifyError::ProofDecodingError( - e.to_string(), - ))) - })? - .0; - - // Decode public inputs - let pub_inputs = bincode::serde::decode_from_slice(&public_input, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoVerifyError::PublicInputDecodingError( - e.to_string(), - ))) - })? - .0; - - Ok(verify_cairo_proof(&proof, &pub_inputs, &proof_options)) -} - -#[rustler::nif()] -fn cairo_get_output(public_input: Vec) -> NifResult>> { - // Decode public inputs - let (pub_inputs, _): (PublicInputs, usize) = - bincode::serde::decode_from_slice(&public_input, bincode::config::standard()).map_err( - |e| Error::Term(Box::new(CairoGetOutputError::DecodingError(e.to_string()))), - )?; - - // Get output segments - let output_segments = pub_inputs - .memory_segments - .get(&SegmentName::Output) - .ok_or_else(|| Error::Term(Box::new(CairoGetOutputError::SegmentNotFound)))?; - - let begin_addr: u64 = output_segments.begin_addr as u64; - let stop_addr: u64 = output_segments.stop_ptr as u64; - - let mut output_values = Vec::new(); - for addr in begin_addr..stop_addr { - // Convert addr to FieldElement (assuming this is the correct way to create a FieldElement from an address) - let addr_field_element = Felt252::from(addr); - - if let Some(value) = pub_inputs.public_memory.get(&addr_field_element) { - output_values.push(value.clone().to_bytes_be().to_vec()); - } else { - return Err(Error::Term(Box::new(CairoGetOutputError::AddressNotFound( - addr, - )))); - } - } - - Ok(output_values) -} - -// The private_key_segments are random values used in delta commitments. -// The messages are nullifiers and resource commitments in the transaction. -#[rustler::nif] -fn cairo_binding_sig_sign( - private_key_segments: Vec, - messages: Vec>, -) -> NifResult> { - // Compute private key - let private_key = { - let result = private_key_segments - .chunks(32) - .fold(BigInt::zero(), |acc, key_segment| { - let key = BigInt::from_bytes_be(num_bigint::Sign::Plus, &key_segment); - acc.add(key) - }) - .mod_floor(&EC_ORDER.to_bigint()); - - let (_, buffer) = result.to_bytes_be(); - let mut result = [0u8; 32]; - result[(32 - buffer.len())..].copy_from_slice(&buffer[..]); - - Felt::from_bytes_be(&result) - }; - - // Message digest - let sig_hash = message_digest(messages)?; - - // ECDSA sign - let mut rng = thread_rng(); - let k = { - let mut felt: [u8; 32] = Default::default(); - rng.fill_bytes(&mut felt); - Felt::from_bytes_be(&felt) - }; - let signature = sign(&private_key, &sig_hash, &k).map_err(|e| { - Error::Term(Box::new(CairoSignError::SignatureGenerationError( - e.to_string(), - ))) - })?; - - // Serialize signature - let mut ret = Vec::new(); - ret.extend(signature.r.to_bytes_be()); - ret.extend(signature.s.to_bytes_be()); - // We don't need the v to recover pubkey - // ret.extend(signature.v.to_bytes_be()); - Ok(ret) -} - -// The pub_key_segments are delta commitments in compliance input inputs. -#[rustler::nif] -fn cairo_binding_sig_verify( - pub_key_segments: Vec>, - messages: Vec>, - signature: Vec, -) -> NifResult { - // Generate the public key - let pub_key = pub_key_segments - .into_iter() - .try_fold(ProjectivePoint::identity(), |acc, bytes| { - let key_x = Felt::from_bytes_be( - &bytes[0..32] - .try_into() - .map_err(|_| CairoBindingSigVerifyError::InputError)?, - ); - let key_y = Felt::from_bytes_be( - &bytes[32..64] - .try_into() - .map_err(|_| CairoBindingSigVerifyError::InputError)?, - ); - let key_segment_affine = AffinePoint::new(key_x, key_y) - .map_err(|_| CairoBindingSigVerifyError::InputError)?; - Ok(acc.add(key_segment_affine)) - }) - .map_err(|e: CairoBindingSigVerifyError| Error::Term(Box::new(e)))? - .to_affine() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))? - .x(); - - // Message digest - let msg = message_digest(messages)?; - - // Decode the signature - let r = Felt::from_bytes_be( - signature[0..32] - .try_into() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))?, - ); - let s = Felt::from_bytes_be( - signature[32..64] - .try_into() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))?, - ); - - // Verify the signature - verify(&pub_key, &msg, &r, &s) - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::VerificationError))) -} - -// random_felt can help create private key in signature -#[rustler::nif] -fn cairo_random_felt() -> NifResult> { - Ok(random_felt()) -} - -#[rustler::nif] -fn get_public_key(priv_key: Vec) -> NifResult> { - let priv_key_felt = Felt::from_bytes_be_slice(&priv_key); - - let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()) - .map_err(|_| Error::Term(Box::new(CairoBindingSigError::KeyGenerationError)))?; - - let pub_key = (&generator * priv_key_felt) - .to_affine() - .map_err(|_| Error::Term(Box::new(CairoBindingSigError::KeyGenerationError)))?; - - let mut ret = pub_key.x().to_bytes_be().to_vec(); - let mut y = pub_key.y().to_bytes_be().to_vec(); - ret.append(&mut y); - Ok(ret) -} -fn message_digest(msg: Vec>) -> NifResult { - let felt_msg_vec: Vec = bytes_to_felt_vec(msg)?; - Ok(poseidon_hash_many(&felt_msg_vec)) -} - -#[rustler::nif] -fn poseidon_single(x: Vec) -> NifResult> { - let x_field = bytes_to_felt(x)?; - Ok(poseidon_hash_single(x_field).to_bytes_be().to_vec()) -} - -#[rustler::nif] -fn poseidon(x: Vec, y: Vec) -> NifResult> { - let x_field = bytes_to_felt(x)?; - let y_field = bytes_to_felt(y)?; - Ok(poseidon_hash(x_field, y_field).to_bytes_be().to_vec()) -} - -#[rustler::nif] -fn poseidon_many(inputs: Vec>) -> NifResult> { - let vec_fe = bytes_to_felt_vec(inputs)?; - let result_fe = poseidon_hash_many(&vec_fe); - Ok(result_fe.to_bytes_be().to_vec()) -} - -// Get the program from public inputs and return the program hash as the -// resource label -#[rustler::nif] -fn program_hash(public_inputs: Vec) -> NifResult> { - let (pub_inputs, _): (PublicInputs, usize) = - bincode::serde::decode_from_slice(&public_inputs, bincode::config::standard()).unwrap(); - let program_segments = match pub_inputs.memory_segments.get(&SegmentName::Program) { - Some(segment) => segment, - None => { - eprintln!("Error: 'Program' segment not found in memory_segments"); - return Ok(vec![]); - } - }; - - let begin_addr: u64 = program_segments.begin_addr as u64; - let stop_addr: u64 = program_segments.stop_ptr as u64; - - let mut program = Vec::new(); - for addr in begin_addr..stop_addr { - // Convert addr to FieldElement (assuming this is the correct way to create a FieldElement from an address) - let addr_field_element = Felt252::from(addr); - - if let Some(value) = pub_inputs.public_memory.get(&addr_field_element) { - program.push(Felt::from_raw(value.to_raw().limbs)); - } else { - eprintln!( - "Error: Address {:?} not found in public memory", - addr_field_element - ); - return Ok(vec![]); - } - } - - let program_hash = poseidon_hash_many(&program); - - Ok(program_hash.to_bytes_be().to_vec()) -} - -#[rustler::nif] -fn cairo_felt_to_string(felt: Vec) -> String { - felt_to_string(&felt) -} - -#[rustler::nif] -fn cairo_generate_compliance_input_json( - input_resource: Vec, - output_resource: Vec, - path: Vec>, - pos: u64, - input_nf_key: Vec, - eph_root: Vec, - rcv: Vec, -) -> String { - ComplianceInputJson::to_json_string( - &input_resource, - &output_resource, - &path, - pos, - &input_nf_key, - &eph_root, - &rcv, - ) -} - -#[rustler::nif] -fn encrypt( - messages: Vec>, - pk: Vec, - sk: Vec, - nonce: Vec, -) -> NifResult>> { - // Decode messages - let msgs_felt = bytes_to_felt_vec(messages)?; - - // Decode pk - let pk_affine = bytes_to_affine(pk)?; - - // Decode sk - let sk_felt = bytes_to_felt(sk)?; - - // Decode nonce - let nonce_felt = bytes_to_felt(nonce)?; - - // Encrypt - let cipher = Ciphertext::encrypt(&msgs_felt, &pk_affine, &sk_felt, &nonce_felt); - let cipher_bytes = cipher - .inner() - .iter() - .map(|x| x.to_bytes_be().to_vec()) - .collect(); - - Ok(cipher_bytes) -} - -#[rustler::nif] -fn decrypt(cihper: Vec>, sk: Vec) -> NifResult>> { - // Decode messages - let cipher_felt = bytes_to_felt_vec(cihper)?; - - // Decode sk - let sk_felt = bytes_to_felt(sk)?; - - // Encrypt - let plaintext = Ciphertext::from(cipher_felt).decrypt(&sk_felt).unwrap(); - let plaintext_bytes = plaintext.iter().map(|x| x.to_bytes_be().to_vec()).collect(); - - Ok(plaintext_bytes) -} +mod verifier; rustler::init!( "Elixir.Cairo.CairoProver", [ - cairo_prove, - cairo_verify, - cairo_get_output, - cairo_binding_sig_sign, - cairo_binding_sig_verify, - cairo_random_felt, - get_public_key, - poseidon_single, - poseidon, - poseidon_many, - program_hash, - cairo_felt_to_string, - cairo_generate_compliance_input_json, - encrypt, - decrypt, + prover::cairo_prove, + verifier::cairo_verify, + verifier::cairo_get_output, + verifier::program_hash, + binding_signature::cairo_binding_sig_sign, + binding_signature::cairo_binding_sig_verify, + binding_signature::get_public_key, + poseidon::poseidon_single, + poseidon::poseidon, + poseidon::poseidon_many, + utils::cairo_random_felt, + utils::cairo_felt_to_string, + compliance_input::cairo_generate_compliance_input_json, + encryption::encrypt, + encryption::decrypt, ] ); - -use lazy_static::lazy_static; -lazy_static! { - // Bytes: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 67, 97, 105, 114, 111, 95, 69, 120, 112, 97, 110, 100, 83, 101, 101, 100] - // Hexstring: "0x436169726f5f457870616e6453656564" - // Decimal string(used in juvix): "89564067232354163924078705540990330212" - pub static ref PRF_EXPAND_PERSONALIZATION_FELT: Vec = { - let personalization: Vec = b"Cairo_ExpandSeed".to_vec(); - let mut result = [0u8; 32]; - result[(32 - personalization.len())..].copy_from_slice(&personalization[..]); - - result.to_vec() - }; -} - -#[test] -fn test_prf_expand_personalization() { - println!( - "PRF_EXPAND_PERSONALIZATION_FELT bytes: {:?}", - *PRF_EXPAND_PERSONALIZATION_FELT - ); - - println!( - "hex: {:?}", - Felt::from_bytes_be( - &PRF_EXPAND_PERSONALIZATION_FELT - .as_slice() - .try_into() - .unwrap() - ) - .to_hex_string() - ); -} - -#[test] -fn generate_compliance_input_test_params() { - println!("Felf one hex: {:?}", Felt::ONE.to_hex_string()); - let input_nf_key = Felt::ONE; - let input_npk = poseidon_hash(input_nf_key, Felt::ZERO); - println!("input_npk: {:?}", input_npk.to_bytes_be()); - println!("input_npk: {:?}", input_npk.to_hex_string()); -} diff --git a/native/cairo_prover/src/poseidon.rs b/native/cairo_prover/src/poseidon.rs new file mode 100644 index 0000000..a9d8bd8 --- /dev/null +++ b/native/cairo_prover/src/poseidon.rs @@ -0,0 +1,23 @@ +use crate::utils::{bytes_to_felt, bytes_to_felt_vec}; +use rustler::NifResult; +use starknet_crypto::{poseidon_hash, poseidon_hash_many, poseidon_hash_single}; + +#[rustler::nif] +fn poseidon_single(x: Vec) -> NifResult> { + let x_field = bytes_to_felt(x)?; + Ok(poseidon_hash_single(x_field).to_bytes_be().to_vec()) +} + +#[rustler::nif] +fn poseidon(x: Vec, y: Vec) -> NifResult> { + let x_field = bytes_to_felt(x)?; + let y_field = bytes_to_felt(y)?; + Ok(poseidon_hash(x_field, y_field).to_bytes_be().to_vec()) +} + +#[rustler::nif] +fn poseidon_many(inputs: Vec>) -> NifResult> { + let vec_fe = bytes_to_felt_vec(inputs)?; + let result_fe = poseidon_hash_many(&vec_fe); + Ok(result_fe.to_bytes_be().to_vec()) +} diff --git a/native/cairo_prover/src/prover.rs b/native/cairo_prover/src/prover.rs new file mode 100644 index 0000000..6648d7c --- /dev/null +++ b/native/cairo_prover/src/prover.rs @@ -0,0 +1,157 @@ +use crate::error::CairoError; +use cairo_platinum_prover::{ + air::{generate_cairo_proof, PublicInputs, Segment, SegmentName}, + cairo_mem::CairoMemory, + execution_trace::build_main_trace, + register_states::RegisterStates, + Felt252, +}; +use hashbrown::HashMap; +use lambdaworks_math::traits::ByteConversion; +use rustler::NifResult; +use stark_platinum_prover::proof::options::{ProofOptions, SecurityLevel}; + +#[rustler::nif(schedule = "DirtyCpu")] +fn cairo_prove( + trace: Vec, + memory: Vec, + public_input: Vec, +) -> NifResult<(Vec, Vec)> { + if trace.is_empty() || memory.is_empty() || public_input.is_empty() { + return Err(CairoError::EmptyInputs.into()); + } + // Generating the prover args + let register_states = + RegisterStates::from_bytes_le(&trace).map_err(|_| CairoError::CairoImportError)?; + + let memory = CairoMemory::from_bytes_le(&memory).map_err(|_| CairoError::CairoImportError)?; + + // Handle public inputs + let (rc_min, rc_max, public_memory, memory_segments) = parse_public_input(&public_input) + .map_err(|e| CairoError::ParsePublicInputError(e.to_string()))?; + + let num_steps = register_states.steps(); + let mut pub_inputs = PublicInputs { + pc_init: Felt252::from(register_states.rows[0].pc), + ap_init: Felt252::from(register_states.rows[0].ap), + fp_init: Felt252::from(register_states.rows[0].fp), + pc_final: Felt252::from(register_states.rows[num_steps - 1].pc), + ap_final: Felt252::from(register_states.rows[num_steps - 1].ap), + range_check_min: Some(rc_min), + range_check_max: Some(rc_max), + memory_segments, + public_memory, + num_steps, + }; + + // Build main trace + let main_trace = build_main_trace(®ister_states, &memory, &mut pub_inputs); + + // Generating proof + let proof_options = ProofOptions::new_secure(SecurityLevel::Conjecturable100Bits, 3); + let proof = generate_cairo_proof(&main_trace, &pub_inputs, &proof_options) + .map_err(|_| CairoError::ProvingError)?; + + // Encode proof and pub_inputs + let proof_bytes = bincode::serde::encode_to_vec(proof, bincode::config::standard()) + .map_err(CairoError::from)?; + let pub_input_bytes = bincode::serde::encode_to_vec(&pub_inputs, bincode::config::standard()) + .map_err(CairoError::from)?; + + Ok((proof_bytes, pub_input_bytes)) +} + +#[allow(clippy::type_complexity)] +fn parse_public_input( + public_input: &[u8], +) -> Result< + ( + u16, + u16, + HashMap, + HashMap, + ), + &'static str, +> { + let rc_min = u16::from_le_bytes( + public_input + .get(0..2) + .ok_or("Input must be at least 2 bytes long for rc_min")? + .try_into() + .map_err(|_| "Failed to convert rc_min bytes")?, + ); + + let rc_max = u16::from_le_bytes( + public_input + .get(2..4) + .ok_or("Input must be at least 4 bytes long for rc_max")? + .try_into() + .map_err(|_| "Failed to convert rc_max bytes")?, + ); + + let mem_len = u64::from_le_bytes( + public_input + .get(4..12) + .ok_or("Input must be at least 12 bytes long for mem_len")? + .try_into() + .map_err(|_| "Failed to convert mem_len bytes")?, + ) as usize; + + let mut public_memory: HashMap = HashMap::new(); + for i in 0..mem_len { + let start_index = 12 + i * 40; + let addr = Felt252::from(u64::from_le_bytes( + public_input + .get(start_index..start_index + 8) + .ok_or("Input too short for public memory address")? + .try_into() + .map_err(|_| "Failed to convert public memory address bytes")?, + )); + let value = Felt252::from_bytes_le( + public_input + .get(start_index + 8..start_index + 40) + .ok_or("Input too short for public memory value")?, + ) + .map_err(|_| "Failed to create Felt252 from bytes")?; + public_memory.insert(addr, value); + } + + let memory_segments_len = *public_input + .get(12 + 40 * mem_len) + .ok_or("Input too short for memory segments length")? + as usize; + let mut memory_segments = HashMap::new(); + for i in 0..memory_segments_len { + let start_index = 12 + 40 * mem_len + 1 + i * 17; + let segment_type = match public_input + .get(start_index) + .ok_or("Input too short for segment type")? + { + 0u8 => SegmentName::RangeCheck, + 1u8 => SegmentName::Output, + 2u8 => SegmentName::Program, + 3u8 => SegmentName::Execution, + 4u8 => SegmentName::Ecdsa, + 5u8 => SegmentName::Pedersen, + _ => continue, // skip unknown type + }; + + let segment_begin = u64::from_le_bytes( + public_input + .get(start_index + 1..start_index + 9) + .ok_or("Input too short for segment begin")? + .try_into() + .map_err(|_| "Failed to convert segment begin bytes")?, + ); + let segment_stop = u64::from_le_bytes( + public_input + .get(start_index + 9..start_index + 17) + .ok_or("Input too short for segment stop")? + .try_into() + .map_err(|_| "Failed to convert segment stop bytes")?, + ); + memory_segments.insert(segment_type, Segment::new(segment_begin, segment_stop)); + } + + Ok((rc_min, rc_max, public_memory, memory_segments)) +} diff --git a/native/cairo_prover/src/utils.rs b/native/cairo_prover/src/utils.rs index d4e5cf9..e4d4e44 100644 --- a/native/cairo_prover/src/utils.rs +++ b/native/cairo_prover/src/utils.rs @@ -1,17 +1,25 @@ -use crate::errors::TypeError; +use crate::error::CairoError; use rand::{thread_rng, RngCore}; -use rustler::{Error, NifResult}; +use rustler::NifResult; use starknet_types_core::curve::AffinePoint; use starknet_types_core::felt::Felt; -pub fn felt_to_string(felt: &Vec) -> String { - assert_eq!(felt.len(), 32, "The felt size is not 32 bytes"); - Felt::from_bytes_be( - felt.as_slice() - .try_into() - .expect("Slice with incorrect length"), - ) - .to_hex_string() +#[rustler::nif] +fn cairo_felt_to_string(felt: Vec) -> NifResult { + Ok(felt_to_string(felt)?) +} + +// random_felt can help create private key in signature +#[rustler::nif] +fn cairo_random_felt() -> NifResult> { + Ok(random_felt()) +} + +pub fn felt_to_string(bytes: Vec) -> Result { + let felt: [u8; 32] = bytes + .try_into() + .map_err(|_| CairoError::InvalidFiniteField)?; + Ok(Felt::from_bytes_be(&felt).to_hex_string()) } pub fn random_felt() -> Vec { @@ -22,45 +30,35 @@ pub fn random_felt() -> Vec { felt.to_bytes_be().to_vec() } -pub fn bytes_to_felt_vec(bytes: Vec>) -> NifResult> { +pub fn bytes_to_felt_vec(bytes_vec: Vec>) -> Result, CairoError> { + if bytes_vec.is_empty() { + return Err(CairoError::InvalidInputs); + } let mut vec_fe = Vec::new(); - for i in bytes { - let i_bytes: [u8; 32] = i.as_slice().try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError( - "invalid felt".to_string(), - ))) - })?; - vec_fe.push(Felt::from_bytes_be(&i_bytes)) + for fe_bytes in bytes_vec { + let fe = bytes_to_felt(fe_bytes)?; + vec_fe.push(fe) } Ok(vec_fe) } -pub fn bytes_to_felt(bytes: Vec) -> NifResult { - let felt: [u8; 32] = bytes.try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError( - "invalid felt".to_string(), - ))) - })?; +pub fn bytes_to_felt(bytes: Vec) -> Result { + let felt: [u8; 32] = bytes + .try_into() + .map_err(|_| CairoError::InvalidFiniteField)?; Ok(Felt::from_bytes_be(&felt)) } -pub fn bytes_to_affine(bytes: Vec) -> NifResult { +pub fn bytes_to_affine(bytes: Vec) -> Result { if bytes.len() != 64 { - return Err(Error::Term(Box::new(TypeError::DecodingError( - "invalid pk".to_string(), - )))); + return Err(CairoError::InvalidAffinePoint); } - let key_x = - Felt::from_bytes_be(&bytes[0..32].try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string()))) - })?); - let key_y = - Felt::from_bytes_be(&bytes[32..64].try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string()))) - })?); - AffinePoint::new(key_x, key_y) - .map_err(|_| Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string())))) + let (x, y) = bytes.split_at(32); + let key_x = bytes_to_felt(x.to_vec())?; + let key_y = bytes_to_felt(y.to_vec())?; + + AffinePoint::new(key_x, key_y).map_err(|_| CairoError::InvalidAffinePoint) } diff --git a/native/cairo_prover/src/verifier.rs b/native/cairo_prover/src/verifier.rs new file mode 100644 index 0000000..b491929 --- /dev/null +++ b/native/cairo_prover/src/verifier.rs @@ -0,0 +1,88 @@ +use crate::error::CairoError; +use cairo_platinum_prover::{ + air::{verify_cairo_proof, PublicInputs, SegmentName}, + Felt252, +}; +use rustler::NifResult; +use stark_platinum_prover::proof::options::{ProofOptions, SecurityLevel}; +use starknet_crypto::poseidon_hash_many; +use starknet_types_core::felt::Felt; + +#[rustler::nif(schedule = "DirtyCpu")] +fn cairo_verify(proof: Vec, public_input: Vec) -> NifResult { + let proof_options = ProofOptions::new_secure(SecurityLevel::Conjecturable100Bits, 3); + + // Decode proof + let proof = bincode::serde::decode_from_slice(&proof, bincode::config::standard()) + .map_err(CairoError::from)? + .0; + + // Decode public inputs + let pub_inputs = bincode::serde::decode_from_slice(&public_input, bincode::config::standard()) + .map_err(CairoError::from)? + .0; + + Ok(verify_cairo_proof(&proof, &pub_inputs, &proof_options)) +} + +#[rustler::nif()] +fn cairo_get_output(public_input: Vec) -> NifResult>> { + // Decode public inputs + let (pub_inputs, _): (PublicInputs, usize) = + bincode::serde::decode_from_slice(&public_input, bincode::config::standard()) + .map_err(CairoError::from)?; + + // Get output segments + let output_segments = pub_inputs + .memory_segments + .get(&SegmentName::Output) + .ok_or_else(|| CairoError::SegmentNotFound)?; + + let begin_addr: u64 = output_segments.begin_addr as u64; + let stop_addr: u64 = output_segments.stop_ptr as u64; + + let mut output_values = Vec::new(); + for addr in begin_addr..stop_addr { + // Convert addr to FieldElement (assuming this is the correct way to create a FieldElement from an address) + let addr_field_element = Felt252::from(addr); + + if let Some(value) = pub_inputs.public_memory.get(&addr_field_element) { + output_values.push(value.clone().to_bytes_be().to_vec()); + } else { + return Err(CairoError::AddressNotFound(addr).into()); + } + } + + Ok(output_values) +} + +// Get the program from public inputs and return the program hash as the +// resource label +#[rustler::nif] +fn program_hash(public_inputs: Vec) -> NifResult> { + let (pub_inputs, _): (PublicInputs, usize) = + bincode::serde::decode_from_slice(&public_inputs, bincode::config::standard()) + .map_err(CairoError::from)?; + let program_segments = pub_inputs + .memory_segments + .get(&SegmentName::Program) + .ok_or_else(|| CairoError::SegmentNotFound)?; + + let begin_addr: u64 = program_segments.begin_addr as u64; + let stop_addr: u64 = program_segments.stop_ptr as u64; + + let mut program = Vec::new(); + for addr in begin_addr..stop_addr { + // Convert addr to FieldElement (assuming this is the correct way to create a FieldElement from an address) + let addr_field_element = Felt252::from(addr); + let value = pub_inputs + .public_memory + .get(&addr_field_element) + .ok_or_else(|| CairoError::AddressNotFound(addr))?; + program.push(Felt::from_raw(value.to_raw().limbs)); + } + + let program_hash = poseidon_hash_many(&program); + + Ok(program_hash.to_bytes_be().to_vec()) +} diff --git a/native/cairo_vm/Cargo.lock b/native/cairo_vm/Cargo.lock index 3b95479..e8d6bb5 100644 --- a/native/cairo_vm/Cargo.lock +++ b/native/cairo_vm/Cargo.lock @@ -287,6 +287,7 @@ dependencies = [ "juvix-cairo-vm", "rustler", "serde_json", + "thiserror", ] [[package]] diff --git a/native/cairo_vm/Cargo.toml b/native/cairo_vm/Cargo.toml index 98e5090..8fcac8c 100644 --- a/native/cairo_vm/Cargo.toml +++ b/native/cairo_vm/Cargo.toml @@ -14,3 +14,4 @@ rustler = "0.31.0" bincode = "2.0.0-rc.3" juvix-cairo-vm = { git = "https://github.com/anoma/juvix-cairo-vm"} serde_json = "1.0.120" +thiserror = "1.0" diff --git a/native/cairo_vm/README.md b/native/cairo_vm/README.md index b944144..4e09a26 100644 --- a/native/cairo_vm/README.md +++ b/native/cairo_vm/README.md @@ -32,10 +32,10 @@ An example can be found in "cairo_api_test" # Run cairo-vm test "cairo_api_test" do // The file cairo.json is the output of Juvix compiler - {:ok, program} = File.read("./native/cairo_vm/cairo.json") + {:ok, program} = File.read("./juvix/cairo.json") // The file cairo_input.json is what we use to input data into the program. If there's no input, it'll just be an empty string. - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") + {:ok, input} = File.read("./juvix/cairo_input.json") // Run cairo vm {output, trace, memory, public_inputs} = diff --git a/native/cairo_vm/src/error.rs b/native/cairo_vm/src/error.rs new file mode 100644 index 0000000..18f4a71 --- /dev/null +++ b/native/cairo_vm/src/error.rs @@ -0,0 +1,24 @@ +use rustler::{Encoder, Env, Term}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CairoVMError { + #[error("Invalid program content")] + InvalidProgramContent, + #[error("Invalid input JSON")] + InvalidInputJSON, + #[error("Runtime error: {0}")] + RuntimeError(String), +} + +impl Encoder for CairoVMError { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + self.to_string().encode(env) + } +} + +impl From for rustler::Error { + fn from(e: CairoVMError) -> Self { + rustler::Error::Term(Box::new(e)) + } +} diff --git a/native/cairo_vm/src/errors.rs b/native/cairo_vm/src/errors.rs deleted file mode 100644 index 0c55249..0000000 --- a/native/cairo_vm/src/errors.rs +++ /dev/null @@ -1,24 +0,0 @@ -use rustler::{Encoder, Env, Term}; - -#[derive(Debug)] -pub(crate) enum CairoVMError { - InvalidProgramContent, - InvalidInputJSON, - RuntimeError(String), -} - -impl std::fmt::Display for CairoVMError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoVMError::InvalidProgramContent => write!(f, "Invalid program content"), - CairoVMError::InvalidInputJSON => write!(f, "Invalid input JSON"), - CairoVMError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg), - } - } -} - -impl Encoder for CairoVMError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} diff --git a/native/cairo_vm/src/lib.rs b/native/cairo_vm/src/lib.rs index 89eb768..43c5ae6 100644 --- a/native/cairo_vm/src/lib.rs +++ b/native/cairo_vm/src/lib.rs @@ -1,11 +1,12 @@ -mod errors; +mod error; -use crate::errors::CairoVMError; +use crate::error::CairoVMError; use juvix_cairo_vm::{anoma_cairo_vm_runner, program_input::ProgramInput}; -use rustler::{Error, NifResult}; +use rustler::NifResult; use serde_json::Value; use std::collections::HashMap; +#[allow(clippy::type_complexity)] #[rustler::nif(schedule = "DirtyCpu")] fn cairo_vm_runner( program_content: String, @@ -13,18 +14,17 @@ fn cairo_vm_runner( ) -> NifResult<(String, Vec, Vec, Vec)> { // Validate program content serde_json::from_str::(&program_content) - .map_err(|_| Error::Term(Box::new(CairoVMError::InvalidProgramContent)))?; + .map_err(|_| CairoVMError::InvalidProgramContent)?; // Load program input let program_input = if inputs.is_empty() { ProgramInput::new(HashMap::new()) } else { - ProgramInput::from_json(&inputs) - .map_err(|_| Error::Term(Box::new(CairoVMError::InvalidInputJSON)))? + ProgramInput::from_json(&inputs).map_err(|_| CairoVMError::InvalidInputJSON)? }; - anoma_cairo_vm_runner(&program_content.as_bytes(), program_input) - .map_err(|e| Error::Term(Box::new(CairoVMError::RuntimeError(e.to_string())))) + anoma_cairo_vm_runner(program_content.as_bytes(), program_input) + .map_err(|e| CairoVMError::RuntimeError(e.to_string()).into()) } rustler::init!("Elixir.Cairo.CairoVM", [cairo_vm_runner]); diff --git a/test/cairo_binding_signature.exs b/test/cairo_binding_signature.exs index bdff799..0ef7f16 100644 --- a/test/cairo_binding_signature.exs +++ b/test/cairo_binding_signature.exs @@ -16,5 +16,56 @@ defmodule BindingSignatureTest do # Sign and verify signature = (priv_key_1 ++ priv_key_2) |> Cairo.sign(msg) assert true = Cairo.sig_verify(pub_keys, msg, signature) + + # Wrong pub_key + wrong_pub_key = Cairo.get_public_key(priv_key_1) + refute Cairo.sig_verify([wrong_pub_key], msg, signature) + + # Wrong msg + refute Cairo.sig_verify(pub_keys, [List.duplicate(1, 32)], signature) + + # Wrong signature + refute Cairo.sig_verify(pub_keys, msg, List.duplicate(1, 64)) + end + + test "cairo_binding_signature_invalid_input_test" do + priv_key_1 = Cairo.random_felt() + priv_key_2 = Cairo.random_felt() + + pub_keys = + [priv_key_1, priv_key_2] + |> Enum.map(fn x -> Cairo.get_public_key(x) end) + + msg = [Cairo.random_felt(), Cairo.random_felt()] + + assert {:error, "Invalid inputs"} = Cairo.sign([], msg) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sign(priv_key_1, [[]]) + + assert {:error, "Invalid inputs"} = Cairo.sign([1, 2], msg) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sign(priv_key_1, [[1, 2]]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.get_public_key([]) + + signature = (priv_key_1 ++ priv_key_2) |> Cairo.sign(msg) + assert {:error, "Invalid Point"} = Cairo.sig_verify([[]], msg, signature) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sig_verify(pub_keys, [[]], signature) + + assert {:error, "Invalid signature: 64 bytes needed"} = + Cairo.sig_verify(pub_keys, msg, []) + + assert {:error, "Invalid Point"} = Cairo.sig_verify([[1]], msg, signature) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sig_verify(pub_keys, [[1]], signature) + + assert {:error, "Invalid signature: 64 bytes needed"} = + Cairo.sig_verify(pub_keys, msg, [1]) end end diff --git a/test/cairo_compliance_test.exs b/test/cairo_compliance_test.exs index 351cdf2..9053760 100644 --- a/test/cairo_compliance_test.exs +++ b/test/cairo_compliance_test.exs @@ -5,8 +5,8 @@ defmodule CairoComplianceTest do doctest Cairo.CairoVM test "compliance_circuit" do - {:ok, program} = File.read("./native/cairo_vm/compliance.json") - # {:ok, input} = File.read("./native/cairo_vm/compliance_input.json") + {:ok, program} = File.read("./juvix/compliance.json") + # {:ok, input} = File.read("./juvix/compliance_input.json") input_resource = List.duplicate(1, 225) output_resource = List.duplicate(2, 225) path = List.duplicate(Cairo.random_felt(), 32) diff --git a/test/cairo_encryption.exs b/test/cairo_encryption.exs index ffcd9e1..1a1f248 100644 --- a/test/cairo_encryption.exs +++ b/test/cairo_encryption.exs @@ -6,8 +6,8 @@ defmodule NifTest do test "cairo_encryption_test" do # encryption circuit test - {:ok, program} = File.read("./native/cairo_vm/encryption.json") - {:ok, input} = File.read("./native/cairo_vm/encryption_input.json") + {:ok, program} = File.read("./juvix/encryption.json") + {:ok, input} = File.read("./juvix/encryption_input.json") {_output, trace, memory, vm_public_input} = Cairo.cairo_vm_runner( @@ -36,8 +36,46 @@ defmodule NifTest do assert Cairo.get_output(public_input) == expected_cipher # decryption - plaintext = Cairo.decrypt(expected_cipher, felt_bytes_1) + plaintext = Cairo.decrypt(expected_cipher, sk) assert plaintext == expected_plaintext + + # decryption: wrong sk + assert {:error, "Invalid DH key"} = + Cairo.decrypt(expected_cipher, felt_bytes_0) + end + + test "cairo_encryption_invalid_input_test" do + felt_bytes = List.duplicate(1, 32) + plaintext = List.duplicate(felt_bytes, 10) + pk = Cairo.get_public_key(felt_bytes) + invalid_pk = List.duplicate(1, 64) + sk = felt_bytes + nonce = felt_bytes + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt([[]], pk, sk, nonce) + + assert {:error, "Invalid Point"} = Cairo.encrypt(plaintext, [], sk, nonce) + + assert {:error, "Invalid Point"} = + Cairo.encrypt(plaintext, invalid_pk, sk, nonce) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt(plaintext, pk, [], nonce) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt(plaintext, pk, sk, []) + + cipher = List.duplicate(felt_bytes, 14) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.decrypt([[]], sk) + + assert {:error, "The length of ciphertext is not correct"} = + Cairo.decrypt([felt_bytes], sk) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.decrypt(cipher, []) end end diff --git a/test/cairo_logic_test.exs b/test/cairo_logic_test.exs index d03f7eb..23be093 100644 --- a/test/cairo_logic_test.exs +++ b/test/cairo_logic_test.exs @@ -6,10 +6,10 @@ defmodule CairoResourceLogicTest do test "resource_logic_circuit" do {:ok, program} = - File.read("./native/cairo_vm/trivial_resource_logic.json") + File.read("./juvix/trivial_resource_logic.json") {:ok, input} = - File.read("./native/cairo_vm/trivial_resource_logic_input.json") + File.read("./juvix/trivial_resource_logic_input.json") {_output, trace, memory, public_inputs} = Cairo.cairo_vm_runner( diff --git a/test/cairo_negative_test.exs b/test/cairo_negative_test.exs index 829b389..6967190 100644 --- a/test/cairo_negative_test.exs +++ b/test/cairo_negative_test.exs @@ -6,7 +6,7 @@ defmodule NegativeTest do test "cairo_vm_runner with invalid program content" do invalid_program = "This is not valid JSON" - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") + {:ok, input} = File.read("./juvix/cairo_input.json") assert {:error, error_message} = Cairo.cairo_vm_runner(invalid_program, input) @@ -15,7 +15,7 @@ defmodule NegativeTest do end test "cairo_vm_runner with invalid input JSON" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") + {:ok, program} = File.read("./juvix/cairo.json") invalid_input = "This is not valid JSON" assert {:error, error_message} = @@ -37,67 +37,51 @@ defmodule NegativeTest do assert String.starts_with?(error_message, "Runtime error:") end - test "cairo_prove with invalid trace (RegisterStatesError)" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, _trace, memory, vm_public_input} = - Cairo.cairo_vm_runner( - program, - input - ) - - invalid_trace = [0, 1, 2, 3] - - assert {:error, error_message} = - Cairo.prove(invalid_trace, memory, vm_public_input) - - assert String.starts_with?(error_message, "Register states error:") - end - - test "cairo_prove with invalid memory (CairoMemoryError)" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, _memory, vm_public_input} = - Cairo.cairo_vm_runner( - program, - input - ) - - invalid_memory = [0, 1, 2, 3] - - assert {:error, error_message} = - Cairo.prove(trace, invalid_memory, vm_public_input) - - assert String.starts_with?(error_message, "Cairo memory error:") - end - - test "cairo_verify with invalid proof" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, memory, vm_public_input} = - Cairo.cairo_vm_runner(program, input) - - {_proof, public_input} = Cairo.prove(trace, memory, vm_public_input) - invalid_proof = [0, 1, 2, 3] - - assert {:error, error_message} = Cairo.verify(invalid_proof, public_input) - assert String.starts_with?(error_message, "Proof decoding error:") + test "cairo_get_output" do + assert {:error, _} = Cairo.get_output([]) + assert {:error, _} = Cairo.get_output([1, 2, 3, 4]) end - test "cairo_verify with invalid public input" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, memory, vm_public_input} = - Cairo.cairo_vm_runner(program, input) - - {proof, _public_input} = Cairo.prove(trace, memory, vm_public_input) - invalid_public_input = [] - - assert {:error, error_message} = Cairo.verify(proof, invalid_public_input) - assert String.starts_with?(error_message, "Public input decoding error:") + test "cairo_felt_to_string" do + assert "0x0" = Cairo.felt_to_string(List.duplicate(0, 32)) + + assert "0x7752582c54a42fe0fa35c40f07293bb7d8efe90e21d8d2c06a7db52d7d9b7a1" = + Cairo.felt_to_string([ + 7, + 117, + 37, + 130, + 197, + 74, + 66, + 254, + 15, + 163, + 92, + 64, + 240, + 114, + 147, + 187, + 125, + 142, + 254, + 144, + 226, + 29, + 141, + 44, + 6, + 167, + 219, + 82, + 215, + 217, + 183, + 161 + ]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.felt_to_string([1, 2, 3, 4]) end end diff --git a/test/cairo_test.exs b/test/cairo_test.exs index b15a4b1..555b276 100644 --- a/test/cairo_test.exs +++ b/test/cairo_test.exs @@ -4,9 +4,9 @@ defmodule NifTest do doctest Cairo.CairoProver doctest Cairo.CairoVM - test "cairo_api_test" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") + test "cairo_prove_test" do + {:ok, program} = File.read("./juvix/cairo.json") + {:ok, input} = File.read("./juvix/cairo_input.json") {output, trace, memory, vm_public_input} = Cairo.cairo_vm_runner( @@ -25,5 +25,16 @@ defmodule NifTest do Cairo.get_program_hash(public_input) |> Cairo.felt_to_string() # IO.inspect(program_hash) + + assert {:error, _} = Cairo.prove([], memory, vm_public_input) + assert {:error, _} = Cairo.prove(trace, [], vm_public_input) + assert {:error, _} = Cairo.prove(trace, memory, []) + assert {:error, _} = Cairo.prove([1], memory, vm_public_input) + assert {:error, _} = Cairo.prove(trace, [1], vm_public_input) + assert {:error, _} = Cairo.prove(trace, memory, [1]) + assert {:error, _} = Cairo.verify([], public_input) + assert {:error, _} = Cairo.verify(proof, []) + assert {:error, _} = Cairo.verify([1], public_input) + assert {:error, _} = Cairo.verify(proof, [1]) end end diff --git a/test/poseidon_test.exs b/test/poseidon_test.exs index bdb0db0..9195678 100644 --- a/test/poseidon_test.exs +++ b/test/poseidon_test.exs @@ -71,4 +71,20 @@ defmodule PoseidonTest do assert hash_bytes == output end + + test "poseidon_hash_invalid_input" do + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon_single([]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon([], List.duplicate(1, 32)) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon(List.duplicate(1, 32), []) + + assert {:error, "Invalid inputs"} = Cairo.poseidon_many([]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon_many([[1]]) + end end