From 5e152a48b90862f068ce8a9be1b1ee5bc2fa25f6 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 1 Oct 2024 13:27:33 -0300 Subject: [PATCH 01/23] feat: initial wip readme for dwn-rs-remote --- crates/dwn-rs-remote/README.md | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 crates/dwn-rs-remote/README.md diff --git a/crates/dwn-rs-remote/README.md b/crates/dwn-rs-remote/README.md new file mode 100644 index 0000000..292a2f8 --- /dev/null +++ b/crates/dwn-rs-remote/README.md @@ -0,0 +1,48 @@ +# dwn-rs-remote + +A Rust library for interacting with remote Decentralized Web Node (DWN) instances using JSON APIs. + +## Overview + +`dwn-rs-remote` is a new crate designed to simplify the process of interacting with remote DWN instances. This crate provides a single, +unified API for sending messages to remote DWN instances and receiving responses. It supports the `processMessage` function, +which will interact with a JSON API offered by e.g. `dwn-server`. + +## Features + +- **Simple API**: `dwn-rs-remote` provides a straightforward API for interacting with remote DWN instances. +- **JSON API support**: The crate supports the JSON API offered by `dwn-server`, allowing developers to interact with remote DWN instances. + +## Requirements + +- **dwn-rs** (or `dwn-rs-core`): This crate depends on the `dwn-rs` crate, which provides the core data structures and functions for interacting with DWN services. + +## Usage + +To use `dwn-rs-remote`, add the following dependency to your `Cargo.toml`: + +```toml +[dependencies] +dwn-rs-remote = "0.1.0" +``` + +Then, you can use the `processMessage` function to interact with remote DWN instances. Here's an example: + +```rust +use dwn_rss_remote::{process_message, RemoteDWNInstance}; + +let instance = RemoteDWNInstance::new("https://example.com/dwn"); +let message = serde_json::json!({"type": "message", "data": "Hello, world!"}); +let tenant = "did:persona" +let response = process_message(&tenant, &instance, message); + +if let Some(response) = response { + println!("Received response: {}", response); +} else { + println!("Error processing message"); +} +``` + +## Contributing + +Contributions to `dwn-rs-remote` are welcome. Please see the [README file](https://github.com/enmand/dwn-rs) for more information on how to contribute to the project. From 04582fe5d1636343b888af676aa8d0ac3a58d5ed Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Wed, 2 Oct 2024 12:01:27 -0300 Subject: [PATCH 02/23] (wip) feat: dwn-server compatible JSONRPC client --- Cargo.lock | 148 ++++++++++++++++++--- Cargo.toml | 2 +- crates/dwn-rs-core/src/errors.rs | 21 +++ crates/dwn-rs-remote/Cargo.toml | 18 +++ crates/dwn-rs-remote/src/errors.rs | 191 ++++++++++++++++++++++++++++ crates/dwn-rs-remote/src/jsonrpc.rs | 185 +++++++++++++++++++++++++++ crates/dwn-rs-remote/src/lib.rs | 6 + 7 files changed, 553 insertions(+), 18 deletions(-) create mode 100644 crates/dwn-rs-remote/Cargo.toml create mode 100644 crates/dwn-rs-remote/src/errors.rs create mode 100644 crates/dwn-rs-remote/src/jsonrpc.rs create mode 100644 crates/dwn-rs-remote/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 16b0d72..6f9a0d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -55,6 +55,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.7.8" @@ -221,6 +227,19 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-compression" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fec134f64e2bc57411226dfc4e52dec859ddfc7e711fc5e07b612584f000e4aa" +dependencies = [ + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-executor" version = "1.12.0" @@ -479,7 +498,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.3", "object", "rustc-demangle", ] @@ -1412,6 +1431,18 @@ dependencies = [ "xtra", ] +[[package]] +name = "dwn-rs-remote" +version = "0.1.0" +dependencies = [ + "dwn-rs-core", + "futures-core", + "reqwest 0.12.8", + "serde", + "serde_json", + "thiserror", +] + [[package]] name = "dwn-rs-stores" version = "0.1.0" @@ -1754,6 +1785,16 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flate2" +version = "1.0.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.0", +] + [[package]] name = "float_next_after" version = "1.0.0" @@ -2212,6 +2253,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.4.1" @@ -2406,7 +2466,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -2429,6 +2489,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -2470,6 +2531,22 @@ dependencies = [ "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.7" @@ -3349,6 +3426,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "1.0.1" @@ -4586,11 +4672,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "hyper 0.14.28", - "hyper-tls", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", @@ -4604,7 +4690,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", "tokio-native-tls", "tower-service", @@ -4617,25 +4703,30 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.7" +version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" dependencies = [ + "async-compression", "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "http-body-util", "hyper 1.4.1", "hyper-rustls", + "hyper-tls 0.6.0", "hyper-util", "ipnet", "js-sys", "log", "mime", "mime_guess", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", @@ -4647,7 +4738,9 @@ dependencies = [ "serde_json", "serde_urlencoded", "sync_wrapper 1.0.1", + "system-configuration 0.6.1", "tokio", + "tokio-native-tls", "tokio-rustls", "tokio-util", "tower-service", @@ -5146,9 +5239,9 @@ checksum = "cd0b0ec5f1c1ca621c432a25813d8d60c88abe6d3e08a3eb9cf37d97a0fe3d73" [[package]] name = "serde" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] @@ -5195,9 +5288,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", @@ -5979,7 +6072,7 @@ dependencies = [ "path-clean", "pharos", "reblessive", - "reqwest 0.12.7", + "reqwest 0.12.8", "revision 0.10.0", "ring", "rust_decimal", @@ -6189,7 +6282,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.5.0", + "core-foundation", + "system-configuration-sys 0.6.0", ] [[package]] @@ -6202,6 +6306,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tap" version = "1.0.1" @@ -6244,18 +6358,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 2209efd..186cea5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["crates/dwn-rs-core", "crates/dwn-rs-stores", "crates/dwn-rs-wasm"] +members = ["crates/dwn-rs-core", "crates/dwn-rs-remote", "crates/dwn-rs-stores", "crates/dwn-rs-wasm"] [profile.release] lto = true diff --git a/crates/dwn-rs-core/src/errors.rs b/crates/dwn-rs-core/src/errors.rs index 55705df..b0cdef3 100644 --- a/crates/dwn-rs-core/src/errors.rs +++ b/crates/dwn-rs-core/src/errors.rs @@ -5,6 +5,27 @@ use ulid::MonotonicError; use crate::{FilterError, QueryError}; +#[derive(Error, Debug)] +pub enum Error { + #[error("error operating store: {0}")] + StoreError(#[from] StoreError), + + #[error("error processing message: {0}")] + MessageError(#[from] MessageStoreError), + + #[error("error processing data: {0}")] + DataError(#[from] DataStoreError), + + #[error("error processing event log: {0}")] + EventLogError(#[from] EventLogError), + + #[error("error processing resumable task: {0}")] + ResumableTaskError(#[from] ResumableTaskStoreError), + + #[error("error processing event stream: {0}")] + EventStreamError(#[from] EventStreamError), +} + #[derive(Error, Debug)] pub enum StoreError { #[error("error opening database: {0}")] diff --git a/crates/dwn-rs-remote/Cargo.toml b/crates/dwn-rs-remote/Cargo.toml new file mode 100644 index 0000000..f59b7a1 --- /dev/null +++ b/crates/dwn-rs-remote/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "dwn-rs-remote" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "Remote SDK components for dwn-rs" + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +reqwest = { version = "0.12.8", features = ["stream", "json", "gzip"] } + +dwn-rs-core = { path = "../dwn-rs-core", default-features = false } +futures-core = { version = "0.3.30", default-features = false, features = ["alloc"] } +thiserror = "1.0.64" +serde = "1.0.210" +serde_json = "1.0.128" diff --git a/crates/dwn-rs-remote/src/errors.rs b/crates/dwn-rs-remote/src/errors.rs new file mode 100644 index 0000000..b497142 --- /dev/null +++ b/crates/dwn-rs-remote/src/errors.rs @@ -0,0 +1,191 @@ +use serde::{ser::SerializeMap, Deserialize, Serialize}; +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum RemoteError { + #[error("reqwest error: {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + #[error("error: {0}")] + Error(String), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct JSONRpcError { + #[serde(flatten)] + pub error: JSONRpcErrorCodes, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum JSONRpcErrorCodes { + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + ParseError = -32700, + + BadRequest = -50400, + Unauthorized = -50401, + Forbidden = -50403, +} + +// Serialize JSONRpcErrorCodes as { "code": number, "message": string } +impl Serialize for JSONRpcErrorCodes { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let code = *self as i64; + let message = match self { + JSONRpcErrorCodes::InvalidRequest => "Invalid Request", + JSONRpcErrorCodes::MethodNotFound => "Method Not Found", + JSONRpcErrorCodes::InvalidParams => "Invalid Params", + JSONRpcErrorCodes::InternalError => "Internal Error", + JSONRpcErrorCodes::ParseError => "Parse Error", + JSONRpcErrorCodes::BadRequest => "Bad Request", + JSONRpcErrorCodes::Unauthorized => "Unauthorized", + JSONRpcErrorCodes::Forbidden => "Forbidden", + }; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("code", &code)?; + map.serialize_entry("message", message)?; + map.end() + } +} + +// Deserialize JSONRpcErrorCodes from { "code": number, "message": string } +// or numeric code +impl<'de> Deserialize<'de> for JSONRpcErrorCodes { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct JSONRpcErrorCodesVisitor; + + impl<'de> serde::de::Visitor<'de> for JSONRpcErrorCodesVisitor { + type Value = JSONRpcErrorCodes; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a JSONRpcErrorCodes object or a number") + } + + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let mut code = None; + while let Some(key) = map.next_key::()? { + if key.as_str() == "code" { + code = Some(map.next_value()?); + } + } + match code { + Some(code) => match code { + -32600 => Ok(JSONRpcErrorCodes::InvalidRequest), + -32601 => Ok(JSONRpcErrorCodes::MethodNotFound), + -32602 => Ok(JSONRpcErrorCodes::InvalidParams), + -32603 => Ok(JSONRpcErrorCodes::InternalError), + -32700 => Ok(JSONRpcErrorCodes::ParseError), + -50400 => Ok(JSONRpcErrorCodes::BadRequest), + -50401 => Ok(JSONRpcErrorCodes::Unauthorized), + -50403 => Ok(JSONRpcErrorCodes::Forbidden), + _ => Err(serde::de::Error::custom(format!("unknown code: {}", code))), + }, + None => Err(serde::de::Error::missing_field("code")), + } + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + match v { + -32600 => Ok(JSONRpcErrorCodes::InvalidRequest), + -32601 => Ok(JSONRpcErrorCodes::MethodNotFound), + -32602 => Ok(JSONRpcErrorCodes::InvalidParams), + -32603 => Ok(JSONRpcErrorCodes::InternalError), + -32700 => Ok(JSONRpcErrorCodes::ParseError), + -50400 => Ok(JSONRpcErrorCodes::BadRequest), + -50401 => Ok(JSONRpcErrorCodes::Unauthorized), + -50403 => Ok(JSONRpcErrorCodes::Forbidden), + _ => Err(serde::de::Error::custom(format!("unknown code: {}", v))), + } + } + } + + deserializer.deserialize_any(JSONRpcErrorCodesVisitor) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jsonrpc_errorcodes_serialization() { + let codes: Vec<(JSONRpcErrorCodes, &'static str)> = vec![ + ( + JSONRpcErrorCodes::InvalidRequest, + r#"{"code":-32600,"message":"Invalid Request"}"#, + ), + ( + JSONRpcErrorCodes::MethodNotFound, + r#"{"code":-32601,"message":"Method Not Found"}"#, + ), + ( + JSONRpcErrorCodes::InvalidParams, + r#"{"code":-32602,"message":"Invalid Params"}"#, + ), + ( + JSONRpcErrorCodes::InternalError, + r#"{"code":-32603,"message":"Internal Error"}"#, + ), + ( + JSONRpcErrorCodes::ParseError, + r#"{"code":-32700,"message":"Parse Error"}"#, + ), + ( + JSONRpcErrorCodes::BadRequest, + r#"{"code":-50400,"message":"Bad Request"}"#, + ), + ( + JSONRpcErrorCodes::Unauthorized, + r#"{"code":-50401,"message":"Unauthorized"}"#, + ), + ( + JSONRpcErrorCodes::Forbidden, + r#"{"code":-50403,"message":"Forbidden"}"#, + ), + ]; + + for (code, json) in codes { + let err: JSONRpcError = serde_json::from_str(json).unwrap(); + + let expected = JSONRpcError { + error: code, + data: None, + }; + + assert_eq!(err, expected); + } + } + + #[test] + fn test_jsonrpc_serialization() { + let error = JSONRpcError { + error: JSONRpcErrorCodes::InvalidRequest, + data: None, + }; + let json = serde_json::to_string(&error).unwrap(); + assert_eq!(json, r#"{"code":-32600,"message":"Invalid Request"}"#); + + let error: JSONRpcError = serde_json::from_str(&json).unwrap(); + assert_eq!(error.error, JSONRpcErrorCodes::InvalidRequest); + assert_eq!(error.data, None); + } +} diff --git a/crates/dwn-rs-remote/src/jsonrpc.rs b/crates/dwn-rs-remote/src/jsonrpc.rs new file mode 100644 index 0000000..e7e7f72 --- /dev/null +++ b/crates/dwn-rs-remote/src/jsonrpc.rs @@ -0,0 +1,185 @@ +use serde::{Deserialize, Serialize}; + +use crate::JSONRpcError; + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +enum Version { + #[serde(rename = "2.0")] + V2, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +enum ID { + String(String), + Number(i64), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SubscriptionRequest { + id: ID, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct Request { + jsonrpc: Version, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + method: String, + params: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + subscription: Option, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct Result { + result: T, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct Error { + error: JSONRpcError, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +enum ResultError { + Result(Result), + Error(Error), +} + +// Define the JSONRPCResponse struct +#[derive(Debug, Serialize, Deserialize, PartialEq)] +struct Response { + jsonrpc: Version, + id: ID, + #[serde(flatten)] + result: ResultError, +} + +#[cfg(test)] +mod test { + use crate::JSONRpcErrorCodes; + + use super::*; + + #[test] + fn test_request() { + #[derive(Debug, PartialEq)] + struct TestCase { + request: Request, + expected: &'static str, + } + + // Define your test cases using the TestCase struct + let test_cases = vec![ + TestCase { + request: Request { + jsonrpc: Version::V2, + id: Some(ID::Number(1)), + method: "test".to_string(), + params: Some(vec!["param1".to_string(), "param2".to_string()]), + subscription: None, + }, + expected: r#"{"jsonrpc":"2.0","id":1,"method":"test","params":["param1","param2"]}"#, + }, + TestCase { + request: Request { + jsonrpc: Version::V2, + id: None, + method: "test".to_string(), + params: Some(vec!["param1".to_string(), "param2".to_string()]), + subscription: None, + }, + expected: r#"{"jsonrpc":"2.0","method":"test","params":["param1","param2"]}"#, + }, + TestCase { + request: Request { + jsonrpc: Version::V2, + id: Some(ID::String("1".to_string())), + method: "test".to_string(), + params: None, + subscription: None, + }, + expected: r#"{"jsonrpc":"2.0","id":"1","method":"test"}"#, + }, + TestCase { + request: Request { + jsonrpc: Version::V2, + id: Some(ID::Number(1)), + method: "test".to_string(), + params: None, + subscription: Some(SubscriptionRequest { id: ID::Number(1) }), + }, + expected: r#"{"jsonrpc":"2.0","id":1,"method":"test","subscription":{"id":1}}"#, + }, + ]; + + for test_case in test_cases { + let serialized = serde_json::to_string(&test_case.request).unwrap(); + assert_eq!( + serialized, test_case.expected, + "Mismatch for test case {:?}", + test_case + ); + + let deserialized: Request = serde_json::from_str(&serialized).unwrap(); + assert_eq!( + test_case.request, deserialized, + "Deserialization mismatch for {}", + serialized + ); + } + } + + #[test] + fn test_response() { + #[derive(Debug, PartialEq)] + struct TestCase { + response: Response, + expected: &'static str, + } + + let test_cases = vec![ + TestCase { + response: Response { + jsonrpc: Version::V2, + id: ID::Number(1), + result: ResultError::Result(Result { + result: "test".to_string(), + }), + }, + expected: r#"{"jsonrpc":"2.0","id":1,"result":"test"}"#, + }, + TestCase { + response: Response { + jsonrpc: Version::V2, + id: ID::Number(1), + result: ResultError::Error(Error { + error: JSONRpcError { + error: JSONRpcErrorCodes::InvalidRequest, + data: None, + }, + }), + }, + expected: r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#, + }, + ]; + + for test_case in test_cases { + let serialized = serde_json::to_string(&test_case.response).unwrap(); + assert_eq!( + serialized, test_case.expected, + "Mismatch for test case {:?}", + test_case + ); + + let deserialized: Response = serde_json::from_str(&serialized).unwrap(); + assert_eq!( + test_case.response, deserialized, + "Deserialization mismatch for {}", + serialized + ); + } + } +} diff --git a/crates/dwn-rs-remote/src/lib.rs b/crates/dwn-rs-remote/src/lib.rs new file mode 100644 index 0000000..104e143 --- /dev/null +++ b/crates/dwn-rs-remote/src/lib.rs @@ -0,0 +1,6 @@ +pub mod client; +pub mod errors; +pub mod jsonrpc; + +pub use client::*; +pub use errors::*; From 9e94dc80814db9d68b8fd986e2ac525445cf4ad7 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Sun, 6 Oct 2024 20:40:10 -0300 Subject: [PATCH 03/23] chore: separate SubscriptionID type from Subscription --- Cargo.lock | 40 +++++++++++++++---------- crates/dwn-rs-core/src/events/stream.rs | 10 +++++-- crates/dwn-rs-wasm/src/events.rs | 2 +- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6f9a0d1..f4b4cbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1435,12 +1435,23 @@ dependencies = [ name = "dwn-rs-remote" version = "0.1.0" dependencies = [ + "bytes", "dwn-rs-core", "futures-core", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", "reqwest 0.12.8", "serde", "serde_json", + "serde_repr", "thiserror", + "tokio", + "tokio-util", + "tower", + "tracing", + "tracing-subscriber", + "ulid", ] [[package]] @@ -2549,9 +2560,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.7" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +checksum = "41296eb09f183ac68eec06e03cdbea2e759633d4067b2f6552fc2e009bcad08b" dependencies = [ "bytes", "futures-channel", @@ -2562,7 +2573,6 @@ dependencies = [ "pin-project-lite", "socket2 0.5.7", "tokio", - "tower", "tower-service", "tracing", ] @@ -6507,14 +6517,16 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", "futures-io", "futures-sink", + "futures-util", + "hashbrown 0.14.5", "pin-project-lite", "tokio", ] @@ -6549,15 +6561,10 @@ dependencies = [ [[package]] name = "tower" -version = "0.4.13" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", "tower-layer", "tower-service", ] @@ -6570,9 +6577,9 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -6580,6 +6587,7 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -6745,9 +6753,9 @@ dependencies = [ [[package]] name = "ulid" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34778c17965aa2a08913b57e1f34db9b4a63f5de31768b55bf20d2795f921259" +checksum = "04f903f293d11f31c0c29e4148f6dc0d033a7f80cebc0282bea147611667d289" dependencies = [ "getrandom", "rand", diff --git a/crates/dwn-rs-core/src/events/stream.rs b/crates/dwn-rs-core/src/events/stream.rs index 23fdf10..ce75c00 100644 --- a/crates/dwn-rs-core/src/events/stream.rs +++ b/crates/dwn-rs-core/src/events/stream.rs @@ -108,7 +108,7 @@ impl Handler for EventStream { let addr = _ctx.mailbox().address().try_upgrade().unwrap(); let sub = Subscription { - id: id.clone(), + subscription_id: SubscriptionID { id: id.clone() }, close: Box::new(make_close_task(ns.clone(), id.clone(), addr)), }; @@ -140,10 +140,14 @@ impl Handler for EventStream { _ctx.stop_all(); } } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct SubscriptionID { + pub id: String, +} #[allow(clippy::type_complexity)] pub struct Subscription { - pub id: String, + pub subscription_id: SubscriptionID, pub close: Box< dyn Fn() -> Pin> + Send>> + Send @@ -250,7 +254,7 @@ mod test { .instrument(tracing::info_span!("subscribe")) .await .unwrap(); - assert_eq!(sub.id, sub_id); + assert_eq!(sub.subscription_id.id, sub_id); let emit = addr .send(Emit { diff --git a/crates/dwn-rs-wasm/src/events.rs b/crates/dwn-rs-wasm/src/events.rs index 682912d..2a20e7d 100644 --- a/crates/dwn-rs-wasm/src/events.rs +++ b/crates/dwn-rs-wasm/src/events.rs @@ -43,7 +43,7 @@ impl TryFrom for EventSubscription { fn try_from(value: Subscription) -> Result { let obj: EventSubscription = JsCast::unchecked_into(Object::new()); - Reflect::set(&obj, &"id".into(), &value.id.into())?; + Reflect::set(&obj, &"id".into(), &value.subscription_id.id.into())?; Reflect::set( &obj, &"close".into(), From 69224d2689adbada7aa9e49cde0fd75dafca242e Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Sun, 6 Oct 2024 20:41:37 -0300 Subject: [PATCH 04/23] feat: jsonrpc client implementation with http transport --- .../src/interfaces/messages/fields.rs | 13 ++ .../src/interfaces/messages/mod.rs | 36 ++++ crates/dwn-rs-remote/Cargo.toml | 27 ++- crates/dwn-rs-remote/src/client.rs | 86 +++++++++ crates/dwn-rs-remote/src/errors.rs | 182 +----------------- .../src/{jsonrpc.rs => jsonrpc/client.rs} | 141 ++++++++++++-- crates/dwn-rs-remote/src/jsonrpc/dwn.rs | 12 ++ crates/dwn-rs-remote/src/jsonrpc/errors.rs | 173 +++++++++++++++++ crates/dwn-rs-remote/src/jsonrpc/http.rs | 139 +++++++++++++ crates/dwn-rs-remote/src/jsonrpc/mod.rs | 8 + 10 files changed, 615 insertions(+), 202 deletions(-) create mode 100644 crates/dwn-rs-remote/src/client.rs rename crates/dwn-rs-remote/src/{jsonrpc.rs => jsonrpc/client.rs} (59%) create mode 100644 crates/dwn-rs-remote/src/jsonrpc/dwn.rs create mode 100644 crates/dwn-rs-remote/src/jsonrpc/errors.rs create mode 100644 crates/dwn-rs-remote/src/jsonrpc/http.rs create mode 100644 crates/dwn-rs-remote/src/jsonrpc/mod.rs diff --git a/crates/dwn-rs-core/src/interfaces/messages/fields.rs b/crates/dwn-rs-core/src/interfaces/messages/fields.rs index 75b0ed7..f018ee3 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/fields.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/fields.rs @@ -6,11 +6,14 @@ use crate::auth::{ jws::JWS, }; +use super::Message; + #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] #[serde(untagged)] pub enum Fields { EncodedWrite(EncodedWriteField), Write(WriteFields), + InitialWriteField(InitialWriteField), Authorization(Authorization), AuthorizationDelegatedGrant(AuthorizationDelegatedGrantFields), } @@ -45,6 +48,16 @@ pub struct EncodedWriteField { pub encoded_data: Option, } +// InitialWriteField represents the RecordsWrite interface method response that includes +// the `initialWrite` data field if the original record was not the initial write. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct InitialWriteField { + #[serde(flatten)] + pub write_fields: EncodedWriteField, + #[serde(rename = "initialWrite", skip_serializing_if = "Option::is_none")] + pub initial_write: Option>, +} + /// EncryptionAlgorithm represents the encryption algorithm used for encrypting records. Currently /// A256CTR is the only supported algorithm. #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] diff --git a/crates/dwn-rs-core/src/interfaces/messages/mod.rs b/crates/dwn-rs-core/src/interfaces/messages/mod.rs index 0b24b85..dd99f1d 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/mod.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/mod.rs @@ -4,7 +4,11 @@ pub mod fields; pub use descriptors::Descriptor; pub use fields::Fields; +use fields::InitialWriteField; use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +use crate::{Cursor, SubscriptionID}; #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] pub struct Message { @@ -12,3 +16,35 @@ pub struct Message { #[serde(flatten)] pub fields: Fields, // Fields should be an Enum representing possible fields } + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum ResponseEntries { + Message(Message), + String(String), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum Record {} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct ReadReplyEntry { + pub cid: cid::Cid, + message: Message, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Status { + pub code: i32, + pub detail: String, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Response { + pub status: Status, + pub entries: Option>, + pub entry: Option, + pub record: Option, + pub cursor: Option, + pub subscription: Option, +} diff --git a/crates/dwn-rs-remote/Cargo.toml b/crates/dwn-rs-remote/Cargo.toml index f59b7a1..98cd601 100644 --- a/crates/dwn-rs-remote/Cargo.toml +++ b/crates/dwn-rs-remote/Cargo.toml @@ -9,10 +9,29 @@ description = "Remote SDK components for dwn-rs" crate-type = ["cdylib", "rlib"] [dependencies] -reqwest = { version = "0.12.8", features = ["stream", "json", "gzip"] } - -dwn-rs-core = { path = "../dwn-rs-core", default-features = false } -futures-core = { version = "0.3.30", default-features = false, features = ["alloc"] } +reqwest = { version = "0.12.8", features = [ + "stream", + "json", + "gzip", + "deflate", +] } +futures-core = { version = "0.3.30", default-features = false, features = [ + "alloc", +] } +tokio = { version = "1.39.2", features = ["io-util", "rt", "macros", "rt-multi-thread"] } thiserror = "1.0.64" serde = "1.0.210" serde_json = "1.0.128" +tower = "0.5.1" +http = "1.1.0" +http-body = "1.0.1" +bytes = "1.7.2" +dwn-rs-core = { path = "../dwn-rs-core", default-features = false } +ulid = { version = "1.1.3", features = ["serde"] } +futures-util = "0.3.30" +tokio-util = { version = "0.7.12", features = ["io", "rt"] } +tracing = { version = "0.1.40", features = ["log-always"] } +serde_repr = "0.1.19" + +[dev-dependencies] +tracing-subscriber = "0.3.18" diff --git a/crates/dwn-rs-remote/src/client.rs b/crates/dwn-rs-remote/src/client.rs new file mode 100644 index 0000000..efb10c7 --- /dev/null +++ b/crates/dwn-rs-remote/src/client.rs @@ -0,0 +1,86 @@ +use crate::{ + errors::Result as ClientResult, + jsonrpc::{self, JSONRpcError}, + RemoteError, +}; + +use bytes::Bytes; +use futures_core::{stream::BoxStream, Stream, TryStream}; +use futures_util::StreamExt; +use tower::Service; + +use dwn_rs_core::{Message, Response as DWNResponse}; + +pub struct RemoteDWNInstance +where + T: Service<(jsonrpc::Request, Option)>, + S: TryStream + Send + 'static, + S::Error: Into>, + Bytes: From, +{ + rpc: jsonrpc::Client, +} + +impl RemoteDWNInstance +where + T: Service< + (jsonrpc::Request, Option), + Response = jsonrpc::Response<( + DWNResponse, + BoxStream<'static, Result>, + )>, + Error = jsonrpc::JSONRpcError, + >, + S: TryStream + Send + 'static, + S::Error: Into>, + Bytes: From, +{ + pub fn new(transport: T) -> ClientResult { + let rpc = jsonrpc::Client::new(transport); + + Ok(RemoteDWNInstance { rpc }) + } + + pub async fn process_message( + &mut self, + tenant: &str, + message: Message, + data: Option, + ) -> ClientResult<(DWNResponse, Option>>)> { + let res = self + .rpc + .request( + jsonrpc::dwn::PROCESS_MESSAGE, + jsonrpc::dwn::ProcessMessageParams { + target: tenant.to_string(), + message, + encoded_data: None, // Data is always sent as a a stream + }, + data, + ) + .await?; + + let (m, d) = match res.result { + jsonrpc::ResultError::Result(m) => Ok(m.reply), + jsonrpc::ResultError::Error(e) => Err(JSONRpcError::from(e)), + }?; + + let d = d.map(|d| match d { + Ok(d) => Ok(d), + Err(e) => Err(RemoteError::from(e)), + }); + + Ok((m, Some(d))) + } +} + +pub type RemoteHTTPDWNInstance = RemoteDWNInstance; +pub fn new_remote_http_dwn(url: String) -> ClientResult> +where + S: TryStream + Send + 'static, + S::Error: Into>, + Bytes: From, +{ + let transport = jsonrpc::HTTPTransport::new(url)?; + RemoteDWNInstance::new(transport) +} diff --git a/crates/dwn-rs-remote/src/errors.rs b/crates/dwn-rs-remote/src/errors.rs index b497142..ad8608e 100644 --- a/crates/dwn-rs-remote/src/errors.rs +++ b/crates/dwn-rs-remote/src/errors.rs @@ -1,6 +1,7 @@ -use serde::{ser::SerializeMap, Deserialize, Serialize}; use thiserror::Error; +use crate::jsonrpc::JSONRpcError; + pub type Result = std::result::Result; #[derive(Error, Debug)] @@ -11,181 +12,6 @@ pub enum RemoteError { IoError(#[from] std::io::Error), #[error("error: {0}")] Error(String), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct JSONRpcError { - #[serde(flatten)] - pub error: JSONRpcErrorCodes, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum JSONRpcErrorCodes { - InvalidRequest = -32600, - MethodNotFound = -32601, - InvalidParams = -32602, - InternalError = -32603, - ParseError = -32700, - - BadRequest = -50400, - Unauthorized = -50401, - Forbidden = -50403, -} - -// Serialize JSONRpcErrorCodes as { "code": number, "message": string } -impl Serialize for JSONRpcErrorCodes { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let code = *self as i64; - let message = match self { - JSONRpcErrorCodes::InvalidRequest => "Invalid Request", - JSONRpcErrorCodes::MethodNotFound => "Method Not Found", - JSONRpcErrorCodes::InvalidParams => "Invalid Params", - JSONRpcErrorCodes::InternalError => "Internal Error", - JSONRpcErrorCodes::ParseError => "Parse Error", - JSONRpcErrorCodes::BadRequest => "Bad Request", - JSONRpcErrorCodes::Unauthorized => "Unauthorized", - JSONRpcErrorCodes::Forbidden => "Forbidden", - }; - let mut map = serializer.serialize_map(Some(2))?; - map.serialize_entry("code", &code)?; - map.serialize_entry("message", message)?; - map.end() - } -} - -// Deserialize JSONRpcErrorCodes from { "code": number, "message": string } -// or numeric code -impl<'de> Deserialize<'de> for JSONRpcErrorCodes { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct JSONRpcErrorCodesVisitor; - - impl<'de> serde::de::Visitor<'de> for JSONRpcErrorCodesVisitor { - type Value = JSONRpcErrorCodes; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a JSONRpcErrorCodes object or a number") - } - - fn visit_map(self, mut map: A) -> std::result::Result - where - A: serde::de::MapAccess<'de>, - { - let mut code = None; - while let Some(key) = map.next_key::()? { - if key.as_str() == "code" { - code = Some(map.next_value()?); - } - } - match code { - Some(code) => match code { - -32600 => Ok(JSONRpcErrorCodes::InvalidRequest), - -32601 => Ok(JSONRpcErrorCodes::MethodNotFound), - -32602 => Ok(JSONRpcErrorCodes::InvalidParams), - -32603 => Ok(JSONRpcErrorCodes::InternalError), - -32700 => Ok(JSONRpcErrorCodes::ParseError), - -50400 => Ok(JSONRpcErrorCodes::BadRequest), - -50401 => Ok(JSONRpcErrorCodes::Unauthorized), - -50403 => Ok(JSONRpcErrorCodes::Forbidden), - _ => Err(serde::de::Error::custom(format!("unknown code: {}", code))), - }, - None => Err(serde::de::Error::missing_field("code")), - } - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - match v { - -32600 => Ok(JSONRpcErrorCodes::InvalidRequest), - -32601 => Ok(JSONRpcErrorCodes::MethodNotFound), - -32602 => Ok(JSONRpcErrorCodes::InvalidParams), - -32603 => Ok(JSONRpcErrorCodes::InternalError), - -32700 => Ok(JSONRpcErrorCodes::ParseError), - -50400 => Ok(JSONRpcErrorCodes::BadRequest), - -50401 => Ok(JSONRpcErrorCodes::Unauthorized), - -50403 => Ok(JSONRpcErrorCodes::Forbidden), - _ => Err(serde::de::Error::custom(format!("unknown code: {}", v))), - } - } - } - - deserializer.deserialize_any(JSONRpcErrorCodesVisitor) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_jsonrpc_errorcodes_serialization() { - let codes: Vec<(JSONRpcErrorCodes, &'static str)> = vec![ - ( - JSONRpcErrorCodes::InvalidRequest, - r#"{"code":-32600,"message":"Invalid Request"}"#, - ), - ( - JSONRpcErrorCodes::MethodNotFound, - r#"{"code":-32601,"message":"Method Not Found"}"#, - ), - ( - JSONRpcErrorCodes::InvalidParams, - r#"{"code":-32602,"message":"Invalid Params"}"#, - ), - ( - JSONRpcErrorCodes::InternalError, - r#"{"code":-32603,"message":"Internal Error"}"#, - ), - ( - JSONRpcErrorCodes::ParseError, - r#"{"code":-32700,"message":"Parse Error"}"#, - ), - ( - JSONRpcErrorCodes::BadRequest, - r#"{"code":-50400,"message":"Bad Request"}"#, - ), - ( - JSONRpcErrorCodes::Unauthorized, - r#"{"code":-50401,"message":"Unauthorized"}"#, - ), - ( - JSONRpcErrorCodes::Forbidden, - r#"{"code":-50403,"message":"Forbidden"}"#, - ), - ]; - - for (code, json) in codes { - let err: JSONRpcError = serde_json::from_str(json).unwrap(); - - let expected = JSONRpcError { - error: code, - data: None, - }; - - assert_eq!(err, expected); - } - } - - #[test] - fn test_jsonrpc_serialization() { - let error = JSONRpcError { - error: JSONRpcErrorCodes::InvalidRequest, - data: None, - }; - let json = serde_json::to_string(&error).unwrap(); - assert_eq!(json, r#"{"code":-32600,"message":"Invalid Request"}"#); - - let error: JSONRpcError = serde_json::from_str(&json).unwrap(); - assert_eq!(error.error, JSONRpcErrorCodes::InvalidRequest); - assert_eq!(error.data, None); - } + #[error("jsonrpc error: {0}")] + JSONRpcError(#[from] JSONRpcError), } diff --git a/crates/dwn-rs-remote/src/jsonrpc.rs b/crates/dwn-rs-remote/src/jsonrpc/client.rs similarity index 59% rename from crates/dwn-rs-remote/src/jsonrpc.rs rename to crates/dwn-rs-remote/src/jsonrpc/client.rs index e7e7f72..7cbdefb 100644 --- a/crates/dwn-rs-remote/src/jsonrpc.rs +++ b/crates/dwn-rs-remote/src/jsonrpc/client.rs @@ -1,6 +1,13 @@ -use serde::{Deserialize, Serialize}; +use std::{any::Any, fmt::Debug}; -use crate::JSONRpcError; +use bytes::Bytes; +use dwn_rs_core::Response as DWNResponse; +use futures_core::{stream::BoxStream, Stream, TryStream}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tower::Service; +use ulid::{Generator, Ulid}; + +use super::JSONRpcError; #[derive(Debug, Serialize, Deserialize, PartialEq)] enum Version { @@ -10,56 +17,149 @@ enum Version { #[derive(Debug, Serialize, Deserialize, PartialEq)] #[serde(untagged)] -enum ID { +pub enum ID { String(String), Number(i64), } +impl From for ID { + fn from(ulid: Ulid) -> Self { + Self::String(ulid.to_string()) + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SubscriptionRequest { id: ID, } #[derive(Debug, Serialize, Deserialize, PartialEq)] -struct Request { +pub struct Request { jsonrpc: Version, #[serde(skip_serializing_if = "Option::is_none")] id: Option, method: String, - params: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, #[serde(skip_serializing_if = "Option::is_none")] subscription: Option, } #[derive(Debug, Serialize, Deserialize, PartialEq)] -struct Result { - result: T, +pub struct ResultData { + pub reply: T, } #[derive(Debug, Serialize, Deserialize, PartialEq)] -struct Error { +pub struct Error { error: JSONRpcError, } +impl From for JSONRpcError { + fn from(error: Error) -> Self { + error.error + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq)] #[serde(untagged)] -enum ResultError { - Result(Result), +pub enum ResultError { + Result(ResultData), Error(Error), } // Define the JSONRPCResponse struct #[derive(Debug, Serialize, Deserialize, PartialEq)] -struct Response { +pub struct Response { jsonrpc: Version, - id: ID, - #[serde(flatten)] - result: ResultError, + pub id: ID, + pub result: ResultError, +} + +impl Response { + pub fn new_v2(id: I, reply: T) -> Self + where + I: Into, + { + Self { + jsonrpc: Version::V2, + id: id.into(), + result: ResultError::Result(ResultData { reply }), + } + } +} + +pub struct Client)>, S> { + ulid: Generator, + transport: T, + _phantom: std::marker::PhantomData, +} + +impl std::fmt::Debug for Client +where + T: Service<(Request, Option)> + Debug, + S: Stream, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("ulid", &self.ulid.type_id()) + .field("transport", &self.transport) + .finish() + } +} + +impl Client +where + T: Service< + (Request, Option), + Response = Response<(DWNResponse, BoxStream<'static, Result>)>, + Error = JSONRpcError, + >, + S: TryStream + Send + 'static, + S::Error: Into>, + Bytes: From, +{ + pub fn new(transport: T) -> Self { + let ulid = Generator::new(); + + Self { + ulid, + transport, + _phantom: std::marker::PhantomData, + } + } + + pub async fn request( + &mut self, + method: &'static str, + params: P, + data: Option, + ) -> Result< + Response<(DWNResponse, impl Stream>)>, + JSONRpcError, + > { + let id = Some(self.ulid.generate()?.into()); + + let jsonrpc = Version::V2; + let method = method.to_string(); + + let request = Request { + jsonrpc, + id, + method, + params: Some(serde_json::to_value(params)?), + subscription: None, + }; + + self.transport.call((request, data)).await + } } #[cfg(test)] mod test { - use crate::JSONRpcErrorCodes; + use serde_json::json; + + use crate::jsonrpc::JSONRpcErrorCodes; use super::*; @@ -78,7 +178,7 @@ mod test { jsonrpc: Version::V2, id: Some(ID::Number(1)), method: "test".to_string(), - params: Some(vec!["param1".to_string(), "param2".to_string()]), + params: Some(json!(vec!["param1".to_string(), "param2".to_string()])), subscription: None, }, expected: r#"{"jsonrpc":"2.0","id":1,"method":"test","params":["param1","param2"]}"#, @@ -88,7 +188,7 @@ mod test { jsonrpc: Version::V2, id: None, method: "test".to_string(), - params: Some(vec!["param1".to_string(), "param2".to_string()]), + params: Some(json!(vec!["param1".to_string(), "param2".to_string()])), subscription: None, }, expected: r#"{"jsonrpc":"2.0","method":"test","params":["param1","param2"]}"#, @@ -145,8 +245,8 @@ mod test { response: Response { jsonrpc: Version::V2, id: ID::Number(1), - result: ResultError::Result(Result { - result: "test".to_string(), + result: ResultError::Result(ResultData { + reply: "test".to_string(), }), }, expected: r#"{"jsonrpc":"2.0","id":1,"result":"test"}"#, @@ -157,7 +257,8 @@ mod test { id: ID::Number(1), result: ResultError::Error(Error { error: JSONRpcError { - error: JSONRpcErrorCodes::InvalidRequest, + code: JSONRpcErrorCodes::InvalidRequest, + message: "Invalid Request".to_string(), data: None, }, }), diff --git a/crates/dwn-rs-remote/src/jsonrpc/dwn.rs b/crates/dwn-rs-remote/src/jsonrpc/dwn.rs new file mode 100644 index 0000000..ab325c5 --- /dev/null +++ b/crates/dwn-rs-remote/src/jsonrpc/dwn.rs @@ -0,0 +1,12 @@ +use dwn_rs_core::Message; +use serde::{Deserialize, Serialize}; + +pub const PROCESS_MESSAGE: &str = "dwn.processMessage"; + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct ProcessMessageParams { + pub target: String, + pub message: Message, + #[serde(skip_serializing_if = "Option::is_none", rename = "encodedData")] + pub encoded_data: Option>, +} diff --git a/crates/dwn-rs-remote/src/jsonrpc/errors.rs b/crates/dwn-rs-remote/src/jsonrpc/errors.rs new file mode 100644 index 0000000..ff474d4 --- /dev/null +++ b/crates/dwn-rs-remote/src/jsonrpc/errors.rs @@ -0,0 +1,173 @@ +use std::{error::Error, fmt::Display}; + +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct JSONRpcError { + pub code: JSONRpcErrorCodes, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl Display for JSONRpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(data) = &self.data { + write!(f, "{:?}: {} ({})", self.code, self.message, data) + } else { + write!(f, "{:?}: {}", self.code, self.message) + } + } +} + +impl Error for JSONRpcError {} + +impl From for JSONRpcError { + fn from(err: reqwest::Error) -> Self { + JSONRpcError { + code: JSONRpcErrorCodes::InternalError, + message: err.to_string(), + data: None, + } + } +} + +impl From for JSONRpcError { + fn from(err: serde_json::Error) -> Self { + JSONRpcError { + code: JSONRpcErrorCodes::InternalError, + message: err.to_string(), + data: None, + } + } +} + +impl From for JSONRpcError { + fn from(err: ulid::MonotonicError) -> Self { + JSONRpcError { + code: JSONRpcErrorCodes::InternalError, + message: err.to_string(), + data: None, + } + } +} + +#[derive(Serialize_repr, Deserialize_repr, Debug, PartialEq, Eq, Clone, Copy)] +#[repr(i32)] +pub enum JSONRpcErrorCodes { + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + ParseError = -32700, + + BadRequest = -50400, + Unauthorized = -50401, + Forbidden = -50403, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jsonrpc_error_display() { + let error = JSONRpcError { + code: JSONRpcErrorCodes::InvalidRequest, + message: "Invalid Request".to_string(), + data: None, + }; + assert_eq!(error.to_string(), "InvalidRequest: Invalid Request"); + + let error = JSONRpcError { + code: JSONRpcErrorCodes::InvalidRequest, + message: "Invalid Request".to_string(), + data: Some(serde_json::json!({ "error": "test" })), + }; + assert_eq!( + error.to_string(), + "InvalidRequest: Invalid Request ({\"error\":\"test\"})" + ); + } + + #[tokio::test] + async fn test_jsonrpc_error_from() { + let err = reqwest::get("bad address").await.unwrap_err(); + let errstr = err.to_string(); + let jsonrpc_err: JSONRpcError = err.into(); + assert_eq!(jsonrpc_err.code, JSONRpcErrorCodes::InternalError); + assert_eq!(jsonrpc_err.message, errstr); + + let err = serde_json::from_str::<&str>("bad json").unwrap_err(); + let errstr = err.to_string(); + let jsonrpc_err: JSONRpcError = err.into(); + assert_eq!(jsonrpc_err.code, JSONRpcErrorCodes::InternalError); + assert_eq!(jsonrpc_err.message, errstr); + } + + #[test] + fn test_jsonrpc_errorcodes_serialization() { + let codes: Vec<(JSONRpcErrorCodes, &'static str)> = vec![ + ( + JSONRpcErrorCodes::InvalidRequest, + r#"{"code":-32600,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::MethodNotFound, + r#"{"code":-32601,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::InvalidParams, + r#"{"code":-32602,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::InternalError, + r#"{"code":-32603,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::ParseError, + r#"{"code":-32700,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::BadRequest, + r#"{"code":-50400,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::Unauthorized, + r#"{"code":-50401,"message":"error"}"#, + ), + ( + JSONRpcErrorCodes::Forbidden, + r#"{"code":-50403,"message":"error"}"#, + ), + ]; + + for (code, json) in codes { + let err: JSONRpcError = serde_json::from_str(json).unwrap(); + + let expected = JSONRpcError { + code, + message: "error".to_string(), + data: None, + }; + + assert_eq!(err, expected); + } + } + + #[test] + fn test_jsonrpc_serialization() { + let error = JSONRpcError { + code: JSONRpcErrorCodes::InvalidRequest, + message: "error".to_string(), + data: None, + }; + let json = serde_json::to_string(&error).unwrap(); + assert_eq!(json, r#"{"code":-32600,"message":"error"}"#); + + let error: JSONRpcError = serde_json::from_str(&json).unwrap(); + assert_eq!(error.code, JSONRpcErrorCodes::InvalidRequest); + assert_eq!(error.data, None); + } +} diff --git a/crates/dwn-rs-remote/src/jsonrpc/http.rs b/crates/dwn-rs-remote/src/jsonrpc/http.rs new file mode 100644 index 0000000..7deccb4 --- /dev/null +++ b/crates/dwn-rs-remote/src/jsonrpc/http.rs @@ -0,0 +1,139 @@ +use std::{future::Future, pin::Pin}; + +use bytes::Bytes; +use dwn_rs_core::Response as DWNResponse; +use futures_core::{stream::BoxStream, TryStream}; +use futures_util::TryStreamExt; +use http::header; +use serde_json::json; +use tower::Service; +use tracing::trace; + +use crate::jsonrpc::ResultError; + +use super::{JSONRpcError, JSONRpcErrorCodes, Request, Response}; + +pub const USER_AGENT: &str = concat!( + "rpc-", + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION"), + "(", + env!("CARGO_CRATE_NAME"), + ")" +); + +pub struct HTTPTransport { + url: String, + client: reqwest::Client, +} + +impl Service<(Request, Option)> for HTTPTransport +where + S: TryStream + Send + 'static, + S::Error: Into>, + Bytes: From, +{ + type Response = Response<(DWNResponse, BoxStream<'static, Result>)>; + type Error = JSONRpcError; + type Future = Pin> + Send>>; + + fn call(&mut self, request: (Request, Option)) -> Self::Future { + let url = self.url.clone(); + let mut client = self.client.clone(); + + Box::pin(async move { + let stream = request.1; + + let mut rb = client + .clone() + .request(http::method::Method::POST, url) + .header("dwn-request", json!(request.0).to_string()); + + if let Some(data) = stream { + let (lw, _) = data.size_hint(); + + rb = rb + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::TRANSFER_ENCODING, "chunked") + .header(header::CONTENT_LENGTH, lw.to_string()) + .body(reqwest::Body::wrap_stream(data)); + } + + let req = rb.build()?; + + let res = client.call(req).await?; + + trace!(?res, "Received response"); + + let resp = match res.headers().get("dwn-response") { + Some(h) => { + let resp = serde_json::from_slice::>(h.as_bytes())?; + let body = Box::pin( + res.bytes_stream() + .map_err(JSONRpcError::from) + .map_ok(|b| b) + .into_stream(), + ) + as BoxStream<'static, Result>; + + (resp, body) + } + None => { + let body = res.bytes().await?; + + let resp: Response = serde_json::from_slice(&body)?; + trace!(?resp, "Response in body"); + let empty = + Box::pin(futures_util::stream::empty::>()) + as BoxStream<'static, Result>; + + (resp, empty) + } + }; + + let msg = match resp.0.result { + ResultError::Result(m) => Ok(m.reply), + ResultError::Error(e) => Err(JSONRpcError::from(e)), + }?; + let body = resp.1; + + Ok(Response::new_v2(resp.0.id, (msg, body))) + }) + } + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.client.poll_ready(cx).map_err(|e| JSONRpcError { + code: JSONRpcErrorCodes::InternalError, + message: e.to_string(), + data: None, + }) + } +} + +impl HTTPTransport { + pub fn new(uri: String) -> Result { + let c = reqwest::ClientBuilder::new() + .user_agent(USER_AGENT) + .default_headers(header::HeaderMap::from_iter(vec![( + header::CONTENT_TYPE, + "application/json-rpc".parse().unwrap(), + )])) + .deflate(true) + .gzip(true) + .build() + .map_err(|e| JSONRpcError { + code: JSONRpcErrorCodes::InternalError, + message: e.to_string(), + data: None, + }); + + Ok(Self { + url: uri, + client: c?, + }) + } +} diff --git a/crates/dwn-rs-remote/src/jsonrpc/mod.rs b/crates/dwn-rs-remote/src/jsonrpc/mod.rs new file mode 100644 index 0000000..2aa0ae8 --- /dev/null +++ b/crates/dwn-rs-remote/src/jsonrpc/mod.rs @@ -0,0 +1,8 @@ +pub mod client; +pub(crate) mod dwn; +mod errors; +mod http; + +pub use client::*; +pub use errors::*; +pub use http::*; From 5f34d59eb697b3e11cb142b5514a37b0ed09849e Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 22 Oct 2024 21:25:46 -0300 Subject: [PATCH 05/23] Add MessageDescriptor #[descriptor] macro, and add to existing Descriptor definitions --- crates/dwn-rs-core/Cargo.toml | 1 + crates/dwn-rs-core/src/auth/authorization.rs | 10 +- .../messages/descriptors/general.rs | 138 ++++++++ .../messages/descriptors/messages.rs | 19 +- .../interfaces/messages/descriptors/mod.rs | 61 ++-- .../messages/descriptors/protocols.rs | 11 +- .../messages/descriptors/records.rs | 18 +- .../src/interfaces/messages/fields.rs | 300 ++++++++++++------ crates/dwn-rs-message-derive/Cargo.toml | 12 + .../src/derive/descriptor.rs | 256 +++++++++++++++ .../dwn-rs-message-derive/src/derive/mod.rs | 1 + crates/dwn-rs-message-derive/src/lib.rs | 13 + 12 files changed, 697 insertions(+), 143 deletions(-) create mode 100644 crates/dwn-rs-core/src/interfaces/messages/descriptors/general.rs create mode 100644 crates/dwn-rs-message-derive/Cargo.toml create mode 100644 crates/dwn-rs-message-derive/src/derive/descriptor.rs create mode 100644 crates/dwn-rs-message-derive/src/derive/mod.rs create mode 100644 crates/dwn-rs-message-derive/src/lib.rs diff --git a/crates/dwn-rs-core/Cargo.toml b/crates/dwn-rs-core/Cargo.toml index 2104b6f..0cc2d03 100644 --- a/crates/dwn-rs-core/Cargo.toml +++ b/crates/dwn-rs-core/Cargo.toml @@ -42,3 +42,4 @@ bytes = "1.8.0" [dev-dependencies] serde_json = "1.0.113" +dwn-rs-message-derive = { path = "../dwn-rs-message-derive" } diff --git a/crates/dwn-rs-core/src/auth/authorization.rs b/crates/dwn-rs-core/src/auth/authorization.rs index 814aac2..b6502cc 100644 --- a/crates/dwn-rs-core/src/auth/authorization.rs +++ b/crates/dwn-rs-core/src/auth/authorization.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::Message; +use crate::{descriptors::records::WriteDescriptor, fields::MessageFields, Message}; use super::jws::JWS; @@ -11,6 +11,8 @@ pub struct Authorization { pub owner_signature: Option, } +impl MessageFields for Authorization {} + #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] pub struct AuthorizationDelegatedGrant { pub signature: JWS, @@ -18,7 +20,7 @@ pub struct AuthorizationDelegatedGrant { rename = "authorDelegatedGrant", skip_serializing_if = "Option::is_none" )] - pub author_delegated_grant: Option>, + pub author_delegated_grant: Option>>, } #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] @@ -28,12 +30,12 @@ pub struct AuthorizationOwner { rename = "authorDelegatedGrant", skip_serializing_if = "Option::is_none" )] - pub author_delegated_grant: Option>, + pub author_delegated_grant: Option>>, #[serde(rename = "ownerSignature", skip_serializing_if = "Option::is_none")] pub owner_signature: Option, #[serde( rename = "ownerDelegatedGrant", skip_serializing_if = "Option::is_none" )] - pub owner_delegated_grant: Option>, + pub owner_delegated_grant: Option>>, } diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/general.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/general.rs new file mode 100644 index 0000000..9bb348d --- /dev/null +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/general.rs @@ -0,0 +1,138 @@ +use serde::{Deserialize, Serialize}; + +use crate::Fields; + +use super::{ + super::descriptors::{ + ConfigureDescriptor, DeleteDescriptor, MessagesQueryDescriptor, MessagesReadDescriptor, + MessagesSubscribeDescriptor, ProtocolQueryDescriptor, ReadDescriptor, + RecordsQueryDescriptor, RecordsWriteDescriptor, SubscribeDescriptor, + }, + MessageDescriptor, CONFIGURE, DELETE, MESSAGES, PROTOCOLS, QUERY, READ, RECORDS, SUBSCRIBE, + WRITE, +}; + +/// Interfaces represent the different Decentralized Web Node message interface types. +/// See for more information. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum Descriptor { + Records(Records), + Protocols(Protocols), + Messages(Messages), +} + +impl MessageDescriptor for Descriptor { + type Fields = Fields; + + fn interface(&self) -> &'static str { + match self { + Descriptor::Records(_) => RECORDS, + Descriptor::Protocols(_) => PROTOCOLS, + Descriptor::Messages(_) => MESSAGES, + } + } + + fn method(&self) -> &'static str { + match self { + Descriptor::Records(records) => records.method(), + Descriptor::Protocols(protocols) => protocols.method(), + Descriptor::Messages(messages) => messages.method(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum Records { + Read(ReadDescriptor), + Query(RecordsQueryDescriptor), + Write(RecordsWriteDescriptor), + Delete(DeleteDescriptor), + Subscribe(SubscribeDescriptor), +} + +impl MessageDescriptor for Records { + type Fields = Fields; + + fn interface(&self) -> &'static str { + RECORDS + } + + fn method(&self) -> &'static str { + match self { + Records::Read(_) => READ, + Records::Query(_) => QUERY, + Records::Write(_) => WRITE, + Records::Delete(_) => DELETE, + Records::Subscribe(_) => SUBSCRIBE, + } + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum Protocols { + Configure(ConfigureDescriptor), + Query(ProtocolQueryDescriptor), +} + +impl MessageDescriptor for Protocols { + type Fields = Fields; + + fn interface(&self) -> &'static str { + PROTOCOLS + } + + fn method(&self) -> &'static str { + match self { + Protocols::Configure(_) => CONFIGURE, + Protocols::Query(_) => QUERY, + } + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum Messages { + Read(MessagesReadDescriptor), + Query(MessagesQueryDescriptor), + Subscribe(MessagesSubscribeDescriptor), +} + +impl MessageDescriptor for Messages { + type Fields = Fields; + fn interface(&self) -> &'static str { + MESSAGES + } + + fn method(&self) -> &'static str { + match self { + Messages::Read(_) => READ, + Messages::Query(_) => QUERY, + Messages::Subscribe(_) => SUBSCRIBE, + } + } +} + +#[cfg(test)] +mod test { + use serde_json::json; + + use crate::Filters; + + #[test] + fn test_descriptor_serialize() { + use super::*; + + let now = chrono::Utc::now(); + let desc = Descriptor::Records(Records::Read(ReadDescriptor { + message_timestamp: now, + filter: Filters::default(), + })); + let serialized = json!(&desc); + let expected = json!({"interface": RECORDS,"method": READ, "messageTimestamp": now, "filter": Filters::default()}); + + assert_eq!(serialized, expected); + } +} diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/messages.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/messages.rs index c288b17..213fa45 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/descriptors/messages.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/messages.rs @@ -1,8 +1,9 @@ +use crate::descriptors::MessageDescriptor; +use crate::interfaces::messages::descriptors::{MESSAGES, QUERY, READ, SUBSCRIBE}; use cid::Cid; -use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; +use dwn_rs_message_derive::descriptor; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[descriptor(interface = MESSAGES, method = READ, fields = crate::auth::Authorization)] pub struct ReadDescriptor { #[serde( rename = "messageTimestamp", @@ -13,8 +14,7 @@ pub struct ReadDescriptor { pub message_cid: Option, } -#[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[descriptor(interface = MESSAGES, method = QUERY, fields = crate::auth::Authorization)] pub struct QueryDescriptor { #[serde( rename = "messageTimestamp", @@ -27,8 +27,7 @@ pub struct QueryDescriptor { pub cursor: Option, } -#[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[descriptor(interface = MESSAGES, method = SUBSCRIBE, fields = crate::auth::Authorization)] pub struct SubscribeDescriptor { #[serde( rename = "messageTimestamp", @@ -65,6 +64,8 @@ mod test { let json = json!({ "messageTimestamp": message_timestamp, "messageCid": message_cid, + "interface": MESSAGES, + "method": READ, }); assert_eq!(serde_json::to_value(&descriptor).unwrap(), json); assert_eq!( @@ -93,6 +94,8 @@ mod test { "messageTimestamp": message_timestamp, "filters": [crate::Filters::default()], "cursor": cursor, + "interface": MESSAGES, + "method": QUERY, }); assert_eq!(serde_json::to_value(&descriptor).unwrap(), json); assert_eq!( @@ -118,6 +121,8 @@ mod test { let json = json!({ "messageTimestamp": message_timestamp, "filters": [crate::Filters::default()], + "interface": MESSAGES, + "method": SUBSCRIBE }); assert_eq!(serde_json::to_value(&descriptor).unwrap(), json); assert_eq!( diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/mod.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/mod.rs index 44c6ee5..14dae72 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/descriptors/mod.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/mod.rs @@ -1,7 +1,12 @@ +pub mod general; pub mod messages; pub mod protocols; pub mod records; +use std::fmt::Debug; + +pub use general::*; + pub use messages::{ QueryDescriptor as MessagesQueryDescriptor, ReadDescriptor as MessagesReadDescriptor, SubscribeDescriptor as MessagesSubscribeDescriptor, @@ -11,39 +16,35 @@ pub use records::{ DeleteDescriptor, QueryDescriptor as RecordsQueryDescriptor, ReadDescriptor, SubscribeDescriptor, WriteDescriptor as RecordsWriteDescriptor, }; +use serde::{de::DeserializeOwned, Serialize}; -use serde::{Deserialize, Serialize}; +use super::fields::MessageFields; -/// Interfaces represent the different Decentralized Web Node message interface types. -/// See for more information. -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -#[serde(tag = "interface")] -pub enum Descriptor { - Records(Records), - Protocols(Protocols), - Messages(Messages), -} +pub const RECORDS: &str = "Records"; +pub const PROTOCOLS: &str = "Protocols"; +pub const MESSAGES: &str = "Messages"; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -#[serde(tag = "method")] -pub enum Records { - Read(ReadDescriptor), - Query(RecordsQueryDescriptor), - Write(RecordsWriteDescriptor), - Delete(DeleteDescriptor), - Subscribe(SubscribeDescriptor), -} +pub const READ: &str = "Read"; +pub const QUERY: &str = "Query"; +pub const WRITE: &str = "Write"; +pub const DELETE: &str = "Delete"; +pub const SUBSCRIBE: &str = "Subscribe"; +pub const CONFIGURE: &str = "Configure"; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -#[serde(tag = "method")] -pub enum Protocols { - Configure(ConfigureDescriptor), - Query(ProtocolQueryDescriptor), -} +/// MessageDescriptor is a trait that all message descriptors must implement. +/// It provides the interface and method for the message descriptor. The generic `Descriptor` +/// implements this trait for use when the concrete type is not known. Concrete Descriptor types +/// implement this trait directly (or use the derive macro). +pub trait MessageDescriptor: Serialize + DeserializeOwned + PartialEq { + type Fields: MessageFields + + Serialize + + DeserializeOwned + + Debug + + PartialEq + + Send + + Sync + + Clone; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub enum Messages { - Read(MessagesReadDescriptor), - Query(MessagesQueryDescriptor), - Subscribe(MessagesSubscribeDescriptor), + fn interface(&self) -> &'static str; + fn method(&self) -> &'static str; } diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs index effeab5..0aac84c 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs @@ -5,7 +5,11 @@ use serde_with::skip_serializing_none; use ssi_dids_core::DIDBuf; use ssi_jwk::JWK; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +use crate::descriptors::MessageDescriptor; +use crate::interfaces::messages::descriptors::{CONFIGURE, PROTOCOLS, QUERY}; +use dwn_rs_message_derive::descriptor; + +#[descriptor(interface = PROTOCOLS, method = CONFIGURE, fields = crate::fields::AuthorizationDelegatedGrantFields)] pub struct ConfigureDescriptor { #[serde(rename = "messageTimestamp")] pub message_timestamp: chrono::DateTime, @@ -217,8 +221,7 @@ pub struct TagContains { pub max_length: Option, } -#[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[descriptor(interface = PROTOCOLS , method = QUERY, fields = crate::auth::Authorization)] pub struct QueryDescriptor { #[serde(rename = "message_timestamp")] pub message_timestamp: chrono::DateTime, @@ -259,6 +262,8 @@ mod test { "types": {}, "structure": {}, }, + "interface": PROTOCOLS, + "method": CONFIGURE, }); assert_eq!(serde_json::to_value(&descriptor).unwrap(), json); assert_eq!( diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/records.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/records.rs index b49dbf7..07b1704 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/descriptors/records.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/records.rs @@ -1,24 +1,26 @@ +use crate::descriptors::MessageDescriptor; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; +use crate::interfaces::messages::descriptors::{DELETE, QUERY, READ, RECORDS, SUBSCRIBE, WRITE}; use crate::{MapValue, Pagination}; +use dwn_rs_message_derive::descriptor; /// ReadDescriptor represents the RecordsRead interface method for reading a given /// record by ID. -#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] +#[descriptor(interface = RECORDS, method = READ, fields = crate::fields::AuthorizationDelegatedGrantFields)] pub struct ReadDescriptor { #[serde( rename = "messageTimestamp", serialize_with = "crate::ser::serialize_datetime" )] pub message_timestamp: chrono::DateTime, - #[serde(rename = "recordId")] - pub record_id: String, + pub filter: crate::Filters, } // QueryDescriptor represents the RecordsQuery interface method for querying records. #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] +#[descriptor(interface = RECORDS, method = QUERY, fields = crate::auth::Authorization)] pub struct QueryDescriptor { #[serde( rename = "messageTimestamp", @@ -49,7 +51,7 @@ pub enum DateSort { /// It can be represented with either no additional fields (`()`), or additional descriptor fields, /// as in the case for `encodedData`. #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] +#[descriptor(interface = RECORDS, method = WRITE, fields = crate::fields::WriteFields)] pub struct WriteDescriptor { pub protocol: Option, #[serde(rename = "protocolPath")] @@ -83,7 +85,7 @@ pub struct WriteDescriptor { pub data_format: String, } -#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] +#[descriptor(interface = RECORDS, method = SUBSCRIBE, fields = crate::fields::AuthorizationDelegatedGrantFields)] pub struct SubscribeDescriptor { #[serde( rename = "messageTimestamp", @@ -93,7 +95,7 @@ pub struct SubscribeDescriptor { pub filter: crate::Filters, } -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[descriptor(interface = RECORDS, method = DELETE, fields = crate::auth::Authorization)] pub struct DeleteDescriptor { #[serde( rename = "messageTimestamp", @@ -124,7 +126,7 @@ mod test { let rd = ReadDescriptor { message_timestamp, - record_id: "test".to_string(), + filter: crate::Filters::default(), }; let ser = serde_json::to_string(&rd).unwrap(); diff --git a/crates/dwn-rs-core/src/interfaces/messages/fields.rs b/crates/dwn-rs-core/src/interfaces/messages/fields.rs index f018ee3..5d96830 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/fields.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/fields.rs @@ -1,23 +1,57 @@ use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; use ssi_jwk::JWK; -use crate::auth::{ - authorization::{Authorization, AuthorizationDelegatedGrant, AuthorizationOwner}, - jws::JWS, +use crate::{ + auth::{ + authorization::{Authorization, AuthorizationDelegatedGrant, AuthorizationOwner}, + jws::JWS, + }, + Value, }; -use super::Message; +use super::{descriptors::records::WriteDescriptor, Message}; + +/// MessageFields is a trait that all message fields must implement. +/// It provides the interface and method for the message fields. The generic `Fields` +/// implements this trait for use when the concrete type is not known. +pub trait MessageFields { + /// encoded_data returns the encoded data for the message fields (if any), + /// and removes the encoded data fields from the Message Fields. + fn encoded_data(&mut self) -> Option { + None + } + + // encode_data encodes the data for the message + fn encode_data(&mut self, _data: Value) { + // no-op + } +} #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] #[serde(untagged)] pub enum Fields { - EncodedWrite(EncodedWriteField), Write(WriteFields), InitialWriteField(InitialWriteField), Authorization(Authorization), AuthorizationDelegatedGrant(AuthorizationDelegatedGrantFields), } +impl MessageFields for Fields { + fn encoded_data(&mut self) -> Option { + match self { + Fields::Write(encoded_write) => encoded_write.encoded_data.take().map(Value::String), + _ => None, + } + } + + fn encode_data(&mut self, data: Value) { + if let Fields::Write(encoded_write) = self { + encoded_write.encoded_data = Some(data.to_string()); + } + } +} + /// ReadFields are the message fields for the RecordsRead interface method. #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] pub struct AuthorizationDelegatedGrantFields { @@ -25,37 +59,47 @@ pub struct AuthorizationDelegatedGrantFields { pub authorization: Option, } +impl MessageFields for AuthorizationDelegatedGrantFields {} + +// InitialWriteField represents the RecordsWrite interface method response that includes +// the `initialWrite` data field if the original record was not the initial write. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct InitialWriteField { + #[serde(flatten)] + pub write_fields: WriteFields, + #[serde(rename = "initialWrite", skip_serializing_if = "Option::is_none")] + pub initial_write: Option>>, +} + +impl MessageFields for InitialWriteField {} + +#[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] pub struct WriteFields { - #[serde(skip_serializing_if = "Option::is_none")] - pub authorization: Option, - #[serde(rename = "recordId", skip_serializing_if = "Option::is_none")] + pub authorization: AuthorizationOwner, + #[serde(rename = "recordId")] pub record_id: Option, - #[serde(rename = "contextId", skip_serializing_if = "Option::is_none")] + #[serde(rename = "contextId")] pub context_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub encryption: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub attestation: Option, -} -/// EncodedWriteField represents the RecordsWrite interface method for writing a record to -/// the DWN using the `encodedData` field for records data that is encoded in messages directly. -#[derive(Serialize, Deserialize, Default, Debug, PartialEq, Clone)] -pub struct EncodedWriteField { - #[serde(flatten)] - pub write_fields: WriteFields, - #[serde(rename = "encodedData", skip_serializing_if = "Option::is_none")] + #[serde(rename = "encodedData")] pub encoded_data: Option, } -// InitialWriteField represents the RecordsWrite interface method response that includes -// the `initialWrite` data field if the original record was not the initial write. -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct InitialWriteField { - #[serde(flatten)] - pub write_fields: EncodedWriteField, - #[serde(rename = "initialWrite", skip_serializing_if = "Option::is_none")] - pub initial_write: Option>, +impl MessageFields for WriteFields { + fn encoded_data(&mut self) -> Option { + Some( + self.encoded_data + .take() + .map(Value::String) + .unwrap_or(Value::Null), + ) + } + + fn encode_data(&mut self, data: Value) { + self.encoded_data = Some(data.to_string()); + } } /// EncryptionAlgorithm represents the encryption algorithm used for encrypting records. Currently @@ -118,17 +162,20 @@ pub struct Encryption { #[cfg(test)] mod tests { - use crate::{auth::jws::SignatureEntry, descriptors::Records, Descriptor, Message}; + use crate::descriptors::{RecordsWriteDescriptor, RECORDS, WRITE}; + use crate::{auth::jws::SignatureEntry, Message}; use super::*; - use serde_json; + use serde_json::{self, json}; use ssi_jwk::JWK; + use tracing_test::traced_test; #[test] fn test_fields_serialization() { use serde_json::json; let jwk = JWK::generate_ed25519().unwrap(); + let now = chrono::Utc::now(); // Define your test cases as structs or tuples. struct TestCase { @@ -139,27 +186,11 @@ mod tests { // Populate the vector with the test cases. let tests = vec![ TestCase { - fields: Fields::EncodedWrite(EncodedWriteField { - write_fields: WriteFields { - record_id: Some("record_id".to_string()), - context_id: Some("context_id".to_string()), - authorization: None, - encryption: Some(Encryption { - algorithm: EncryptionAlgorithm::A256CTR, - initialization_vector: "initialization_vector".to_string(), - key_encryption: vec![KeyEncryption { - algorithm: KeyEncryptionAlgorithm::EciesEs256k, - root_key_id: "root_key_id".to_string(), - derivation_scheme: DerivationScheme::DataFormats, - derived_public_key: None, - encrypted_key: "encrypted_key".to_string(), - initialization_vector: "initialization_vector".to_string(), - ephemeral_public_key: jwk.clone(), - message_authentication_code: "message_authentication_code" - .to_string(), - }], - }), - attestation: Some(JWS { + fields: Fields::Write(WriteFields { + record_id: Some("record_id".to_string()), + context_id: Some("context_id".to_string()), + authorization: AuthorizationOwner { + signature: JWS { payload: Some("payload".to_string()), signatures: Some(vec![SignatureEntry { payload: Some("payload".to_string()), @@ -168,8 +199,33 @@ mod tests { ..Default::default() }]), ..Default::default() - }), + }, + ..Default::default() }, + encryption: Some(Encryption { + algorithm: EncryptionAlgorithm::A256CTR, + initialization_vector: "initialization_vector".to_string(), + key_encryption: vec![KeyEncryption { + algorithm: KeyEncryptionAlgorithm::EciesEs256k, + root_key_id: "root_key_id".to_string(), + derivation_scheme: DerivationScheme::DataFormats, + derived_public_key: None, + encrypted_key: "encrypted_key".to_string(), + initialization_vector: "initialization_vector".to_string(), + ephemeral_public_key: jwk.clone(), + message_authentication_code: "message_authentication_code".to_string(), + }], + }), + attestation: Some(JWS { + payload: Some("payload".to_string()), + signatures: Some(vec![SignatureEntry { + payload: Some("payload".to_string()), + protected: Some("protected".to_string()), + signature: Some("signature".to_string()), + ..Default::default() + }]), + ..Default::default() + }), encoded_data: Some("encoded_data".to_string()), }), expected_json: format!( @@ -192,6 +248,18 @@ mod tests { }} ] }}, + "authorization": {{ + "signature": {{ + "payload": "payload", + "signatures": [ + {{ + "payload": "payload", + "protected": "protected", + "signature": "signature" + }} + ] + }} + }}, "attestation": {{ "payload": "payload", "signatures": [ @@ -211,26 +279,54 @@ mod tests { authorization: Some(AuthorizationDelegatedGrant { // fill in all the fields with fake details signature: JWS::default(), - author_delegated_grant: Some(Box::new(Message { - descriptor: Descriptor::Records(Records::Write( - crate::descriptors::RecordsWriteDescriptor::default(), - )), - fields: Fields::Write(WriteFields::default()), + author_delegated_grant: Some(Box::new(Message:: { + descriptor: crate::descriptors::RecordsWriteDescriptor { + data_cid: "data_cid".to_string(), + data_size: 0, + date_created: now, + message_timestamp: now, + data_format: "data_format".to_string(), + protocol: None, + recipient: None, + schema: None, + tags: None, + protocol_path: None, + parent_id: None, + published: None, + date_published: None, + }, + fields: WriteFields { + record_id: Some("record".to_string()), + context_id: Some("context".to_string()), + authorization: AuthorizationOwner::default(), + encryption: None, + attestation: None, + encoded_data: None, + }, })), }), }), expected_json: json!({ - "authorization": { - "signature": JWS::default(), - "authorDelegatedGrant": { - "descriptor": Descriptor::Records(Records::Write( - crate::descriptors::RecordsWriteDescriptor::default() - )), - // there are no Fields on this request for test serialization - } + "authorization": { + "signature": JWS::default(), + "authorDelegatedGrant": { + "descriptor": json!({ + "interface": RECORDS, + "method": WRITE, + "dataCid": "data_cid", + "dataSize": 0, + "dateCreated": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true), + "messageTimestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true), + "dataFormat": "data_format", + }), + "authorization": { + "signature": JWS::default(), + }, + "recordId": "record", + "contextId": "context", }, - } - ) + }, + }) .to_string(), }, ]; @@ -267,6 +363,18 @@ mod tests { }} ] }}, + "authorization": {{ + "signature": {{ + "payload": "payload", + "signatures": [ + {{ + "payload": "payload", + "protected": "protected", + "signature": "signature" + }} + ] + }} + }}, "attestation": {{ "payload": "payload", "signatures": [ @@ -283,16 +391,20 @@ mod tests { ); let fields: Fields = serde_json::from_str(json).unwrap(); + println!("{:?}", fields); match fields { - Fields::EncodedWrite(EncodedWriteField { - write_fields, + Fields::Write(WriteFields { + record_id, + context_id, + encryption, + attestation, encoded_data, + .. }) => { - assert_eq!(write_fields.record_id, Some("record_id".to_string())); - assert_eq!(write_fields.context_id, Some("context_id".to_string())); - assert!(write_fields.authorization.is_none()); - assert!(write_fields.encryption.is_some()); - assert!(write_fields.attestation.is_some()); + assert_eq!(record_id, Some("record_id".to_string())); + assert_eq!(context_id, Some("context_id".to_string())); + assert!(encryption.is_some()); + assert!(attestation.is_some()); assert_eq!(encoded_data, Some("encoded_data".to_string())); } _ => unreachable!(), @@ -301,36 +413,42 @@ mod tests { #[test] fn test_fields_serialization_with_null_fields() { - let fields = Fields::EncodedWrite(EncodedWriteField { - write_fields: WriteFields { - record_id: None, - context_id: None, - authorization: None, - encryption: None, - attestation: None, - }, + let fields = Fields::Write(WriteFields { + record_id: Some("record_id".to_string()), + context_id: None, + authorization: AuthorizationOwner::default(), + encryption: None, + attestation: None, encoded_data: None, }); - let json = serde_json::to_string(&fields).unwrap(); - assert_eq!("{}", json); + let json = + serde_json::from_str::(&serde_json::to_string(&fields).unwrap()) + .unwrap(); + let expected = json!({"recordId":"record_id","authorization":{"signature":{}}}); + assert_eq!(expected, json); } #[test] + #[traced_test] fn test_fields_deserialization_with_null_fields() { - let json = "{}"; + let json = r#"{"recordId":"test", "authorization": {"signature":{}}}"#; let fields: Fields = serde_json::from_str(json).unwrap(); + match fields { - Fields::EncodedWrite(EncodedWriteField { - write_fields, + Fields::Write(WriteFields { + record_id, + context_id, + encryption, + attestation, encoded_data, + .. }) => { - assert!(write_fields.record_id.is_none()); - assert!(write_fields.context_id.is_none()); - assert!(write_fields.authorization.is_none()); - assert!(write_fields.encryption.is_none()); - assert!(write_fields.attestation.is_none()); + assert_eq!(record_id, Some("test".to_string())); + assert!(context_id.is_none()); + assert!(encryption.is_none()); + assert!(attestation.is_none()); assert!(encoded_data.is_none()); } _ => unreachable!(), diff --git a/crates/dwn-rs-message-derive/Cargo.toml b/crates/dwn-rs-message-derive/Cargo.toml new file mode 100644 index 0000000..e03cda8 --- /dev/null +++ b/crates/dwn-rs-message-derive/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "dwn-rs-message-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "*", features = ["full"] } +quote = "*" +proc-macro2 = "*" diff --git a/crates/dwn-rs-message-derive/src/derive/descriptor.rs b/crates/dwn-rs-message-derive/src/derive/descriptor.rs new file mode 100644 index 0000000..4cdabc8 --- /dev/null +++ b/crates/dwn-rs-message-derive/src/derive/descriptor.rs @@ -0,0 +1,256 @@ +pub use proc_macro2::TokenStream; +pub use quote::quote_spanned; +use quote::{format_ident, quote}; +pub use syn::{ + parse::{Parse, ParseStream, Parser, Result}, + parse2, + spanned::Spanned, + DeriveInput, Token, +}; +use syn::{Fields, FieldsNamed, Ident, ItemStruct, Path}; + +// parse the attribtutes (`interface`, `method`) from DeriveInput and return them +// as their Interface and related enum Method type +pub struct DescriptorAttr { + interface: Ident, + method: Ident, + fields: Path, +} + +impl Parse for DescriptorAttr { + fn parse(input: ParseStream) -> Result { + let mut interface = None; + let mut method = None; + let mut fields = None; + + while !input.is_empty() { + let ident: syn::Ident = input.parse()?; + input.parse::()?; + match ident.to_string().as_str() { + "interface" => { + interface = Some(input.parse()?); + } + "method" => { + method = Some(input.parse()?); + } + "fields" => { + fields = Some(input.parse()?); + } + _ => return Err(syn::Error::new(ident.span(), "unknown attribute")), + } + if input.peek(Token![,]) { + input.parse::()?; + } + } + + Ok(Self { + interface: interface + .ok_or_else(|| syn::Error::new(input.span(), "missing interface"))?, + method: method.ok_or_else(|| syn::Error::new(input.span(), "missing method"))?, + fields: fields.ok_or_else(|| syn::Error::new(input.span(), "missing fields"))?, + }) + } +} + +pub(crate) fn impl_descriptor_macro_attr(attrs: DescriptorAttr, input: TokenStream) -> TokenStream { + let ast: DeriveInput = parse2(input.clone()).expect("failed to parse input"); + let items: ItemStruct = parse2(input).expect("descriptor mus be a struct"); + + let generics = &ast.generics; + let where_clause = &generics.where_clause; + + let mut item_ser = items.clone(); + + let ident = &items.ident; + let item_ser_ident = format_ident!("{}Internal", &items.ident); + let interface = attrs.interface; + let method = attrs.method; + let fields = attrs.fields; + + let deserialize_message_ident = format_ident!("{}MessageInternal", ident); + + item_ser.ident = item_ser_ident.clone(); + + let mut into_idents: TokenStream = quote! {}; + let mut from_idents: TokenStream = quote! {}; + + if let Fields::Named(ref mut fields) = item_ser.fields { + let idents = move |from: Ident, fields: &FieldsNamed, ast: &DeriveInput| { + fields + .named + .iter() + .map(|field| { + let ident = field.ident.as_ref().expect("field must have an identifier"); + + quote_spanned! { ast.span() => + #ident: #from.#ident, + } + }) + .collect::() + }; + + into_idents = idents(format_ident!("from"), fields, &ast); + from_idents = idents(format_ident!("internal"), fields, &ast); + + fields.named.push( + syn::Field::parse_named + .parse2(quote_spanned!(ast.span() => + pub interface: String + )) + .expect("failed to parse fields"), + ); + fields.named.push( + syn::Field::parse_named + .parse2(quote_spanned!(ast.span() => + pub method: String + )) + .expect("failed to parse fields"), + ); + } + + let intofrom = format!("{}", &item_ser_ident); + + let output = quote_spanned! { ast.span() => + #[serde_with::skip_serializing_none] + #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Clone)] + #[serde(into = #intofrom, from = #intofrom)] + #items + + #[derive(serde::Deserialize, serde::Serialize, Clone)] + #item_ser + + impl From<#ident> for #item_ser_ident { + fn from(from: #ident) -> Self { + #item_ser_ident { + interface: from.interface().to_string(), + method: from.method().to_string(), + #into_idents + } + } + } + + impl From<#item_ser_ident> for #ident { + fn from(internal: #item_ser_ident) -> Self { + #ident { + #from_idents + } + } + } + + impl #generics MessageDescriptor for #ident #generics #where_clause { + type Fields = #fields; + + fn interface(&self) -> &'static str { + #interface + } + + fn method(&self) -> &'static str { + #method + } + } + + #[derive(serde::Deserialize)] + struct #deserialize_message_ident + where + D: crate::interfaces::messages::descriptors::MessageDescriptor + serde::de::DeserializeOwned, + { + descriptor: #ident, + #[serde(flatten)] + fields: D::Fields, + } + + impl<'de> serde::Deserialize<'de> for crate::Message<#ident> + { + fn deserialize(deserializer: Des) -> Result + where + Des: serde::Deserializer<'de>, + { + // Deserialize the internal struct + let inner: #deserialize_message_ident<#ident> = serde::Deserialize::deserialize(deserializer)?; + + // Return the message + Ok(crate::Message { + descriptor: inner.descriptor, + fields: inner.fields, + }) + } + } + + impl<'de>serde::Deserialize<'de> for crate::MessageEvent<#ident> { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(serde::Deserialize)] + struct TempEvent { + pub message: crate::Message<#ident>, + #[serde(rename = "initialWrite")] + pub initial_write: Option>, + } + let temp_event = TempEvent::deserialize(deserializer)?; + + Ok(Self { + message: temp_event.message, + initial_write: temp_event.initial_write, + }) + } + } + }; + output +} + +#[cfg(test)] +mod tests { + use quote::quote; + use quote::ToTokens; + use syn::parse_quote; + + use super::*; + + #[test] + fn test_parse_descriptor_attr() { + const RECORDS: &str = "RECORDS"; + const READ: &str = "READ"; + let input = quote! { + interface = RECORDS, + method = READ, + fields = alloc::vec::Vec, + }; + + let attr: DescriptorAttr = parse2(input).unwrap(); + + assert_eq!(attr.interface.to_token_stream().to_string(), RECORDS); + assert_eq!(attr.method.to_token_stream().to_string(), READ); + assert_eq!( + attr.fields.to_token_stream().to_string(), + "alloc :: vec :: Vec < u32 >" + ); + } + + #[test] + fn test_impl_descriptor_macro_attr_with_fields() { + // Define the input struct as a token stream + let input: TokenStream = quote! { + pub struct Example { + pub name: String, + pub id: u32, + } + }; + + // Define macro implementation attributes + let attrs = DescriptorAttr { + interface: format_ident!("ExampleInterface"), + method: format_ident!("ExampleMethod"), + fields: parse_quote! { FieldsNamed }, + }; + + // Apply the macro + let output = impl_descriptor_macro_attr(attrs, input); + + // Check for key elements in the generated code + assert!(output.to_string().contains("ExampleInternal")); + assert!(output + .to_string() + .contains("impl MessageDescriptor for Example")); + } +} diff --git a/crates/dwn-rs-message-derive/src/derive/mod.rs b/crates/dwn-rs-message-derive/src/derive/mod.rs new file mode 100644 index 0000000..f0d46f3 --- /dev/null +++ b/crates/dwn-rs-message-derive/src/derive/mod.rs @@ -0,0 +1 @@ +pub mod descriptor; diff --git a/crates/dwn-rs-message-derive/src/lib.rs b/crates/dwn-rs-message-derive/src/lib.rs new file mode 100644 index 0000000..9483686 --- /dev/null +++ b/crates/dwn-rs-message-derive/src/lib.rs @@ -0,0 +1,13 @@ +mod derive; + +use derive::descriptor::impl_descriptor_macro_attr; +use proc_macro::TokenStream; +use syn::parse_macro_input; + +#[proc_macro_attribute] +pub fn descriptor(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr); + let item = parse_macro_input!(item); + + impl_descriptor_macro_attr(attr, item).into() +} From 86d9c5f08146f0c4e15f2c778d466bc2965e5948 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 22 Oct 2024 21:30:05 -0300 Subject: [PATCH 06/23] feat: use `Message` as generic Message type --- crates/dwn-rs-core/src/events/emitter.rs | 33 +++- crates/dwn-rs-core/src/events/stream.rs | 173 +++++++++++++---- crates/dwn-rs-core/src/events/subscription.rs | 49 +++-- .../src/interfaces/messages/mod.rs | 175 ++++++++++++++---- crates/dwn-rs-core/src/stores.rs | 17 +- crates/dwn-rs-remote/src/client.rs | 9 +- crates/dwn-rs-remote/src/jsonrpc/dwn.rs | 10 +- crates/dwn-rs-stores/Cargo.toml | 1 + crates/dwn-rs-stores/src/cid.rs | 85 +++++++++ crates/dwn-rs-stores/src/lib.rs | 3 + .../src/surrealdb/message_store.rs | 62 ++++--- crates/dwn-rs-wasm/src/event_stream.rs | 20 +- crates/dwn-rs-wasm/src/events.rs | 8 +- crates/dwn-rs-wasm/src/message.rs | 22 ++- crates/dwn-rs-wasm/src/query.rs | 7 +- 15 files changed, 512 insertions(+), 162 deletions(-) create mode 100644 crates/dwn-rs-stores/src/cid.rs diff --git a/crates/dwn-rs-core/src/events/emitter.rs b/crates/dwn-rs-core/src/events/emitter.rs index e1076c2..12be4e7 100644 --- a/crates/dwn-rs-core/src/events/emitter.rs +++ b/crates/dwn-rs-core/src/events/emitter.rs @@ -1,22 +1,38 @@ +use std::fmt::Debug; + +use serde::{de::DeserializeOwned, Serialize}; use xtra::{Address, Mailbox}; use crate::{ + descriptors::MessageDescriptor, errors::{EventStreamError, StoreError}, - MapValue, + MapValue, Message, }; use tracing::{instrument, trace}; use super::{Emit, EventChannel, EventStream, MessageEvent, Shutdown, Subscribe, Subscription}; #[derive(Debug, Default)] -pub struct EventStreamer(Option>); -impl EventStreamer { +pub struct EventStreamer(Option>>) +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static; + +impl EventStreamer +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static, +{ + pub fn new() -> Self { + Self(None) + } + #[cfg(target_arch = "wasm32")] #[instrument] pub async fn open(&mut self) { trace!("opening EventStreamer (wasm)"); self.0 = Some(xtra::spawn_wasm_bindgen( - EventStream::default(), + EventStream::new(), Mailbox::unbounded(), )); } @@ -25,10 +41,7 @@ impl EventStreamer { #[instrument] pub async fn open(&mut self) { trace!("opening EventStreamer (tokio)"); - self.0 = Some(xtra::spawn_tokio( - EventStream::default(), - Mailbox::unbounded(), - )); + self.0 = Some(xtra::spawn_tokio(EventStream::new(), Mailbox::unbounded())); } #[instrument] @@ -39,7 +52,7 @@ impl EventStreamer { } #[instrument] - pub async fn emit(&self, ns: &str, evt: MessageEvent, indexes: MapValue) { + pub async fn emit(&self, ns: &str, evt: MessageEvent, indexes: MapValue) { if let Some(addr) = &self.0 { let _ = addr .send(Emit { @@ -56,7 +69,7 @@ impl EventStreamer { &self, ns: &str, id: &str, - listener: EventChannel, + listener: EventChannel, ) -> Result { if let Some(addr) = &self.0 { trace!("subscribing to event stream"); diff --git a/crates/dwn-rs-core/src/events/stream.rs b/crates/dwn-rs-core/src/events/stream.rs index ce75c00..509cc2f 100644 --- a/crates/dwn-rs-core/src/events/stream.rs +++ b/crates/dwn-rs-core/src/events/stream.rs @@ -1,39 +1,71 @@ -use std::{collections::BTreeMap, future::Future, pin::Pin}; +use std::{collections::BTreeMap, fmt::Debug, future::Future, pin::Pin}; use futures_util::future; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tracing::{debug, info, instrument, trace, Instrument}; use xtra::{prelude::MessageChannel, Actor, Handler}; -use crate::{errors::EventStreamError, MapValue, Message}; +use crate::{ + descriptors::{records, MessageDescriptor}, + errors::EventStreamError, + Descriptor, MapValue, Message, +}; -pub type Event = (String, MessageEvent, MapValue); +pub type Event = (String, MessageEvent, MapValue); -pub type EventChannel = MessageChannel; +pub type EventChannel = MessageChannel, MessageEvent, xtra::refcount::Strong>; -#[derive(Debug, Default)] -pub struct EventStream { - listeners: BTreeMap<(String, String), EventChannel>, +#[derive(Debug)] +pub struct EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ + listeners: BTreeMap<(String, String), EventChannel>, } -impl EventStream { +impl EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ pub fn new() -> Self { - Self::default() + EventStream { + listeners: BTreeMap::new(), + } } } -#[derive(Debug)] -pub struct Emit { +impl Default for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +pub struct Emit +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ pub ns: String, - pub evt: MessageEvent, + pub evt: MessageEvent, pub indexes: MapValue, } #[derive(Debug)] -pub struct Subscribe { +pub struct Subscribe +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ pub ns: String, pub id: String, - pub listener: EventChannel, + pub listener: EventChannel, } #[derive(Debug, Clone)] @@ -45,14 +77,48 @@ pub struct Close { #[derive(Debug, Clone)] pub struct Shutdown; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct MessageEvent { - message: Message, +#[derive(Debug, Serialize, Clone, PartialEq)] +pub struct MessageEvent +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ + pub message: Message, #[serde(rename = "initialWrite")] - initial_write: Option, // RecordsWrite message + pub initial_write: Option>, +} + +// This is a custom deserializer for the MessageEvent struct. It is necessary because the +// Message struct has a generic type parameter that is not known at compile time. This deserializer +// is the generalized version, which can deserialize any descriptor type. Individual +// Descriptors types implement their own deserializers via. the `MessageDescriptor` trait +// derivation. +impl<'de> Deserialize<'de> for MessageEvent { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct TempEvent { + message: Message, + #[serde(rename = "initialWrite")] + initial_write: Option>, + } + + let temp_event = TempEvent::deserialize(deserializer)?; + + Ok(Self { + message: temp_event.message, + initial_write: temp_event.initial_write, + }) + } } -impl Actor for EventStream { +impl Actor for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ type Stop = (); #[instrument] @@ -61,10 +127,14 @@ impl Actor for EventStream { } } -impl Handler for EventStream { +impl Handler> for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ type Return = (); - async fn handle(&mut self, msg: Emit, _ctx: &mut xtra::Context) -> Self::Return { + async fn handle(&mut self, msg: Emit, _ctx: &mut xtra::Context) -> Self::Return { debug!("Emitting event"); future::join_all( self.listeners @@ -97,10 +167,14 @@ impl Handler for EventStream { } } -impl Handler for EventStream { +impl Handler> for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ type Return = Subscription; - async fn handle(&mut self, msg: Subscribe, _ctx: &mut xtra::Context) -> Self::Return { + async fn handle(&mut self, msg: Subscribe, _ctx: &mut xtra::Context) -> Self::Return { debug!("handling event subscription"); let ns = msg.ns; let id = msg.id; @@ -117,7 +191,11 @@ impl Handler for EventStream { } } -impl Handler for EventStream { +impl Handler for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ type Return = (); async fn handle(&mut self, close: Close, _ctx: &mut xtra::Context) -> Self::Return { @@ -131,7 +209,11 @@ impl Handler for EventStream { } } -impl Handler for EventStream { +impl Handler for EventStream +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ type Return = (); async fn handle(&mut self, _: Shutdown, _ctx: &mut xtra::Context) -> Self::Return { @@ -157,11 +239,15 @@ pub struct Subscription { #[allow(dead_code)] #[instrument] -fn make_close_task( +fn make_close_task( ns: String, id: String, - addr: xtra::Address, -) -> impl Fn() -> Pin> + Send>> + 'static { + addr: xtra::Address>, +) -> impl Fn() -> Pin> + Send>> + 'static +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, +{ move || { let ns = ns.clone(); let id = id.clone(); @@ -184,17 +270,20 @@ mod test { use tracing_test::traced_test; use xtra::{spawn_tokio, Mailbox}; - use crate::{descriptors::Records, Descriptor, Fields}; + use crate::{descriptors::Records, Fields}; #[traced_test] #[tokio::test] async fn test_event_stream() { use super::*; - fn test_evt() -> MessageEvent { + fn test_evt() -> MessageEvent { MessageEvent { message: Message { - descriptor: Descriptor::Records(Records::Read(Default::default())), + descriptor: Descriptor::Records(Records::Read(records::ReadDescriptor { + message_timestamp: chrono::Utc::now(), + filter: Default::default(), + })), fields: Fields::Authorization(Default::default()), }, initial_write: None, @@ -207,21 +296,29 @@ mod test { MapValue::default() } - struct MessageReturner(Option); - impl Actor for MessageReturner { - type Stop = Option; + struct MessageReturner< + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, + >(Option>) + where + Message: Serialize + DeserializeOwned; + impl Actor for MessageReturner + where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Clone + Debug + PartialEq + Send + 'static, + { + type Stop = Option>; async fn stopped(self) -> Self::Stop { self.0 } } - impl Handler for MessageReturner { - type Return = MessageEvent; + impl Handler> for MessageReturner { + type Return = MessageEvent; async fn handle( &mut self, - (ns, msg, indexes): (String, MessageEvent, MapValue), + (ns, msg, indexes): (String, MessageEvent, MapValue), _ctx: &mut xtra::Context, ) -> Self::Return { self.0 = Some(msg.clone()); diff --git a/crates/dwn-rs-core/src/events/subscription.rs b/crates/dwn-rs-core/src/events/subscription.rs index 469d620..f0ff646 100644 --- a/crates/dwn-rs-core/src/events/subscription.rs +++ b/crates/dwn-rs-core/src/events/subscription.rs @@ -1,23 +1,35 @@ +use std::fmt::Debug; + +use serde::{de::DeserializeOwned, Serialize}; use tracing::{info, instrument, trace}; use xtra::{Actor, Address, Handler}; -use crate::MapValue; +use crate::{descriptors::MessageDescriptor, MapValue, Message}; use super::{Event, EventChannel, MessageEvent}; -pub type HandleFn = fn(String, MessageEvent, MapValue); - -pub type SubscriptionFnAddress = Address; +pub type HandleFn = fn(String, MessageEvent, MapValue); +pub type SubscriptionFnAddress = Address>; +pub type BoxedSubscriptionFn = + Box, MapValue) + Send + Sync + 'static>; /// SubscriptionFn is an actor that subscribes to events and calls a function when an event is emitted. /// This is useful for functions that need to be called when an event is emitted, such as from /// the WASM world, or through WebSockets. -pub struct SubscriptionFn { +pub struct SubscriptionFn +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static, +{ id: String, - f: Box, + f: BoxedSubscriptionFn, } -impl Actor for SubscriptionFn { +impl Actor for SubscriptionFn +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static, +{ type Stop = (); async fn stopped(self) -> Self::Stop { @@ -25,10 +37,14 @@ impl Actor for SubscriptionFn { } } -impl Handler for SubscriptionFn { - type Return = MessageEvent; +impl Handler> for SubscriptionFn +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static, +{ + type Return = MessageEvent; - async fn handle(&mut self, evt: Event, _: &mut xtra::Context) -> Self::Return { + async fn handle(&mut self, evt: Event, _: &mut xtra::Context) -> Self::Return { trace!("SubscriptionFn handling event"); (self.f)(evt.0, evt.1.clone(), evt.2); @@ -36,11 +52,12 @@ impl Handler for SubscriptionFn { } } -impl SubscriptionFn { - pub fn new( - id: &str, - f: Box, - ) -> Self { +impl SubscriptionFn +where + Message: Serialize + DeserializeOwned, + D: MessageDescriptor + Clone + Debug + PartialEq + Send + 'static, +{ + pub fn new(id: &str, f: BoxedSubscriptionFn) -> Self { Self { id: id.to_string(), f, @@ -60,7 +77,7 @@ impl SubscriptionFn { } #[instrument] - pub fn channel(addr: Address) -> EventChannel { + pub fn channel(addr: Address) -> EventChannel { trace!("adding SubscriptionFn to EventChannel"); EventChannel::new(addr) } diff --git a/crates/dwn-rs-core/src/interfaces/messages/mod.rs b/crates/dwn-rs-core/src/interfaces/messages/mod.rs index dd99f1d..7630ead 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/mod.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/mod.rs @@ -2,49 +2,158 @@ pub mod descriptors; pub mod fields; pub use descriptors::Descriptor; +use descriptors::MessageDescriptor; pub use fields::Fields; -use fields::InitialWriteField; -use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use crate::{Cursor, SubscriptionID}; - -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct Message { - pub descriptor: Descriptor, - #[serde(flatten)] - pub fields: Fields, // Fields should be an Enum representing possible fields +#[derive(Debug, Clone, PartialEq)] +pub struct Message { + pub descriptor: D, + pub fields: D::Fields, } -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub enum ResponseEntries { - Message(Message), - String(String), +impl Message { + pub fn new(descriptor: D, fields: D::Fields) -> Self { + Self { descriptor, fields } + } } -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub enum Record {} +impl Serialize for Message +where + D: MessageDescriptor + Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + #[derive(Serialize)] + struct TempMessage<'a, D: MessageDescriptor> { + descriptor: &'a D, + #[serde(flatten)] + other: &'a D::Fields, + } + + let temp_message = TempMessage { + descriptor: &self.descriptor, + other: &self.fields, + }; -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct ReadReplyEntry { - pub cid: cid::Cid, - message: Message, + temp_message.serialize(serializer) + } } -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct Status { - pub code: i32, - pub detail: String, +// This is a custom deserializer for the Message struct. It is necessary because the Message +// struct has a generic type parameter that is not known at compile time. This deserializer +// is the generalized version, which can deserialize any descriptor type. Individual +// Descriptors types implement their own deserializers via. the `MessageDescriptor` trait +// derivation. +impl<'de> Deserialize<'de> for Message { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct TempMessage { + descriptor: Descriptor, + #[serde(flatten)] + other: Fields, + } + + let temp_message = TempMessage::deserialize(deserializer)?; + + Ok(Self { + descriptor: temp_message.descriptor, + fields: temp_message.other, + }) + } } -#[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] -pub struct Response { - pub status: Status, - pub entries: Option>, - pub entry: Option, - pub record: Option, - pub cursor: Option, - pub subscription: Option, +#[cfg(test)] +mod test { + + use chrono::Utc; + use descriptors::{ReadDescriptor, Records}; + use dwn_rs_message_derive::descriptor; + use fields::MessageFields; + use serde_json::json; + + use crate::{auth::Authorization, Filters}; + + use super::*; + + const INTERFACE: &str = "interface"; + const METHOD: &str = "method"; + #[descriptor(interface = INTERFACE, method = METHOD, fields = TestFields)] + struct TestDescriptor { + data: String, + } + + #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] + struct TestFields { + field1: String, + field2: i32, + } + impl MessageFields for TestFields {} + + #[test] + fn test_message_serialize() { + let desc = TestDescriptor { + data: "test".to_string(), + }; + let fields = TestFields { + field1: "test".to_string(), + field2: 42, + }; + + let message = Message::new(desc, fields); + + let serialized = serde_json::to_string(&message).unwrap(); + let expected = r#"{"descriptor":{"data":"test","interface":"interface","method":"method"},"field1":"test","field2":42}"#; + + assert_eq!(serialized, expected); + + let now = Utc::now(); + + let desc = Descriptor::Records(Records::Read(ReadDescriptor { + message_timestamp: now, + filter: Filters::default(), + })); + let fields = Fields::Authorization(Authorization { + ..Default::default() + }); + + let message = Message::new(desc, fields); + let serialized = json!(&message); + let expected = json!({ + "descriptor": { + "messageTimestamp": now, + "filter": Filters::default(), + "interface":"Records","method":"Read" + }, + "signature":{} + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_message_deserialize() { + let serialized = r#"{"descriptor":{"data":"test","interface":"interface","method":"method"},"field1":"test","field2":42}"#; + + let message: Message = serde_json::from_str(serialized).unwrap(); + + let descriptor = TestDescriptor { + data: "test".to_string(), + }; + + let fields = TestFields { + field1: "test".to_string(), + field2: 42, + }; + + let expected = Message::new(descriptor, fields); + + assert_eq!(message, expected); + } } diff --git a/crates/dwn-rs-core/src/stores.rs b/crates/dwn-rs-core/src/stores.rs index af41a1e..fc6647e 100644 --- a/crates/dwn-rs-core/src/stores.rs +++ b/crates/dwn-rs-core/src/stores.rs @@ -7,6 +7,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use ulid::Ulid; use crate::{ + descriptors::MessageDescriptor, errors::{DataStoreError, EventLogError, MessageStoreError, ResumableTaskStoreError}, filters::filter_key::Filters, Cursor, MessageSort, Pagination, QueryReturn, @@ -18,27 +19,31 @@ pub trait MessageStore: Default { fn close(&mut self) -> impl Future; - fn put( + fn put( &self, tenant: &str, - message: Message, + message: Message, indexes: MapValue, tags: MapValue, ) -> impl Future> + Send; - fn get( + fn get( &self, tenant: &str, cid: &str, - ) -> impl Future> + Send; + ) -> impl Future, MessageStoreError>> + Send + where + Message: DeserializeOwned; - fn query( + fn query( &self, tenant: &str, filter: Filters, sort: Option, pagination: Option, - ) -> impl Future, MessageStoreError>> + Send; + ) -> impl Future>, MessageStoreError>> + Send + where + Message: DeserializeOwned; fn delete( &self, diff --git a/crates/dwn-rs-remote/src/client.rs b/crates/dwn-rs-remote/src/client.rs index efb10c7..1611cc8 100644 --- a/crates/dwn-rs-remote/src/client.rs +++ b/crates/dwn-rs-remote/src/client.rs @@ -41,12 +41,15 @@ where Ok(RemoteDWNInstance { rpc }) } - pub async fn process_message( + pub async fn process_message( &mut self, tenant: &str, - message: Message, + message: Message, data: Option, - ) -> ClientResult<(DWNResponse, Option>>)> { + ) -> ClientResult<(DWNResponse, Option>>)> + where + D: MessageDescriptor + DeserializeOwned + Serialize + Send + 'static, + { let res = self .rpc .request( diff --git a/crates/dwn-rs-remote/src/jsonrpc/dwn.rs b/crates/dwn-rs-remote/src/jsonrpc/dwn.rs index ab325c5..8041e6c 100644 --- a/crates/dwn-rs-remote/src/jsonrpc/dwn.rs +++ b/crates/dwn-rs-remote/src/jsonrpc/dwn.rs @@ -1,12 +1,12 @@ -use dwn_rs_core::Message; -use serde::{Deserialize, Serialize}; +use dwn_rs_core::{descriptors::MessageDescriptor, Message}; +use serde::Serialize; pub const PROCESS_MESSAGE: &str = "dwn.processMessage"; -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct ProcessMessageParams { +#[derive(Debug, Serialize)] +pub struct ProcessMessageParams { pub target: String, - pub message: Message, + pub message: Message, #[serde(skip_serializing_if = "Option::is_none", rename = "encodedData")] pub encoded_data: Option>, } diff --git a/crates/dwn-rs-stores/Cargo.toml b/crates/dwn-rs-stores/Cargo.toml index 4924a44..0c5735b 100644 --- a/crates/dwn-rs-stores/Cargo.toml +++ b/crates/dwn-rs-stores/Cargo.toml @@ -57,3 +57,4 @@ bytes = { version = "1.8.0", features = ["serde"] } [dev-dependencies] rand = "0.8.5" +tokio-util = "0.7.12" diff --git a/crates/dwn-rs-stores/src/cid.rs b/crates/dwn-rs-stores/src/cid.rs new file mode 100644 index 0000000..2ac7185 --- /dev/null +++ b/crates/dwn-rs-stores/src/cid.rs @@ -0,0 +1,85 @@ +use std::collections::TryReserveError; + +use cid::Cid; +use multihash_codetable::Code; +use multihash_codetable::MultihashDigest; +use serde_ipld_dagcbor::EncodeError; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek}; + +pub fn generate_cid(data: B) -> Result> +where + B: AsRef<[u8]>, +{ + let mh = Code::Sha2_256.digest(data.as_ref()); + let cid = Cid::new_v1(multicodec::Codec::DagCbor.code(), mh); + + Ok(cid) +} + +pub async fn generate_cid_from_asyncreader( + reader: R, +) -> Result> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut buf = Vec::new(); + reader + .take(1024 * 1024) + .read_to_end(&mut buf) + .await + .map_err(EncodeError::Write) + .unwrap(); + + let mh = Code::Sha2_256.digest(&buf); + let cid = Cid::new_v1(multicodec::Codec::DagCbor.code(), mh); + + Ok(cid) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::io::Cursor; + use std::str::FromStr; + + #[test] + fn test_generate_cid() { + let data = json!({ + "hello": "world", + }); + + let cid = generate_cid(data.to_string()).unwrap(); + + assert_eq!( + cid, + Cid::from_str("bafyreietui4xdkiu4xvmx4fi2jivjtndbhb4drzpxomrjvd4mdz4w2avra").unwrap(), + ); + assert_eq!(cid.codec(), multicodec::Codec::DagCbor.code()); + } + + #[tokio::test] + async fn test_generate_cid_from_asyncreader() { + // Define some sample data to read + let data = b"Sample data to generate CID"; + + // Create a cursor over the data, which implements AsyncRead + AsyncSeek + let cursor = Cursor::new(data); + + // Call the function with the cursor + let cid = generate_cid_from_asyncreader(cursor).await; + assert!(cid.is_ok()); + let cid = cid.unwrap(); + + // Verify that the CID is generated correctly + // For a real test, you might compare the cid with a known value + assert_eq!(cid.version(), cid::Version::V1); + assert_eq!(cid.codec(), multicodec::Codec::DagCbor.code()); + + // For demonstration: hash the data using the same logic to get the expected hash + let expected_mh = multihash_codetable::Code::Sha2_256.digest(data); + + // Compare multihashes + assert_eq!(cid.hash(), &expected_mh); + } +} diff --git a/crates/dwn-rs-stores/src/lib.rs b/crates/dwn-rs-stores/src/lib.rs index a32e9da..9aa9770 100644 --- a/crates/dwn-rs-stores/src/lib.rs +++ b/crates/dwn-rs-stores/src/lib.rs @@ -1,3 +1,6 @@ +pub mod cid; +pub use cid::*; + #[cfg(feature = "surrealdb")] pub mod surrealdb; #[cfg(feature = "surrealdb")] diff --git a/crates/dwn-rs-stores/src/surrealdb/message_store.rs b/crates/dwn-rs-stores/src/surrealdb/message_store.rs index 183a52b..12fcc4a 100644 --- a/crates/dwn-rs-stores/src/surrealdb/message_store.rs +++ b/crates/dwn-rs-stores/src/surrealdb/message_store.rs @@ -1,17 +1,18 @@ use std::str::FromStr; use cid::Cid; -use multihash_codetable::{Code, MultihashDigest}; +use serde::{de::DeserializeOwned, Serialize}; use super::core::SurrealDB; -use crate::SurrealQuery; +use crate::{generate_cid, SurrealQuery}; use dwn_rs_core::{ + descriptors::MessageDescriptor, errors::{MessageStoreError, StoreError}, + fields::MessageFields, filters::{Filters, MessageSort, Pagination, Query, QueryReturn}, interfaces::Message, stores::MessageStore, - value::{MapValue, Value}, - Fields, + value::MapValue, }; use super::{ @@ -30,23 +31,20 @@ impl MessageStore for SurrealDB { self.close().await } - async fn put( + async fn put( &self, tenant: &str, - mut message: Message, + mut message: Message, indexes: MapValue, tags: MapValue, - ) -> Result { - let mut data: Option = None; - - if let Fields::EncodedWrite(ref mut ew) = message.fields { - data = ew.encoded_data.clone().map(Value::String); - ew.encoded_data = None; - } + ) -> Result + where + D: MessageDescriptor + Serialize + Send + 'static, + { + let data = message.fields.encoded_data(); let i = serde_ipld_dagcbor::to_vec(&message)?; - let mh = Code::Sha2_256.digest(i.as_slice()); - let cid = Cid::new_v1(multicodec::Codec::DagCbor.code(), mh); + let cid = generate_cid(&i)?; self.with_database(tenant, |db| async move { db.create::>((MESSAGES_TABLE, cid.to_string())) @@ -67,7 +65,11 @@ impl MessageStore for SurrealDB { Ok(cid) } - async fn get(&self, tenant: &str, cid: &str) -> Result { + async fn get(&self, tenant: &str, cid: &str) -> Result, MessageStoreError> + where + Message: DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Send + 'static, + { // fetch and decode the message from the db let encoded_message: GetEncodedMessage = self .with_database(tenant, |db| async move { @@ -84,24 +86,27 @@ impl MessageStore for SurrealDB { return Err(MessageStoreError::StoreError(StoreError::NotFound)); } - let mut from: Message = serde_ipld_dagcbor::from_slice(&encoded_message.encoded_message)?; + let mut from: Message = + serde_ipld_dagcbor::from_slice(&encoded_message.encoded_message)?; if let Some(data) = encoded_message.encoded_data { - if let Fields::EncodedWrite(ref mut ew) = from.fields { - ew.encoded_data = Some(data.to_string()); - }; - } + from.fields.encode_data(data); + }; Ok(from) } - async fn query( + async fn query( &self, tenant: &str, filters: Filters, sort: Option, pagination: Option, - ) -> Result, MessageStoreError> { + ) -> Result>, MessageStoreError> + where + Message: DeserializeOwned, + D: MessageDescriptor + DeserializeOwned + Send + 'static, + { let mut qb = self .with_database(tenant, |db| async move { Ok(SurrealQuery::::new(db)) @@ -125,24 +130,21 @@ impl MessageStore for SurrealDB { .filter(|m| m.tenant == tenant) .map(|m: GetEncodedMessage| { let cid = Cid::from_str(m.cid.as_str())?; - let mh = Code::Sha2_256.digest(&m.encoded_message); - let data_cid = Cid::new_v1(multicodec::Codec::DagCbor.code(), mh); + let data_cid = generate_cid(&m.encoded_message)?; if cid != data_cid { return Err(MessageStoreError::StoreError(StoreError::NotFound)); } - let mut msg: Message = serde_ipld_dagcbor::from_slice(&m.encoded_message)?; + let mut msg: Message = serde_ipld_dagcbor::from_slice(&m.encoded_message)?; if let Some(data) = m.encoded_data { - if let Fields::EncodedWrite(ref mut ew) = msg.fields { - ew.encoded_data = Some(data.to_string()); - }; + msg.fields.encode_data(data); } Ok(msg) }) - .collect::, MessageStoreError>>()?; + .collect::>, MessageStoreError>>()?; Ok(QueryReturn { items: r, cursor }) } diff --git a/crates/dwn-rs-wasm/src/event_stream.rs b/crates/dwn-rs-wasm/src/event_stream.rs index 4d88c87..5357123 100644 --- a/crates/dwn-rs-wasm/src/event_stream.rs +++ b/crates/dwn-rs-wasm/src/event_stream.rs @@ -2,7 +2,7 @@ use alloc::{boxed::Box, string::String}; use async_std::channel::unbounded; use dwn_rs_core::{ - emitter::EventStreamer, subscription::SubscriptionFn, MapValue, + emitter::EventStreamer, subscription::SubscriptionFn, Descriptor, MapValue, MessageEvent as CoreMessageEvent, }; use js_sys::Promise; @@ -16,9 +16,15 @@ use crate::{ }; #[wasm_bindgen] -#[derive(Debug, Default)] +#[derive(Debug)] pub struct EventStream { - events: EventStreamer, + events: EventStreamer, +} + +impl Default for EventStream { + fn default() -> Self { + Self::new() + } } #[wasm_bindgen] @@ -27,7 +33,9 @@ impl EventStream { pub fn new() -> Self { console_error_panic_hook::set_once(); - Self::default() + Self { + events: EventStreamer::new(), + } } #[wasm_bindgen] @@ -71,9 +79,9 @@ impl EventStream { async fn subscription_for_func( id: &str, listener: js_sys::Function, -) -> Result { +) -> Result, JsError> { trace!("creating subscription for js function"); - let (tx, rx) = unbounded::<(String, CoreMessageEvent, MapValue)>(); + let (tx, rx) = unbounded::<(String, CoreMessageEvent, MapValue)>(); spawn_local(async move { while let Ok((tenant, evt, indexes)) = rx.recv().await { diff --git a/crates/dwn-rs-wasm/src/events.rs b/crates/dwn-rs-wasm/src/events.rs index 2a20e7d..2fd9913 100644 --- a/crates/dwn-rs-wasm/src/events.rs +++ b/crates/dwn-rs-wasm/src/events.rs @@ -1,6 +1,6 @@ use alloc::boxed::Box; -use dwn_rs_core::{MessageEvent as CoreMessageEvent, Subscription}; +use dwn_rs_core::{Descriptor, MessageEvent as CoreMessageEvent, Subscription}; use futures_util::FutureExt; use js_sys::{Object, Promise, Reflect}; use serde::Serialize; @@ -19,7 +19,7 @@ extern "C" { pub type EventSubscription; } -impl From<&MessageEvent> for CoreMessageEvent { +impl From<&MessageEvent> for CoreMessageEvent { fn from(value: &MessageEvent) -> Self { if value.is_undefined() { throw_str("MessageEvent is undefined"); @@ -29,8 +29,8 @@ impl From<&MessageEvent> for CoreMessageEvent { } } -impl From for MessageEvent { - fn from(value: CoreMessageEvent) -> Self { +impl From> for MessageEvent { + fn from(value: CoreMessageEvent) -> Self { value .serialize(&crate::ser::serializer()) .expect_throw("unable to serialize event") diff --git a/crates/dwn-rs-wasm/src/message.rs b/crates/dwn-rs-wasm/src/message.rs index 578715c..cadb2f2 100644 --- a/crates/dwn-rs-wasm/src/message.rs +++ b/crates/dwn-rs-wasm/src/message.rs @@ -1,8 +1,9 @@ +use alloc::string::ToString; use alloc::vec::Vec; use serde::Serialize; use wasm_bindgen::{prelude::*, throw_str}; -use dwn_rs_core::Message; +use dwn_rs_core::{Descriptor, Message}; use crate::ser::serializer; @@ -18,18 +19,23 @@ extern "C" { pub type GenericMessageArray; } -impl From<&GenericMessage> for Message { +impl From<&GenericMessage> for Message { fn from(value: &GenericMessage) -> Self { if value.is_undefined() { throw_str("Message is undefined"); } - serde_wasm_bindgen::from_value(value.into()).expect_throw("unable to deserialize message") + let t = serde_wasm_bindgen::from_value(value.into()); + + match t { + Ok(m) => m, + Err(e) => throw_str(e.to_string().as_str()), + } } } -impl From for GenericMessage { - fn from(value: Message) -> Self { +impl From> for GenericMessage { + fn from(value: Message) -> Self { value .serialize(&serializer()) .expect_throw("unable to serialize message") @@ -37,14 +43,14 @@ impl From for GenericMessage { } } -impl From<&GenericMessageArray> for Vec { +impl From<&GenericMessageArray> for Vec> { fn from(value: &GenericMessageArray) -> Self { serde_wasm_bindgen::from_value(value.into()).expect_throw("unable to deserialize messages") } } -impl From> for GenericMessageArray { - fn from(value: Vec) -> Self { +impl From>> for GenericMessageArray { + fn from(value: Vec>) -> Self { value .serialize(&serializer()) .expect_throw("unable to serialize messages") diff --git a/crates/dwn-rs-wasm/src/query.rs b/crates/dwn-rs-wasm/src/query.rs index bbcf982..0325a80 100644 --- a/crates/dwn-rs-wasm/src/query.rs +++ b/crates/dwn-rs-wasm/src/query.rs @@ -5,6 +5,7 @@ use wasm_bindgen::prelude::*; use dwn_rs_core::{ filters::{Cursor, MessageSort, Pagination, QueryReturn}, interfaces::Message, + Descriptor, }; use crate::ser::serializer; @@ -37,11 +38,11 @@ extern "C" { pub type JSPaginationCursor; } -impl From> for JSQueryReturn { - fn from(value: QueryReturn) -> Self { +impl From>> for JSQueryReturn { + fn from(value: QueryReturn>) -> Self { #[derive(Serialize)] struct Wrapper<'a> { - messages: &'a [Message], + messages: &'a [Message], cursor: Option, } From 4c77bf526e45fa23225f76507ac8e9fe6608a6a4 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 22 Oct 2024 21:30:26 -0300 Subject: [PATCH 07/23] fix: add dwn-rs-message-derive crate --- Cargo.toml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 186cea5..df4535f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,12 @@ [workspace] resolver = "2" -members = ["crates/dwn-rs-core", "crates/dwn-rs-remote", "crates/dwn-rs-stores", "crates/dwn-rs-wasm"] +members = [ + "crates/dwn-rs-core", + "crates/dwn-rs-message-derive", + "crates/dwn-rs-remote", + "crates/dwn-rs-stores", + "crates/dwn-rs-wasm", +] [profile.release] lto = true From 5783628c9e7246d5bb66f1a14a0731dab982d249 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 22 Oct 2024 23:32:33 -0300 Subject: [PATCH 08/23] chore: add #[derive(Default)] on MessageDescriptors --- .../src/interfaces/messages/descriptors/protocols.rs | 2 +- crates/dwn-rs-message-derive/src/derive/descriptor.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs b/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs index 0aac84c..f6b2433 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/descriptors/protocols.rs @@ -16,7 +16,7 @@ pub struct ConfigureDescriptor { pub definition: ProtocolDefinition, } -#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Default)] pub struct ProtocolDefinition { pub protocol: String, pub published: bool, diff --git a/crates/dwn-rs-message-derive/src/derive/descriptor.rs b/crates/dwn-rs-message-derive/src/derive/descriptor.rs index 4cdabc8..e5ec0e6 100644 --- a/crates/dwn-rs-message-derive/src/derive/descriptor.rs +++ b/crates/dwn-rs-message-derive/src/derive/descriptor.rs @@ -112,7 +112,7 @@ pub(crate) fn impl_descriptor_macro_attr(attrs: DescriptorAttr, input: TokenStre let output = quote_spanned! { ast.span() => #[serde_with::skip_serializing_none] - #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Clone)] + #[derive(serde::Serialize, serde::Deserialize, Default, Debug, PartialEq, Clone)] #[serde(into = #intofrom, from = #intofrom)] #items From 5fd5d5cb26280d018c730a02db8d9b01b3f4de78 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 22 Oct 2024 23:33:08 -0300 Subject: [PATCH 09/23] feat: add DWN response types --- crates/dwn-rs-core/src/interfaces/mod.rs | 2 + .../src/interfaces/replies/messages.rs | 26 ++++++++++++ .../dwn-rs-core/src/interfaces/replies/mod.rs | 42 +++++++++++++++++++ .../src/interfaces/replies/protocols.rs | 10 +++++ .../src/interfaces/replies/records.rs | 42 +++++++++++++++++++ 5 files changed, 122 insertions(+) create mode 100644 crates/dwn-rs-core/src/interfaces/replies/messages.rs create mode 100644 crates/dwn-rs-core/src/interfaces/replies/mod.rs create mode 100644 crates/dwn-rs-core/src/interfaces/replies/protocols.rs create mode 100644 crates/dwn-rs-core/src/interfaces/replies/records.rs diff --git a/crates/dwn-rs-core/src/interfaces/mod.rs b/crates/dwn-rs-core/src/interfaces/mod.rs index 8f05488..77e0836 100644 --- a/crates/dwn-rs-core/src/interfaces/mod.rs +++ b/crates/dwn-rs-core/src/interfaces/mod.rs @@ -1,3 +1,5 @@ pub mod messages; +pub mod replies; pub use messages::*; +pub use replies::{Reply, Response}; diff --git a/crates/dwn-rs-core/src/interfaces/replies/messages.rs b/crates/dwn-rs-core/src/interfaces/replies/messages.rs new file mode 100644 index 0000000..ef71981 --- /dev/null +++ b/crates/dwn-rs-core/src/interfaces/replies/messages.rs @@ -0,0 +1,26 @@ +use cid::Cid; +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +use crate::{Cursor, Descriptor, Message}; + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct ReadEntry { + #[serde(rename = "messageCid")] + pub cid: Cid, + pub message: Option>, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Read { + pub entry: Option, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Query { + pub entries: Option>, + pub cursor: Option, +} diff --git a/crates/dwn-rs-core/src/interfaces/replies/mod.rs b/crates/dwn-rs-core/src/interfaces/replies/mod.rs new file mode 100644 index 0000000..32babe7 --- /dev/null +++ b/crates/dwn-rs-core/src/interfaces/replies/mod.rs @@ -0,0 +1,42 @@ +pub mod messages; +pub mod protocols; +pub mod records; + +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +use crate::SubscriptionID; + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Status { + pub code: i32, + pub detail: String, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Response { + pub status: Status, + #[serde(flatten)] + pub reply: Reply, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Empty {} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Subscribe { + pub subscription: Option, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum Reply { + Empty(Empty), + RecordsRead(records::Read), + RecordsQuery(records::Query), + MessageRead(messages::Read), + MessageQuery(messages::Query), + ProtocolsQuery(protocols::Query), + Subscribe(Subscribe), +} diff --git a/crates/dwn-rs-core/src/interfaces/replies/protocols.rs b/crates/dwn-rs-core/src/interfaces/replies/protocols.rs new file mode 100644 index 0000000..bde7879 --- /dev/null +++ b/crates/dwn-rs-core/src/interfaces/replies/protocols.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +use crate::{descriptors::protocols, Message}; + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Query { + pub entries: Option>, +} diff --git a/crates/dwn-rs-core/src/interfaces/replies/records.rs b/crates/dwn-rs-core/src/interfaces/replies/records.rs new file mode 100644 index 0000000..e0e359d --- /dev/null +++ b/crates/dwn-rs-core/src/interfaces/replies/records.rs @@ -0,0 +1,42 @@ +use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; + +use crate::{ + descriptors::{records::WriteDescriptor, DeleteDescriptor}, + Cursor, Message, +}; + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct ReadEntry { + #[serde(rename = "recordsWrite")] + pub records_write: Option>, + #[serde(rename = "recordsDelete")] + pub records_delete: Option>, + #[serde(rename = "initialWrite")] + pub initial_write: Option>, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Read { + pub entry: Option, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct QueryEntry { + #[serde(rename = "initialWrite")] + pub initial_write: Option>, + #[serde(rename = "encodedData")] + pub encoded_data: Option, + #[serde(flatten)] + pub message: Message, +} + +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Query { + pub entries: Option>, + pub cursor: Option, +} From 327e926ffbbdf2749517314f5ca204663a567d8d Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 19 Nov 2024 18:55:47 -0400 Subject: [PATCH 10/23] chore: move CID operations to dwn-rs-core --- crates/dwn-rs-core/Cargo.toml | 9 ++- .../src => dwn-rs-core/src/utils}/cid.rs | 30 +++++++++ crates/dwn-rs-core/src/utils/mod.rs | 66 +++++++++++++++++++ crates/dwn-rs-stores/Cargo.toml | 4 -- crates/dwn-rs-stores/src/lib.rs | 3 - .../src/surrealdb/message_store.rs | 3 +- crates/dwn-rs-stores/src/surrealdb/models.rs | 2 +- crates/dwn-rs-stores/src/surrealdb/query.rs | 4 +- 8 files changed, 109 insertions(+), 12 deletions(-) rename crates/{dwn-rs-stores/src => dwn-rs-core/src/utils}/cid.rs (75%) create mode 100644 crates/dwn-rs-core/src/utils/mod.rs diff --git a/crates/dwn-rs-core/Cargo.toml b/crates/dwn-rs-core/Cargo.toml index 0cc2d03..0e0a942 100644 --- a/crates/dwn-rs-core/Cargo.toml +++ b/crates/dwn-rs-core/Cargo.toml @@ -39,7 +39,14 @@ serde_json = "1.0.113" tracing = "0.1.40" tracing-test = { version = "0.2.5", features = ["no-env-filter"] } bytes = "1.8.0" +rand = "0.8.5" +secp256k1 = { version = "0.30", features = ["rand"] } +partially = { version = "0.2.1", features = ["derive"] } +derive_builder = "0.20.2" +multicodec = { git = "https://github.com/cryptidtech/rust-multicodec.git" } # Use moden fork, see gnunicorn/rust-multicodec#1 +multihash = { version = "0.19.1", features = ["serde"] } +multihash-codetable = { version = "0.1.4", features = ["serde", "sha2"] } +dwn-rs-message-derive = { path = "../dwn-rs-message-derive" } [dev-dependencies] serde_json = "1.0.113" -dwn-rs-message-derive = { path = "../dwn-rs-message-derive" } diff --git a/crates/dwn-rs-stores/src/cid.rs b/crates/dwn-rs-core/src/utils/cid.rs similarity index 75% rename from crates/dwn-rs-stores/src/cid.rs rename to crates/dwn-rs-core/src/utils/cid.rs index 2ac7185..7efc8af 100644 --- a/crates/dwn-rs-stores/src/cid.rs +++ b/crates/dwn-rs-core/src/utils/cid.rs @@ -1,6 +1,9 @@ +use bytes::Bytes; +use futures_util::TryStreamExt; use std::collections::TryReserveError; use cid::Cid; +use futures_util::TryStream; use multihash_codetable::Code; use multihash_codetable::MultihashDigest; use serde_ipld_dagcbor::EncodeError; @@ -16,6 +19,33 @@ where Ok(cid) } +pub async fn generate_cid_from_serialized( + data: T, +) -> Result> { + let serialized = serde_ipld_dagcbor::to_vec(&data)?; + generate_cid(serialized) +} + +pub async fn generate_cid_from_stream + Unpin>( + stream: S, +) -> Result> +where + S::Error: Into>, +{ + let mut buf = Vec::new(); + let _ = stream + .try_for_each(|chunk| { + buf.extend_from_slice(&chunk); + async { Ok(()) } + }) + .await; + + let mh = Code::Sha2_256.digest(&buf); + let cid = Cid::new_v1(multicodec::Codec::DagCbor.code(), mh); + + Ok(cid) +} + pub async fn generate_cid_from_asyncreader( reader: R, ) -> Result> diff --git a/crates/dwn-rs-core/src/utils/mod.rs b/crates/dwn-rs-core/src/utils/mod.rs new file mode 100644 index 0000000..3bb6af4 --- /dev/null +++ b/crates/dwn-rs-core/src/utils/mod.rs @@ -0,0 +1,66 @@ +pub mod cid; + +use partially::Partial; +use rand::{distributions::Alphanumeric, Rng}; +use secp256k1::{Keypair, Secp256k1}; +use ssi_dids_core::DIDBuf; +use std::str::FromStr; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum PersonaError { + #[error("DID error: {0}")] + DIDError(#[from] ssi_dids_core::InvalidDID), +} + +#[derive(Partial, Debug)] +#[partially(derive(Default))] +pub struct Persona { + pub did: DIDBuf, + pub key_id: String, + keypair: secp256k1::Keypair, +} + +impl Persona { + pub fn generate( + PartialPersona { + did, + key_id, + keypair, + }: PartialPersona, + ) -> Result { + let did = did.unwrap_or_else(|| { + let suffix = generate_random_string(32); + DIDBuf::from_str(&format!("did:example:{}", suffix)).unwrap() + }); + + let keypair = keypair.unwrap_or_else(|| { + let secp = Secp256k1::new(); + + Keypair::new(&secp, &mut rand::thread_rng()) + }); + + let key_id = key_id.unwrap_or_else(|| { + let suffix = generate_random_string(16); + format!("{}#{}", did, suffix) + }); + + Ok(Self { + did, + key_id, + keypair, + }) + } + + pub fn public_key(&self) -> secp256k1::PublicKey { + self.keypair.public_key() + } +} + +pub fn generate_random_string(len: usize) -> String { + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} diff --git a/crates/dwn-rs-stores/Cargo.toml b/crates/dwn-rs-stores/Cargo.toml index 0c5735b..b89a867 100644 --- a/crates/dwn-rs-stores/Cargo.toml +++ b/crates/dwn-rs-stores/Cargo.toml @@ -25,10 +25,6 @@ no-std = [] futures-util = "0.3.30" chrono = { version = "0.4.37", features = ["serde", "wasmbind"] } cid = { version = "0.11.1", features = ["serde"] } -ipld-core = { version = "0.4.1", features = ["serde"] } -multicodec = { git = "https://github.com/cryptidtech/rust-multicodec.git" } # Use moden fork, see gnunicorn/rust-multicodec#1 -multihash = { version = "0.19.1", features = ["serde"] } -multihash-codetable = { version = "0.1.2", features = ["serde", "sha2"] } thiserror = "1.0.63" time = "0.3.36" tokio = { version = "1.39.2", features = ["io-util", "rt", "macros"] } diff --git a/crates/dwn-rs-stores/src/lib.rs b/crates/dwn-rs-stores/src/lib.rs index 9aa9770..a32e9da 100644 --- a/crates/dwn-rs-stores/src/lib.rs +++ b/crates/dwn-rs-stores/src/lib.rs @@ -1,6 +1,3 @@ -pub mod cid; -pub use cid::*; - #[cfg(feature = "surrealdb")] pub mod surrealdb; #[cfg(feature = "surrealdb")] diff --git a/crates/dwn-rs-stores/src/surrealdb/message_store.rs b/crates/dwn-rs-stores/src/surrealdb/message_store.rs index 12fcc4a..c948fd3 100644 --- a/crates/dwn-rs-stores/src/surrealdb/message_store.rs +++ b/crates/dwn-rs-stores/src/surrealdb/message_store.rs @@ -4,7 +4,8 @@ use cid::Cid; use serde::{de::DeserializeOwned, Serialize}; use super::core::SurrealDB; -use crate::{generate_cid, SurrealQuery}; +use crate::SurrealQuery; +use dwn_rs_core::utils::cid::generate_cid; use dwn_rs_core::{ descriptors::MessageDescriptor, errors::{MessageStoreError, StoreError}, diff --git a/crates/dwn-rs-stores/src/surrealdb/models.rs b/crates/dwn-rs-stores/src/surrealdb/models.rs index 855034e..f781bc6 100644 --- a/crates/dwn-rs-stores/src/surrealdb/models.rs +++ b/crates/dwn-rs-stores/src/surrealdb/models.rs @@ -1,10 +1,10 @@ use std::str::FromStr; +use cid::Cid; use dwn_rs_core::{ filters::{MessageSort, MessageWatermark, NoSort}, value::{MapValue, Value}, }; -use ipld_core::cid::Cid; use serde::{Deserialize, Serialize}; use surrealdb::{ sql::{Datetime, Value as SurrealValue}, diff --git a/crates/dwn-rs-stores/src/surrealdb/query.rs b/crates/dwn-rs-stores/src/surrealdb/query.rs index 8b745c1..0b2c638 100644 --- a/crates/dwn-rs-stores/src/surrealdb/query.rs +++ b/crates/dwn-rs-stores/src/surrealdb/query.rs @@ -4,6 +4,7 @@ use std::{ ops::{Bound, RangeBounds}, }; +use cid::Cid; use serde::de::DeserializeOwned; use surrealdb::sql::{value as surreal_value, Cond, Function, Idiom, Subquery}; use surrealdb::{ @@ -64,7 +65,7 @@ where } pub trait CursorValue { - fn cid(&self) -> ipld_core::cid::Cid; + fn cid(&self) -> Cid; fn cursor_value(&self, sort: T) -> dwn_rs_core::value::Value; } @@ -243,7 +244,6 @@ where None => stmt.cond, }; - let mut q = self .db .query(stmt.clone()) From 694c0739737d9abcf188b9e509a3aa213be530f9 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Tue, 19 Nov 2024 19:39:39 -0400 Subject: [PATCH 11/23] feat: add JWS signature calculation for Write fields --- Cargo.lock | 452 +++++++++++++++++++++++++---- crates/dwn-rs-core/Cargo.toml | 1 + crates/dwn-rs-core/src/auth/jws.rs | 142 ++++++++- crates/dwn-rs-core/src/auth/mod.rs | 3 +- 4 files changed, 534 insertions(+), 64 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4b4cbd..730a619 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,7 +308,7 @@ dependencies = [ "Inflector", "async-graphql-parser", "darling", - "proc-macro-crate 3.1.0", + "proc-macro-crate", "proc-macro2", "quote", "strum", @@ -576,6 +576,22 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" +[[package]] +name = "bitcoin-io" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "340e09e8399c7bd8912f495af6aa58bea0c9214773417ffaa8f6460f93aaee56" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" +dependencies = [ + "bitcoin-io", + "hex-conservative", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -717,7 +733,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7a8646f94ab393e43e8b35a2558b1624bed28b97ee09c5d15456e3c9463f46d" dependencies = [ "once_cell", - "proc-macro-crate 3.1.0", + "proc-macro-crate", "proc-macro2", "quote", "syn 2.0.77", @@ -1319,6 +1335,37 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.77", +] + [[package]] name = "derive_more" version = "1.0.0" @@ -1388,6 +1435,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "dmp" version = "0.2.0" @@ -1411,14 +1469,23 @@ dependencies = [ "bytes", "chrono", "cid", + "derive_builder", "derive_more", + "dwn-rs-message-derive", "futures-util", "ipld-core", + "multicodec", + "multihash", + "multihash-codetable", + "partially", + "rand", + "secp256k1", "serde", "serde_ipld_dagcbor", "serde_json", "serde_repr", "serde_with", + "ssi-claims-core", "ssi-dids-core", "ssi-jwk 0.3.0", "ssi-jws 0.3.0", @@ -1431,12 +1498,23 @@ dependencies = [ "xtra", ] +[[package]] +name = "dwn-rs-message-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "dwn-rs-remote" version = "0.1.0" dependencies = [ "bytes", + "chrono", "dwn-rs-core", + "dwn-rs-stores", "futures-core", "futures-util", "http 1.1.0", @@ -1451,7 +1529,9 @@ dependencies = [ "tower", "tracing", "tracing-subscriber", + "tracing-test", "ulid", + "url", ] [[package]] @@ -1464,11 +1544,7 @@ dependencies = [ "cid", "dwn-rs-core", "futures-util", - "ipld-core", "memoize", - "multicodec", - "multihash", - "multihash-codetable", "rand", "serde", "serde_ipld_dagcbor", @@ -1478,6 +1554,7 @@ dependencies = [ "thiserror", "time", "tokio", + "tokio-util", "tracing", "ulid", "url", @@ -2364,6 +2441,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex-conservative" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" +dependencies = [ + "arrayvec", +] + [[package]] name = "hex_fmt" version = "0.3.0" @@ -2600,6 +2686,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2608,12 +2812,23 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -3250,6 +3465,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "lock_api" version = "0.4.12" @@ -3506,20 +3727,20 @@ dependencies = [ [[package]] name = "multihash" -version = "0.19.1" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076d548d76a0e2a0d4ab471d0b1c36c577786dfc4471242035d97a12a735c492" +checksum = "cc41f430805af9d1cf4adae4ed2149c759b877b01d909a1f40256188d09345d2" dependencies = [ "core2", "serde", - "unsigned-varint 0.7.2", + "unsigned-varint 0.8.0", ] [[package]] name = "multihash-codetable" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c11bbdf904c8be009e82ff968c4dab84388cbafc45dfaff61936eca4bf40f1b5" +checksum = "67996849749d25f1da9f238e8ace2ece8f9d6bdf3f9750aaf2ae7de3a5cad8ea" dependencies = [ "blake2b_simd", "blake2s_simd", @@ -3537,9 +3758,9 @@ dependencies = [ [[package]] name = "multihash-derive" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "890e72cb7396cb99ed98c1246a97b243cc16394470d94e0bc8b0c2c11d84290e" +checksum = "1f1b7edab35d920890b88643a765fc9bd295cf0201f4154dda231bef9b8404eb" dependencies = [ "core2", "multihash", @@ -3548,15 +3769,14 @@ dependencies = [ [[package]] name = "multihash-derive-impl" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38685e08adb338659871ecfc6ee47ba9b22dcc8abcf6975d379cc49145c3040" +checksum = "e3dc7141bd06405929948754f0628d247f5ca1865be745099205e5086da957cb" dependencies = [ - "proc-macro-crate 1.3.1", - "proc-macro-error", + "proc-macro-crate", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.77", "synstructure", ] @@ -3933,6 +4153,27 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "partially" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c662358b50ce030ff0bf5c174541e37b2e29564732c34504e8ffb952389e3c1" +dependencies = [ + "partially_derive", +] + +[[package]] +name = "partially_derive" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53059790c7b28bb3a618a75dbfa0c6e569515d0b8d3a4df94baca76d4e14b38" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "password-hash" version = "0.5.0" @@ -4286,23 +4527,13 @@ dependencies = [ "uint", ] -[[package]] -name = "proc-macro-crate" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" -dependencies = [ - "once_cell", - "toml_edit 0.19.15", -] - [[package]] name = "proc-macro-crate" version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284" dependencies = [ - "toml_edit 0.21.1", + "toml_edit", ] [[package]] @@ -5209,6 +5440,26 @@ dependencies = [ "zeroize", ] +[[package]] +name = "secp256k1" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" +dependencies = [ + "bitcoin_hashes", + "rand", + "secp256k1-sys", +] + +[[package]] +name = "secp256k1-sys" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4387882333d3aa8cb20530a17c69a3752e97837832f34f6dccc760e715001d9" +dependencies = [ + "cc", +] + [[package]] name = "security-framework" version = "2.11.0" @@ -5646,6 +5897,8 @@ checksum = "226ed3bf73c05cc884b261171c276c426a3e49bbc8cbb46368fbff24abfbc72f" dependencies = [ "chrono", "educe 0.4.23", + "linked-data", + "serde", "ssi-core", "ssi-crypto", "ssi-eip712", @@ -6020,9 +6273,9 @@ dependencies = [ [[package]] name = "strobe-rs" -version = "0.8.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fabb238a1cccccfa4c4fb703670c0d157e1256c1ba695abf1b93bd2bb14bab2d" +checksum = "98fe17535ea31344936cc58d29fec9b500b0452ddc4cc24c429c8a921a0e84e5" dependencies = [ "bitflags 1.3.2", "byteorder", @@ -6274,14 +6527,13 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.12.6" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", - "unicode-xid", + "syn 2.0.77", ] [[package]] @@ -6436,6 +6688,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -6537,17 +6799,6 @@ version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" -[[package]] -name = "toml_edit" -version = "0.19.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap 2.2.6", - "toml_datetime", - "winnow", -] - [[package]] name = "toml_edit" version = "0.21.1" @@ -6772,12 +7023,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - [[package]] name = "unicode-ident" version = "1.0.12" @@ -6847,9 +7092,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", @@ -6869,12 +7114,24 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + [[package]] name = "utf8-decode" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca61eb27fa339aa08826a29f03e87b99b4d8f0fc2255306fd266bb1b6a9de498" +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" version = "1.10.0" @@ -7297,6 +7554,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "ws_stream_wasm" version = "0.7.4" @@ -7364,6 +7633,30 @@ dependencies = [ "wasm-bindgen-futures", ] +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.34" @@ -7384,6 +7677,27 @@ dependencies = [ "syn 2.0.77", ] +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" @@ -7403,3 +7717,25 @@ dependencies = [ "quote", "syn 2.0.77", ] + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] diff --git a/crates/dwn-rs-core/Cargo.toml b/crates/dwn-rs-core/Cargo.toml index 0e0a942..80dbe44 100644 --- a/crates/dwn-rs-core/Cargo.toml +++ b/crates/dwn-rs-core/Cargo.toml @@ -47,6 +47,7 @@ multicodec = { git = "https://github.com/cryptidtech/rust-multicodec.git" } # multihash = { version = "0.19.1", features = ["serde"] } multihash-codetable = { version = "0.1.4", features = ["serde", "sha2"] } dwn-rs-message-derive = { path = "../dwn-rs-message-derive" } +ssi-claims-core = "0.1.2" [dev-dependencies] serde_json = "1.0.113" diff --git a/crates/dwn-rs-core/src/auth/jws.rs b/crates/dwn-rs-core/src/auth/jws.rs index e18481c..423712d 100644 --- a/crates/dwn-rs-core/src/auth/jws.rs +++ b/crates/dwn-rs-core/src/auth/jws.rs @@ -1,5 +1,18 @@ -use crate::MapValue; +use base64::prelude::{BASE64_URL_SAFE_NO_PAD as base64url, *}; +use futures_util::{stream, StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; +use ssi_jws::{Header, JwsSigner}; +use thiserror::Error; + +use crate::MapValue; + +#[derive(Error, Debug)] +pub enum JwsError { + #[error("Error parsing JWS: {0}")] + ParseError(#[from] serde_json::Error), + #[error("Error signing JWS: {0}")] + SignError(#[from] ssi_claims_core::SignatureError), +} #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] pub struct JWS { @@ -9,18 +22,137 @@ pub struct JWS { pub signatures: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub header: Option, - #[serde(flatten)] + #[serde(flatten)] // TODO: remove? pub extra: MapValue, } +pub struct NoSigner {} +impl JwsSigner for NoSigner { + async fn fetch_info(&self) -> Result { + Ok(ssi_jws::JwsSignerInfo { + key_id: None, + algorithm: ssi_jwk::Algorithm::None, + }) + } + + async fn sign_bytes( + &self, + _signing_bytes: &[u8], + ) -> Result, ssi_claims_core::SignatureError> { + Ok(Vec::new()) + } +} + +impl JWS { + pub async fn create(payload: Vec, signers: Option>) -> Result + where + S: JwsSigner, + { + let payload = base64url.encode(payload); + + if let Some(signers) = signers { + let signatures = Self::generate_signatures(signers, &payload).await?; + Ok(Self { + payload: Some(payload), + signatures: Some(signatures), + header: None, + extra: MapValue::default(), + }) + } else { + Ok(Self { + payload: Some(payload), + signatures: None, + header: None, + extra: MapValue::default(), + }) + } + } + + async fn generate_signatures( + signers: Vec, + payload_encoded: &str, + ) -> Result, JwsError> + where + S: JwsSigner, + { + stream::iter(signers) + .then(|signer| async move { + let result: Result = async { + let info = signer.fetch_info().await?; + let header = Header { + algorithm: info.algorithm, + key_id: info.key_id, + ..Default::default() + }; + let header = serde_json::to_vec(&header)?; + let protected_header = base64url.encode(header); + + let sign_input = format!("{}.{}", protected_header, payload_encoded); + + let signature = signer.sign(&sign_input).await?; + let signature = base64url.encode(signature.as_bytes()); + + Ok(SignatureEntry { + protected: Some(protected_header), + signature: Some(signature), + extra: MapValue::default(), + }) + } + .await; + + result + }) + .try_collect() + .await + } +} + #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] pub struct SignatureEntry { - #[serde(skip_serializing_if = "Option::is_none")] - pub payload: Option, #[serde(skip_serializing_if = "Option::is_none")] pub protected: Option, #[serde(skip_serializing_if = "Option::is_none")] pub signature: Option, - #[serde(flatten)] + #[serde(flatten)] // TODO: remove? pub extra: MapValue, } + +#[cfg(test)] +mod tests { + use super::*; + use ssi_jwk::JWK; + + #[tokio::test] + async fn test_jws_create() { + let jwk = JWK::generate_ed25519().expect("could not generate key"); + let jws = JWS::create(b"hello world".to_vec(), Some(vec![jwk])) + .await + .expect("could not create JWS"); + + assert_eq!(jws.payload, Some("aGVsbG8gd29ybGQ".to_string())); + assert_eq!(jws.signatures.as_ref().unwrap().len(), 1); + assert_eq!( + jws.signatures.as_ref().unwrap()[0] + .protected + .as_ref() + .unwrap(), + "eyJhbGciOiJFZERTQSJ9" + ); + + assert!(!jws.signatures.as_ref().unwrap()[0] + .signature + .as_ref() + .unwrap() + .is_empty()); + } + + #[tokio::test] + async fn test_jws_create_no_signers() { + let jws = JWS::create::(b"hello world".to_vec(), None) + .await + .expect("could not create JWS"); + + assert_eq!(jws.payload, Some("aGVsbG8gd29ybGQ".to_string())); + assert!(jws.signatures.is_none()); + } +} diff --git a/crates/dwn-rs-core/src/auth/mod.rs b/crates/dwn-rs-core/src/auth/mod.rs index 0cb70af..ef4d1cf 100644 --- a/crates/dwn-rs-core/src/auth/mod.rs +++ b/crates/dwn-rs-core/src/auth/mod.rs @@ -1,5 +1,6 @@ pub mod authorization; +pub mod encryption; pub mod jws; pub use authorization::{Authorization, AuthorizationDelegatedGrant, AuthorizationOwner}; -pub use jws::JWS; +pub use jws::{JwsError, JWS}; // TODO: JWS -> Jws From 72b66bf10a09e9f18db0a0109e35b13fa8fe59c2 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Thu, 21 Nov 2024 15:49:55 -0400 Subject: [PATCH 12/23] fix: `pub mod utils` in `dwn-rs-core` --- crates/dwn-rs-core/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/dwn-rs-core/src/lib.rs b/crates/dwn-rs-core/src/lib.rs index 63dc748..6667908 100644 --- a/crates/dwn-rs-core/src/lib.rs +++ b/crates/dwn-rs-core/src/lib.rs @@ -35,3 +35,5 @@ pub use events::*; pub use filters::*; pub use interfaces::*; pub use value::*; + +pub mod utils; From 49926a8a6c3a79e1b9c85ca3cacbf5137ea29f8f Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Fri, 22 Nov 2024 08:28:09 -0400 Subject: [PATCH 13/23] fix: core tests for event stream, field serialization --- crates/dwn-rs-core/src/events/stream.rs | 7 ++++++- crates/dwn-rs-core/src/interfaces/messages/fields.rs | 6 ------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crates/dwn-rs-core/src/events/stream.rs b/crates/dwn-rs-core/src/events/stream.rs index 509cc2f..f294624 100644 --- a/crates/dwn-rs-core/src/events/stream.rs +++ b/crates/dwn-rs-core/src/events/stream.rs @@ -267,6 +267,7 @@ where #[cfg(test)] mod test { + use chrono::TimeZone; use tracing_test::traced_test; use xtra::{spawn_tokio, Mailbox}; @@ -278,10 +279,14 @@ mod test { use super::*; fn test_evt() -> MessageEvent { + let now = chrono::DateTime::::MIN_UTC.naive_utc(); MessageEvent { message: Message { descriptor: Descriptor::Records(Records::Read(records::ReadDescriptor { - message_timestamp: chrono::Utc::now(), + message_timestamp: chrono::DateTime::from_naive_utc_and_offset( + now, + chrono::Utc, + ), filter: Default::default(), })), fields: Fields::Authorization(Default::default()), diff --git a/crates/dwn-rs-core/src/interfaces/messages/fields.rs b/crates/dwn-rs-core/src/interfaces/messages/fields.rs index 5d96830..201b56c 100644 --- a/crates/dwn-rs-core/src/interfaces/messages/fields.rs +++ b/crates/dwn-rs-core/src/interfaces/messages/fields.rs @@ -193,7 +193,6 @@ mod tests { signature: JWS { payload: Some("payload".to_string()), signatures: Some(vec![SignatureEntry { - payload: Some("payload".to_string()), protected: Some("protected".to_string()), signature: Some("signature".to_string()), ..Default::default() @@ -219,7 +218,6 @@ mod tests { attestation: Some(JWS { payload: Some("payload".to_string()), signatures: Some(vec![SignatureEntry { - payload: Some("payload".to_string()), protected: Some("protected".to_string()), signature: Some("signature".to_string()), ..Default::default() @@ -253,7 +251,6 @@ mod tests { "payload": "payload", "signatures": [ {{ - "payload": "payload", "protected": "protected", "signature": "signature" }} @@ -264,7 +261,6 @@ mod tests { "payload": "payload", "signatures": [ {{ - "payload": "payload", "protected": "protected", "signature": "signature" }} @@ -368,7 +364,6 @@ mod tests { "payload": "payload", "signatures": [ {{ - "payload": "payload", "protected": "protected", "signature": "signature" }} @@ -379,7 +374,6 @@ mod tests { "payload": "payload", "signatures": [ {{ - "payload": "payload", "protected": "protected", "signature": "signature" }} From 82340f824775b1ab28193a3fa92f3ad1e1166e86 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Fri, 22 Nov 2024 13:47:05 -0400 Subject: [PATCH 14/23] chore: replace secp256k1 with k256 in Persona utils --- crates/dwn-rs-core/Cargo.toml | 11 ++++++++--- crates/dwn-rs-core/src/utils/mod.rs | 13 +++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/crates/dwn-rs-core/Cargo.toml b/crates/dwn-rs-core/Cargo.toml index 80dbe44..d2c8648 100644 --- a/crates/dwn-rs-core/Cargo.toml +++ b/crates/dwn-rs-core/Cargo.toml @@ -30,7 +30,11 @@ tokio = { version = "1.39.2", features = ["io-util", "rt", "macros"] } derive_more = { version = "1.0", features = ["display", "from", "try_into"] } ssi-dids-core = "0.1.0" base64 = "0.22.1" -ssi-jwk = "0.3.0" +ssi-jwk = { version = "0.3.1", features = [ + "secp256r1", + "secp256k1", + "ed25519", +] } ssi-jws = "0.3.0" ulid = { version = "1.1.2", features = ["serde"] } url = { version = "2.5.0", features = ["serde"] } @@ -40,14 +44,15 @@ tracing = "0.1.40" tracing-test = { version = "0.2.5", features = ["no-env-filter"] } bytes = "1.8.0" rand = "0.8.5" -secp256k1 = { version = "0.30", features = ["rand"] } partially = { version = "0.2.1", features = ["derive"] } derive_builder = "0.20.2" -multicodec = { git = "https://github.com/cryptidtech/rust-multicodec.git" } # Use moden fork, see gnunicorn/rust-multicodec#1 +multicodec = { git = "https://github.com/cryptidtech/rust-multicodec.git" } # Use moden fork, see gnunicorn/rust-multicodec#1 multihash = { version = "0.19.1", features = ["serde"] } multihash-codetable = { version = "0.1.4", features = ["serde", "sha2"] } dwn-rs-message-derive = { path = "../dwn-rs-message-derive" } ssi-claims-core = "0.1.2" +k256 = "0.13.4" +hkdf = "0.12.4" [dev-dependencies] serde_json = "1.0.113" diff --git a/crates/dwn-rs-core/src/utils/mod.rs b/crates/dwn-rs-core/src/utils/mod.rs index 3bb6af4..7651265 100644 --- a/crates/dwn-rs-core/src/utils/mod.rs +++ b/crates/dwn-rs-core/src/utils/mod.rs @@ -1,8 +1,8 @@ pub mod cid; +use k256::{PublicKey, SecretKey}; use partially::Partial; use rand::{distributions::Alphanumeric, Rng}; -use secp256k1::{Keypair, Secp256k1}; use ssi_dids_core::DIDBuf; use std::str::FromStr; use thiserror::Error; @@ -18,7 +18,7 @@ pub enum PersonaError { pub struct Persona { pub did: DIDBuf, pub key_id: String, - keypair: secp256k1::Keypair, + keypair: (SecretKey, PublicKey), } impl Persona { @@ -35,9 +35,10 @@ impl Persona { }); let keypair = keypair.unwrap_or_else(|| { - let secp = Secp256k1::new(); + let rng = &mut rand::thread_rng(); + let secp = SecretKey::random(rng); - Keypair::new(&secp, &mut rand::thread_rng()) + (secp.clone(), secp.public_key()) }); let key_id = key_id.unwrap_or_else(|| { @@ -52,8 +53,8 @@ impl Persona { }) } - pub fn public_key(&self) -> secp256k1::PublicKey { - self.keypair.public_key() + pub fn public_key(&self) -> k256::PublicKey { + self.keypair.1 } } From c219a845b6b5812a0f4ab329c19ba91165e63f3d Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Fri, 22 Nov 2024 13:58:34 -0400 Subject: [PATCH 15/23] feat: add HKDF key derivation for records encryption --- crates/dwn-rs-core/src/encryption/hd_keys.rs | 217 +++++++++++++++++++ crates/dwn-rs-core/src/encryption/mod.rs | 80 +++++++ crates/dwn-rs-core/src/lib.rs | 1 + 3 files changed, 298 insertions(+) create mode 100644 crates/dwn-rs-core/src/encryption/hd_keys.rs create mode 100644 crates/dwn-rs-core/src/encryption/mod.rs diff --git a/crates/dwn-rs-core/src/encryption/hd_keys.rs b/crates/dwn-rs-core/src/encryption/hd_keys.rs new file mode 100644 index 0000000..90f0aae --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/hd_keys.rs @@ -0,0 +1,217 @@ +use k256::{sha2, SecretKey}; +use ssi_jwk::{secp256k1_parse_private, Params, JWK}; + +use super::DerivationScheme; +use thiserror::Error; + +const HKDF_KEY_LENGTH: usize = 32; // * 8; // 32 bytes = 256 bits + +#[derive(Debug, Error)] +pub enum Error { + #[error("Error getting secret key: {0}")] + SecretKeyError(#[from] ssi_jwk::Error), + #[error("Error deriving key, bad key length: {0}")] + DeriveKeyLengthError(hkdf::InvalidLength), + #[error("Error deriving key: {0}")] + DeriveKeyError(#[from] k256::elliptic_curve::Error), + #[error("Error encoding key: {0}")] + EncodeError(#[from] k256::pkcs8::der::Error), + #[error("Invalid path segment: {0}")] + InvalidPathSegment(String), + #[error("Unsupported hash algorithm: {0}")] + UnsupportedHashAlgorithm(String), + #[error("Unsupported key type")] + UnsupportedKeyType, +} + +/// DerivedPrivateJWK represents a derived private JWK, which includes the root key ID, derivation +/// scheme, derivation path, and the key itself. This is used for encrypting records with keys +/// derived from a root key. +#[derive(Debug)] +pub struct DerivedPrivateJWK { + pub root_key_id: String, + pub scheme: DerivationScheme, + pub path: Option>, + pub key: JWK, +} + +/// HashAlgorithm represents the hash algorithm used for key derivation. +#[derive(PartialEq)] +pub enum HashAlgorithm { + SHA256, + SHA384, + SHA512, +} + +impl DerivedPrivateJWK { + /// derive derives a new private key from the root key using the derivation path. + pub fn derive( + ancestor_key: DerivedPrivateJWK, + derivation_path: Vec, + ) -> Result { + let path: &[&str] = &derivation_path + .iter() + .map(|s| s.as_str()) + .collect::>(); + if let Params::EC(ecparam) = ancestor_key.key.params { + // TODO support x25519 + let sk: k256::SecretKey = (&ecparam).try_into()?; + let ancestor_path = ancestor_key.path.unwrap_or_default(); + + let derived_key = Self::derive_private_key(&sk, path)?; + + let mut pk: JWK = sk.public_key().into(); + let derived_jwk = secp256k1_parse_private(&derived_key.to_sec1_der()?)?; + pk.params = derived_jwk.params.clone(); + + return Ok(DerivedPrivateJWK { + root_key_id: ancestor_key.root_key_id, + scheme: ancestor_key.scheme, + path: Some([ancestor_path, derivation_path].concat()), + key: pk, + }); + }; + + Err(Error::UnsupportedKeyType) + } + + pub fn derive_public_key( + ancestor_key: DerivedPrivateJWK, + derivation_path: &[&str], + ) -> Result { + if let Params::EC(ecparam) = ancestor_key.key.params { + // TODO support x25519 + let sk: k256::SecretKey = (&ecparam).try_into()?; + + let derived_key = Self::derive_private_key(&sk, derivation_path)?; + let derived_jwk = derived_key.public_key().into(); + + return Ok(derived_jwk); + } + + Err(Error::UnsupportedKeyType) + } + pub fn derive_private_key( + ancestor_key: &SecretKey, + derivation_path: &[&str], + ) -> Result { + Self::validate_path(derivation_path)?; + + let sk = derivation_path.iter().try_fold( + ancestor_key.to_owned(), + |key, segment| -> Result { + let seg = segment.as_bytes(); + let key_material = key.to_bytes(); + Self::derive_hkdf_key(HashAlgorithm::SHA256, &key_material, seg) + }, + )?; + + Ok(sk) + } + + pub fn derive_hkdf_key( + hash_algo: HashAlgorithm, + initial_key_material: &[u8], + info: &[u8], + ) -> Result { + if hash_algo != HashAlgorithm::SHA256 { + // TODO support more algorithms + return Err(Error::UnsupportedHashAlgorithm( + "Unsupported hash algorithm".to_string(), + )); + } + + let mut okm = [0u8; HKDF_KEY_LENGTH]; + + hkdf::Hkdf::::new(None, initial_key_material) + .expand(info, &mut okm) + .map_err(Error::DeriveKeyLengthError)?; + + Ok(SecretKey::from_slice(&okm)?) + } + + fn validate_path(path: &[&str]) -> Result<(), Error> { + // check if any path segments are empty + if path.iter().any(|s| s.is_empty()) { + return Err(Error::InvalidPathSegment("Empty path segment".to_string())); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ssi_jwk::JWK; + + #[test] + fn test_derive() { + let root_key = JWK::generate_secp256k1(); + let root_key_id = "root_key_id".to_string(); + let scheme = DerivationScheme::ProtocolPath; + let path = &["path"]; + let derived_key = DerivedPrivateJWK { + root_key_id: root_key_id.clone(), + scheme, + path: Some(path.iter().map(|s| s.to_string()).collect()), + key: root_key, + }; + let derived = DerivedPrivateJWK::derive(derived_key, vec!["path2".to_string()]).unwrap(); + + assert_eq!(derived.root_key_id, root_key_id); + assert_eq!(derived.scheme, DerivationScheme::ProtocolPath); + assert_eq!( + derived.path, + Some(vec!["path".to_string(), "path2".to_string()]) + ); + } + + #[test] + fn test_derive_public_key() { + let root_key = JWK::generate_secp256k1(); + let root_key_id = "root_key_id".to_string(); + let scheme = DerivationScheme::ProtocolPath; + let path = &["path"]; + let derived_key = DerivedPrivateJWK { + root_key_id: root_key_id.clone(), + scheme, + path: Some(path.iter().map(|s| s.to_string()).collect()), + key: root_key.clone(), + }; + + let derived = DerivedPrivateJWK::derive_public_key(derived_key, path).unwrap(); + + assert!(derived.params.is_public()); + } + + #[test] + fn test_derive_ancestor_chain_path() { + let root_key = k256::SecretKey::random(&mut rand::thread_rng()); + + let path_to_g = ["a", "b", "c", "d", "e", "f", "g"].as_slice(); + let path_to_d = ["a", "b", "c", "d"].as_slice(); + let path_e_to_g = ["e", "f", "g"].as_slice(); + + let keyg = DerivedPrivateJWK::derive_private_key(&root_key, path_to_g).unwrap(); + let keyd = DerivedPrivateJWK::derive_private_key(&root_key, path_to_d).unwrap(); + let keydg = DerivedPrivateJWK::derive_private_key(&keyd, path_e_to_g).unwrap(); + + assert_eq!(keyg, keydg); + assert_ne!(keyg, keyd); + } + + #[test] + fn test_invalid_path() { + let root_key = k256::SecretKey::random(&mut rand::thread_rng()); + let path = ["a", "", "c"].as_slice(); + + let result = DerivedPrivateJWK::derive_private_key(&root_key, path); + + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Invalid path segment: Empty path segment" + ); + } +} diff --git a/crates/dwn-rs-core/src/encryption/mod.rs b/crates/dwn-rs-core/src/encryption/mod.rs new file mode 100644 index 0000000..e2b0389 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/mod.rs @@ -0,0 +1,80 @@ +pub mod hd_keys; + +pub use hd_keys::*; + +use serde::{Deserialize, Serialize}; +use ssi_jwk::JWK; + +/// EncryptionAlgorithm represents the encryption algorithm used for encrypting records. Currently +/// A256CTR is the only supported algorithm. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum EncryptionAlgorithm { + A256CTR, +} + +// DerivationScheme represents the derivation scheme used for deriving keys for encryption. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum DerivationScheme { + #[serde(rename = "dataFormats")] + DataFormats, + #[serde(rename = "protocolContext")] + ProtocolContext, + #[serde(rename = "protocolPath")] + ProtocolPath, + #[serde(rename = "schemas")] + Schemas, +} + +/// KeyEncryptionAlgorithm represents the key encryption algorithm used for encrypting keys. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[serde(untagged)] +pub enum KeyEncryptionAlgorithm { + Asymmetric(KeyEncryptionAlgorithmAsymmetric), + Symmetric(KeyEncryptionAlgorithmSymmetric), +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum KeyEncryptionAlgorithmAsymmetric { + #[serde(rename = "ECIES-ES256K")] + EciesSecp256k1, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum KeyEncryptionAlgorithmSymmetric { + #[serde(rename = "A256CTR")] + AES256CTR, + #[serde(rename = "A256GCM")] + AES256GCM, + #[serde(rename = "XSalsa20-Poly1305")] + XSalsa20Poly1305, +} + +/// KeyEncryption represents the key encryption used for encrypting keys. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct KeyEncryption { + pub algorithm: KeyEncryptionAlgorithm, + #[serde(rename = "rootKeyId")] + pub root_key_id: String, + #[serde(rename = "derivationScheme")] + pub derivation_scheme: DerivationScheme, + #[serde(rename = "derivedPublicKey")] + pub derived_public_key: Option, + #[serde(rename = "encryptedKey")] + pub encrypted_key: String, + #[serde(rename = "initializationVector")] + pub initialization_vector: String, + #[serde(rename = "ephemeralPublicKey")] + pub ephemeral_public_key: JWK, + #[serde(rename = "messageAuthenticationCode")] + pub message_authentication_code: String, +} + +/// Encryption represents the encryption used for encrypting records. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Encryption { + pub algorithm: EncryptionAlgorithm, + #[serde(rename = "initializationVector")] + pub initialization_vector: String, + #[serde(rename = "keyEncryption")] + pub key_encryption: Vec, +} diff --git a/crates/dwn-rs-core/src/lib.rs b/crates/dwn-rs-core/src/lib.rs index 6667908..9211333 100644 --- a/crates/dwn-rs-core/src/lib.rs +++ b/crates/dwn-rs-core/src/lib.rs @@ -23,6 +23,7 @@ //! - [`messages::records::RecordsDelete`]: A descriptor for reading records. #![doc(issue_tracker_base_url = "https://github.com/enmand/dwn-rsissues/")] pub mod auth; +pub mod encryption; pub mod errors; pub mod events; pub mod filters; From 461eb14083438d2e20202ab85d904eae48392f1f Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Mon, 2 Dec 2024 23:35:19 -0400 Subject: [PATCH 16/23] tests: use secp256k1 key for tests --- crates/dwn-rs-core/src/auth/jws.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/dwn-rs-core/src/auth/jws.rs b/crates/dwn-rs-core/src/auth/jws.rs index 423712d..1144c27 100644 --- a/crates/dwn-rs-core/src/auth/jws.rs +++ b/crates/dwn-rs-core/src/auth/jws.rs @@ -124,7 +124,7 @@ mod tests { #[tokio::test] async fn test_jws_create() { - let jwk = JWK::generate_ed25519().expect("could not generate key"); + let jwk = JWK::generate_secp256k1(); let jws = JWS::create(b"hello world".to_vec(), Some(vec![jwk])) .await .expect("could not create JWS"); @@ -136,7 +136,7 @@ mod tests { .protected .as_ref() .unwrap(), - "eyJhbGciOiJFZERTQSJ9" + "eyJhbGciOiJFUzI1NksifQ" ); assert!(!jws.signatures.as_ref().unwrap()[0] From a1f4a8f70ee20ec067bac3dd2941fa41a0291b7e Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Mon, 2 Dec 2024 23:36:24 -0400 Subject: [PATCH 17/23] (wip) feat: key derivation and ECIES asymmetric encryption for encrypted RecordsWrite --- crates/dwn-rs-core/src/auth/mod.rs | 1 - .../src/encryption/asymmetric/mod.rs | 12 + .../src/encryption/asymmetric/publickey.rs | 94 +++++++ .../src/encryption/asymmetric/secretkey.rs | 241 ++++++++++++++++++ crates/dwn-rs-core/src/encryption/errors.rs | 12 + crates/dwn-rs-core/src/encryption/hd_keys.rs | 226 +++++++++------- crates/dwn-rs-core/src/encryption/mod.rs | 6 +- 7 files changed, 503 insertions(+), 89 deletions(-) create mode 100644 crates/dwn-rs-core/src/encryption/asymmetric/mod.rs create mode 100644 crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs create mode 100644 crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs create mode 100644 crates/dwn-rs-core/src/encryption/errors.rs diff --git a/crates/dwn-rs-core/src/auth/mod.rs b/crates/dwn-rs-core/src/auth/mod.rs index ef4d1cf..3c204fb 100644 --- a/crates/dwn-rs-core/src/auth/mod.rs +++ b/crates/dwn-rs-core/src/auth/mod.rs @@ -1,5 +1,4 @@ pub mod authorization; -pub mod encryption; pub mod jws; pub use authorization::{Authorization, AuthorizationDelegatedGrant, AuthorizationOwner}; diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs new file mode 100644 index 0000000..f6fe363 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs @@ -0,0 +1,12 @@ +pub mod publickey; +pub mod secretkey; + +use thiserror::Error; +#[derive(Error, Debug)] +pub enum ECIESError { + #[error("Invalid HKDF key length: {0}")] + InvalidHKDFKeyLength(hkdf::InvalidLength), +} + +pub use publickey::PublicKey; +pub use secretkey::{ParseError, PrivateKeyError, SecretKey}; diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs new file mode 100644 index 0000000..d6804a3 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs @@ -0,0 +1,94 @@ +use k256::{elliptic_curve::sec1::ToEncodedPoint, sha2}; +use ssi_jwk::{Base64urlUInt, OctetParams, Params, JWK}; + +use super::{ECIESError, SecretKey}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Error parsing JWK: {0}")] + PublicKeyError(#[from] ssi_jwk::Error), + #[error("Unsupported Curve: {0}")] + InvalidCurve(String), + #[error("ECIES encryption error: {0}")] + ECIESError(#[from] ECIESError), +} + +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum PublicKey { + Secp256k1(k256::PublicKey), + X25519(x25519_dalek::PublicKey), +} + +impl PublicKey { + pub fn to_bytes(&self) -> Vec { + match self { + PublicKey::Secp256k1(pk) => pk.to_encoded_point(true).as_bytes().to_vec(), + PublicKey::X25519(pk) => pk.as_bytes().to_vec(), + } + } + + pub fn jwk(&self) -> JWK { + match self { + PublicKey::Secp256k1(pk) => (*pk).into(), + PublicKey::X25519(pk) => JWK::from(Params::OKP(OctetParams { + curve: "X25519".to_string(), + public_key: Base64urlUInt(pk.to_bytes().to_vec()), + private_key: None, + })), + } + } + + pub fn decapsulate(self, sk: &SecretKey) -> Result<[u8; 32], Error> { + match (self, sk) { + (PublicKey::Secp256k1(pk), SecretKey::Secp256k1(sk)) => { + let mut okm = [0u8; 32]; + k256::ecdh::diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine()) + .extract::(None) + .expand(&[], &mut okm) + .map_err(ECIESError::InvalidHKDFKeyLength)?; + Ok(okm) + } + (PublicKey::X25519(pk), SecretKey::X25519(sk)) => Ok(sk.diffie_hellman(&pk).to_bytes()), + _ => Err(Error::InvalidCurve("Unsupported key type".to_string())), + } + } +} + +impl From<&SecretKey> for PublicKey { + fn from(sk: &SecretKey) -> Self { + match sk { + SecretKey::Secp256k1(sk) => PublicKey::Secp256k1(sk.public_key()), + SecretKey::X25519(sk) => PublicKey::X25519(sk.into()), + } + } +} + +impl TryFrom for PublicKey { + type Error = Error; + + fn try_from(jwk: JWK) -> Result { + match jwk.params { + Params::EC(ref ec) => Ok(PublicKey::Secp256k1(ec.try_into()?)), + Params::OKP(ref op) => match op.curve.to_lowercase().as_str() { + "x25519" => { + let mut sk = [0u8; 32]; + sk.copy_from_slice(&op.public_key.0); + Ok(PublicKey::X25519(x25519_dalek::PublicKey::from(sk))) + } + "ed25519" => { + let pk: ed25519_dalek::VerifyingKey = op.try_into()?; + Ok(PublicKey::X25519(x25519_dalek::PublicKey::from( + pk.to_montgomery().to_bytes(), + ))) + } + _ => Err(Error::InvalidCurve(format!( + "Unsupported curve: {}", + op.curve + ))), + }, + _ => Err(Error::InvalidCurve("Unsupported key type".to_string())), + } + } +} diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs new file mode 100644 index 0000000..709efb9 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs @@ -0,0 +1,241 @@ +use std::fmt::Debug; + +use k256::sha2; +use ssi_jwk::{secp256k1_parse_private, Base64urlUInt, OctetParams, Params, JWK}; +use thiserror::Error; + +use super::{ECIESError, PublicKey}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Error getting SecretKey from bytes: {0}")] + SecretKeyError(String), + #[error("Error parsing private key: {0}")] + PrivateKeyError(#[from] PrivateKeyError), + #[error("ECIES encryption error: {0}")] + ECIESError(#[from] ECIESError), +} + +#[derive(Error, Debug)] +pub enum PrivateKeyError { + #[error("Error encoding key: {0}")] + EncodeError(#[from] k256::pkcs8::der::Error), + #[error("Error parsing private key: {0}")] + PrivateKeyError(#[from] ssi_jwk::Error), + #[error("Error parsing private key: {0}")] + ParseError(#[from] ParseError), +} + +#[derive(Error, Debug)] +pub enum ParseError { + #[error("Error parsing secp256k1 private key: {0}")] + Secp256k1(#[from] k256::elliptic_curve::Error), + #[error("Error parsing x25519 private key: {0}")] + X25519(String), + #[error("Error parsing ed25519 private key: {0}")] + Ed25519(#[from] ed25519_dalek::SignatureError), +} + +/// SecretKey represents a private asymmetric key. Supported key types are: +/// - secp256k1 +/// - x25519 +/// - ed25519 (converted to x25519) +/// +/// secp256k1 keys are preferred, since the x26619 keys are converted from ed25519 keys. See +/// also: https://eprint.iacr.org/2021/509 +#[derive(Clone)] +#[non_exhaustive] +pub enum SecretKey { + Secp256k1(k256::SecretKey), + X25519(x25519_dalek::StaticSecret), +} + +impl PartialEq for SecretKey { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (SecretKey::Secp256k1(sk1), SecretKey::Secp256k1(sk2)) => sk1 == sk2, + (SecretKey::X25519(sk1), SecretKey::X25519(sk2)) => sk1.as_bytes() == sk2.as_bytes(), + _ => false, + } + } +} + +impl Debug for SecretKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn x25519_bytes(sk: &x25519_dalek::StaticSecret) -> [u8; 32] { + let pk: x25519_dalek::PublicKey = sk.into(); + pk.to_bytes() + } + + match self { + SecretKey::Secp256k1(sk) => write!(f, "Secp256k1(pub: {:?})", sk.public_key()), + SecretKey::X25519(sk) => write!(f, "X25519(pub: {:?})", x25519_bytes(sk)), + } + } +} + +impl SecretKey { + pub fn to_bytes(&self) -> Vec { + match self { + SecretKey::Secp256k1(sk) => sk.to_bytes().to_vec(), + SecretKey::X25519(sk) => sk.as_bytes().to_vec(), + } + } + + pub fn public_key(&self) -> PublicKey { + self.into() + } + + pub fn jwk(&self) -> Result { + match self { + SecretKey::Secp256k1(sk) => { + let mut jwk: JWK = sk.public_key().into(); + let pjwk = secp256k1_parse_private( + &sk.to_sec1_der().map_err(PrivateKeyError::EncodeError)?, + ) + .map_err(PrivateKeyError::PrivateKeyError)?; + jwk.params = pjwk.params.clone(); + + Ok(jwk) + } + SecretKey::X25519(sk) => { + let pk: x25519_dalek::PublicKey = sk.into(); + let jwk = JWK::from(Params::OKP(OctetParams { + curve: "X25519".to_string(), + public_key: Base64urlUInt(pk.as_bytes().to_vec()), + private_key: Some(Base64urlUInt(sk.as_bytes().to_vec())), + })); + + Ok(jwk) + } + } + } + + pub fn encapsulate(self, pk: PublicKey) -> Result<[u8; 32], Error> { + // TODO support key compression for secp256k1 hkdf key and ephemeral key + match (self, pk) { + (SecretKey::Secp256k1(sk), PublicKey::Secp256k1(pk)) => { + let mut okm = [0u8; 32]; + k256::ecdh::diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine()) + .extract::(None) + .expand(&[], &mut okm) + .map_err(ECIESError::InvalidHKDFKeyLength)?; + + Ok(okm) + } + (SecretKey::X25519(sk), PublicKey::X25519(pk)) => Ok(sk.diffie_hellman(&pk).to_bytes()), + _ => Err(Error::SecretKeyError("Unsupported key type".to_string())), + } + } +} + +/// TryFrom (&SecretKey, &[u8; 32]) for SecretKey implements the conversion of a HKDF derived key +/// into a SecretKey of the same type +impl TryFrom<(&SecretKey, &[u8; 32])> for SecretKey { + type Error = Error; + + fn try_from(value: (&SecretKey, &[u8; 32])) -> Result { + let sk = match value.0 { + SecretKey::Secp256k1(_) => { + let sk: k256::SecretKey = k256::SecretKey::from_slice(value.1) + .map_err(|e| PrivateKeyError::ParseError(ParseError::Secp256k1(e)))?; + SecretKey::Secp256k1(sk) + } + SecretKey::X25519(_) => { + let mut sk_bytes = [0u8; 32]; + sk_bytes.copy_from_slice(value.1); + let sk: x25519_dalek::StaticSecret = x25519_dalek::StaticSecret::from(sk_bytes); + + SecretKey::X25519(sk) + } + }; + + Ok(sk) + } +} + +/// TryFrom for a SecretKey implements the converstion of a (private) JWK into a SecretKey +impl TryFrom for SecretKey { + type Error = Error; + fn try_from(jwk: JWK) -> Result { + match jwk.params { + Params::EC(ecparams) => { + let sk: k256::SecretKey = (&ecparams) + .try_into() + .map_err(PrivateKeyError::PrivateKeyError)?; + + Ok(SecretKey::Secp256k1(sk)) + } + Params::OKP(okpparams) => { + if okpparams.curve.to_lowercase() == "x25519" { + let sk: [u8; 32] = match okpparams.private_key.clone() { + Some(sk) => { + let mut sk_bytes = [0u8; 32]; + sk_bytes.copy_from_slice(sk.0.as_slice()); + sk_bytes + } + None => { + return Err(Error::SecretKeyError("Missing private key".to_string())) + } + }; + + Ok(SecretKey::X25519(x25519_dalek::StaticSecret::from(sk))) + } else if okpparams.curve.to_lowercase() == "ed25519" { + let edsk: ed25519_dalek::SigningKey = (&okpparams) + .try_into() + .map_err(PrivateKeyError::PrivateKeyError)?; + + Ok(SecretKey::X25519(x25519_dalek::StaticSecret::from( + edsk.to_scalar_bytes(), + ))) + } else { + Err(Error::SecretKeyError(format!( + "Unsupported curve type: {}", + okpparams.curve + ))) + } + } + _ => Err(Error::SecretKeyError("Unsupported key type".to_string())), + } + } +} + +impl From for SecretKey { + fn from(sk: ed25519_dalek::SigningKey) -> Self { + SecretKey::X25519(x25519_dalek::StaticSecret::from(sk.to_scalar_bytes())) + } +} + +impl TryFrom for JWK { + type Error = Error; + fn try_from(sk: SecretKey) -> Result { + sk.jwk() + } +} + +#[cfg(test)] +mod test { + use super::*; + use ssi_jwk::JWK; + use std::convert::TryInto; + + #[test] + fn test_secret_key() { + let sk = SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())); + let jwk: JWK = sk.jwk().unwrap(); + let sk2: SecretKey = jwk.try_into().unwrap(); + assert_eq!(sk, sk2); + + let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( + rand::thread_rng(), + )); + let jwk: JWK = sk.jwk().unwrap(); + let sk2: SecretKey = jwk.try_into().unwrap(); + assert_eq!(sk, sk2); + + let sk: SecretKey = ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(); + let jwk: JWK = sk.jwk().unwrap(); + let sk2: SecretKey = jwk.try_into().unwrap(); + assert_eq!(sk, sk2); + } +} diff --git a/crates/dwn-rs-core/src/encryption/errors.rs b/crates/dwn-rs-core/src/encryption/errors.rs new file mode 100644 index 0000000..cbcf28c --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/errors.rs @@ -0,0 +1,12 @@ +use thiserror::Error; + +use super::asymmetric::secretkey::Error as AsymmetricSecretKeyError; +use super::hd_keys::Error as HDKeysError; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Error getting JWK secret: {0}")] + JWKSecretKeyError(#[from] AsymmetricSecretKeyError), + #[error("Error deriving key: {0}")] + DeriveKeyError(#[from] HDKeysError), +} diff --git a/crates/dwn-rs-core/src/encryption/hd_keys.rs b/crates/dwn-rs-core/src/encryption/hd_keys.rs index 90f0aae..82bb851 100644 --- a/crates/dwn-rs-core/src/encryption/hd_keys.rs +++ b/crates/dwn-rs-core/src/encryption/hd_keys.rs @@ -1,15 +1,17 @@ -use k256::{sha2, SecretKey}; -use ssi_jwk::{secp256k1_parse_private, Params, JWK}; +use k256::sha2; +use ssi_jwk::JWK; -use super::DerivationScheme; -use thiserror::Error; +use super::{asymmetric, DerivationScheme, SecretKey}; +use thiserror::Error as ThisError; const HKDF_KEY_LENGTH: usize = 32; // * 8; // 32 bytes = 256 bits -#[derive(Debug, Error)] +#[derive(Debug, ThisError)] pub enum Error { - #[error("Error getting secret key: {0}")] - SecretKeyError(#[from] ssi_jwk::Error), + #[error("Error getting JWK secret key: {0}")] + JWKSecretKeyError(#[from] ssi_jwk::Error), + #[error("Error getting SecretKey from bytes: {0}")] + SecretKeyError(#[from] asymmetric::secretkey::Error), #[error("Error deriving key, bad key length: {0}")] DeriveKeyLengthError(hkdf::InvalidLength), #[error("Error deriving key: {0}")] @@ -53,45 +55,32 @@ impl DerivedPrivateJWK { .iter() .map(|s| s.as_str()) .collect::>(); - if let Params::EC(ecparam) = ancestor_key.key.params { - // TODO support x25519 - let sk: k256::SecretKey = (&ecparam).try_into()?; - let ancestor_path = ancestor_key.path.unwrap_or_default(); - - let derived_key = Self::derive_private_key(&sk, path)?; - - let mut pk: JWK = sk.public_key().into(); - let derived_jwk = secp256k1_parse_private(&derived_key.to_sec1_der()?)?; - pk.params = derived_jwk.params.clone(); - - return Ok(DerivedPrivateJWK { - root_key_id: ancestor_key.root_key_id, - scheme: ancestor_key.scheme, - path: Some([ancestor_path, derivation_path].concat()), - key: pk, - }); - }; - - Err(Error::UnsupportedKeyType) + + let sk: SecretKey = ancestor_key.key.try_into()?; + let ancestor_path = ancestor_key.path.unwrap_or_default(); + let derived_key = Self::derive_secret(&sk, path)?; + let pjwk: JWK = derived_key.jwk()?; + + Ok(DerivedPrivateJWK { + root_key_id: ancestor_key.root_key_id, + scheme: ancestor_key.scheme, + path: Some([ancestor_path, derivation_path].concat()), + key: pjwk, + }) } + /// derive_public_key derives a new public key from the root key using the derivation path. pub fn derive_public_key( ancestor_key: DerivedPrivateJWK, - derivation_path: &[&str], + derivation_path: Vec, ) -> Result { - if let Params::EC(ecparam) = ancestor_key.key.params { - // TODO support x25519 - let sk: k256::SecretKey = (&ecparam).try_into()?; - - let derived_key = Self::derive_private_key(&sk, derivation_path)?; - let derived_jwk = derived_key.public_key().into(); + let derived_key = Self::derive(ancestor_key, derivation_path)?; + let sk: SecretKey = derived_key.key.try_into()?; - return Ok(derived_jwk); - } - - Err(Error::UnsupportedKeyType) + Ok(sk.public_key().jwk()) } - pub fn derive_private_key( + + pub fn derive_secret( ancestor_key: &SecretKey, derivation_path: &[&str], ) -> Result { @@ -101,8 +90,7 @@ impl DerivedPrivateJWK { ancestor_key.to_owned(), |key, segment| -> Result { let seg = segment.as_bytes(); - let key_material = key.to_bytes(); - Self::derive_hkdf_key(HashAlgorithm::SHA256, &key_material, seg) + Self::derive_hkdf_key(HashAlgorithm::SHA256, &key, seg) }, )?; @@ -111,7 +99,7 @@ impl DerivedPrivateJWK { pub fn derive_hkdf_key( hash_algo: HashAlgorithm, - initial_key_material: &[u8], + initial_key_material: &SecretKey, info: &[u8], ) -> Result { if hash_algo != HashAlgorithm::SHA256 { @@ -123,11 +111,11 @@ impl DerivedPrivateJWK { let mut okm = [0u8; HKDF_KEY_LENGTH]; - hkdf::Hkdf::::new(None, initial_key_material) + hkdf::Hkdf::::new(None, initial_key_material.to_bytes().as_slice()) .expand(info, &mut okm) .map_err(Error::DeriveKeyLengthError)?; - Ok(SecretKey::from_slice(&okm)?) + Ok(SecretKey::try_from((initial_key_material, &okm))?) } fn validate_path(path: &[&str]) -> Result<(), Error> { @@ -145,68 +133,132 @@ mod tests { use super::*; use ssi_jwk::JWK; + struct JWKTestTable { + private_jwk: JWK, + } + + struct SecretKeyTestTable { + secret_key: SecretKey, + } + #[test] fn test_derive() { - let root_key = JWK::generate_secp256k1(); - let root_key_id = "root_key_id".to_string(); - let scheme = DerivationScheme::ProtocolPath; - let path = &["path"]; - let derived_key = DerivedPrivateJWK { - root_key_id: root_key_id.clone(), - scheme, - path: Some(path.iter().map(|s| s.to_string()).collect()), - key: root_key, - }; - let derived = DerivedPrivateJWK::derive(derived_key, vec!["path2".to_string()]).unwrap(); - - assert_eq!(derived.root_key_id, root_key_id); - assert_eq!(derived.scheme, DerivationScheme::ProtocolPath); - assert_eq!( - derived.path, - Some(vec!["path".to_string(), "path2".to_string()]) - ); + let tcs = vec![ + JWKTestTable { + private_jwk: JWK::generate_secp256k1(), + }, + JWKTestTable { + private_jwk: { + let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( + rand::thread_rng(), + )); + sk.try_into().unwrap() + }, + }, + JWKTestTable { + private_jwk: JWK::generate_ed25519().expect("unable to gnenerate ed25519 key"), + }, + ]; + + for tc in tcs { + let root_key = tc.private_jwk.clone(); + let root_key_id = "root_key_id".to_string(); + let scheme = DerivationScheme::ProtocolPath; + let path = vec!["path".to_string()]; + let derived_key = DerivedPrivateJWK { + root_key_id: root_key_id.clone(), + scheme, + path: Some(path), + key: root_key, + }; + let derived = + DerivedPrivateJWK::derive(derived_key, vec!["path2".to_string()]).unwrap(); + + assert_eq!(derived.root_key_id, root_key_id); + assert_eq!(derived.scheme, DerivationScheme::ProtocolPath); + assert_eq!( + derived.path, + Some(vec!["path".to_string(), "path2".to_string()]) + ); + } } #[test] fn test_derive_public_key() { - let root_key = JWK::generate_secp256k1(); - let root_key_id = "root_key_id".to_string(); - let scheme = DerivationScheme::ProtocolPath; - let path = &["path"]; - let derived_key = DerivedPrivateJWK { - root_key_id: root_key_id.clone(), - scheme, - path: Some(path.iter().map(|s| s.to_string()).collect()), - key: root_key.clone(), - }; - - let derived = DerivedPrivateJWK::derive_public_key(derived_key, path).unwrap(); - - assert!(derived.params.is_public()); + let tcs = vec![ + JWKTestTable { + private_jwk: JWK::generate_secp256k1(), + }, + JWKTestTable { + private_jwk: { + let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( + rand::thread_rng(), + )); + sk.try_into().unwrap() + }, + }, + JWKTestTable { + private_jwk: JWK::generate_ed25519().expect("unable to gnenerate ed25519 key"), + }, + ]; + + for tc in tcs { + let root_key = tc.private_jwk.clone(); + let root_key_id = "root_key_id".to_string(); + let scheme = DerivationScheme::ProtocolPath; + let path = vec!["path".to_string()]; + let derived_key = DerivedPrivateJWK { + root_key_id: root_key_id.clone(), + scheme, + path: Some(path.iter().map(|s| s.to_string()).collect()), + key: root_key.clone(), + }; + + let derived = DerivedPrivateJWK::derive_public_key(derived_key, path).unwrap(); + + assert!(derived.params.is_public()); + } } #[test] fn test_derive_ancestor_chain_path() { - let root_key = k256::SecretKey::random(&mut rand::thread_rng()); + let tcs = vec![ + SecretKeyTestTable { + secret_key: SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())), + }, + SecretKeyTestTable { + secret_key: SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( + rand::thread_rng(), + )), + }, + SecretKeyTestTable { + secret_key: ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(), + }, + ]; + + for tc in tcs { + let root_key = tc.secret_key.clone(); - let path_to_g = ["a", "b", "c", "d", "e", "f", "g"].as_slice(); - let path_to_d = ["a", "b", "c", "d"].as_slice(); - let path_e_to_g = ["e", "f", "g"].as_slice(); + let path_to_g = ["a", "b", "c", "d", "e", "f", "g"].as_slice(); + let path_to_d = ["a", "b", "c", "d"].as_slice(); + let path_e_to_g = ["e", "f", "g"].as_slice(); - let keyg = DerivedPrivateJWK::derive_private_key(&root_key, path_to_g).unwrap(); - let keyd = DerivedPrivateJWK::derive_private_key(&root_key, path_to_d).unwrap(); - let keydg = DerivedPrivateJWK::derive_private_key(&keyd, path_e_to_g).unwrap(); + let keyg = DerivedPrivateJWK::derive_secret(&root_key, path_to_g).unwrap(); + let keyd = DerivedPrivateJWK::derive_secret(&root_key, path_to_d).unwrap(); + let keydg = DerivedPrivateJWK::derive_secret(&keyd, path_e_to_g).unwrap(); - assert_eq!(keyg, keydg); - assert_ne!(keyg, keyd); + assert_eq!(keyg, keydg); + assert_ne!(keyg, keyd); + } } #[test] fn test_invalid_path() { - let root_key = k256::SecretKey::random(&mut rand::thread_rng()); + let root_key: SecretKey = + SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())); let path = ["a", "", "c"].as_slice(); - let result = DerivedPrivateJWK::derive_private_key(&root_key, path); + let result = DerivedPrivateJWK::derive_secret(&root_key, path); assert!(result.is_err()); assert_eq!( diff --git a/crates/dwn-rs-core/src/encryption/mod.rs b/crates/dwn-rs-core/src/encryption/mod.rs index e2b0389..80ad438 100644 --- a/crates/dwn-rs-core/src/encryption/mod.rs +++ b/crates/dwn-rs-core/src/encryption/mod.rs @@ -1,6 +1,10 @@ +pub mod asymmetric; +pub mod errors; pub mod hd_keys; -pub use hd_keys::*; +pub use asymmetric::SecretKey; +pub use errors::Error; +pub use hd_keys::{DerivedPrivateJWK, HashAlgorithm}; use serde::{Deserialize, Serialize}; use ssi_jwk::JWK; From 547a86ee55f0851c6271774d62f7ccc8ccc617f4 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Fri, 6 Dec 2024 09:36:17 -0400 Subject: [PATCH 18/23] feat: AES-256-CTR symmetric encryption --- crates/dwn-rs-core/src/encryption/mod.rs | 1 + .../src/encryption/symmetric/aes_ctr.rs | 73 ++++++ .../src/encryption/symmetric/mod.rs | 247 ++++++++++++++++++ 3 files changed, 321 insertions(+) create mode 100644 crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs create mode 100644 crates/dwn-rs-core/src/encryption/symmetric/mod.rs diff --git a/crates/dwn-rs-core/src/encryption/mod.rs b/crates/dwn-rs-core/src/encryption/mod.rs index 80ad438..66b6138 100644 --- a/crates/dwn-rs-core/src/encryption/mod.rs +++ b/crates/dwn-rs-core/src/encryption/mod.rs @@ -1,6 +1,7 @@ pub mod asymmetric; pub mod errors; pub mod hd_keys; +pub mod symmetric; pub use asymmetric::SecretKey; pub use errors::Error; diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs new file mode 100644 index 0000000..2dfa1a5 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs @@ -0,0 +1,73 @@ +use aes::{ + cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}, + Aes256, +}; +use bytes::{Bytes, BytesMut}; +use ctr::Ctr64BE; +use thiserror::Error; + +use super::Encryption; + +pub type CipherAES256CTR = Ctr64BE; + +pub struct AES256CTR(CipherAES256CTR, CipherAES256CTR); + +#[derive(Debug, Error)] +pub enum Error { + #[error("Invalid key length: {0}")] + InvalidKeyLength(#[from] aes::cipher::InvalidLength), + #[error("AES-256-CTR encryption error: {0}")] + EncryptError(aes::cipher::StreamCipherError), +} + +impl Encryption for AES256CTR { + fn new(key: &[u8; 32], iv: &[u8; 16]) -> Result { + let cipher = CipherAES256CTR::new_from_slices(key, iv).map_err(Error::InvalidKeyLength)?; + let mut dec_cipher = cipher.clone(); + dec_cipher.seek(0u32); + Ok(Self(cipher, dec_cipher)) + } + + fn encrypt(&mut self, data: &mut BytesMut) -> Result { + println!("applying encryption"); + self.0.apply_keystream(data); + Ok(data.clone().freeze()) + } + + fn decrypt(&mut self, data: &mut BytesMut) -> Result { + println!("applying decryption"); + self.1.apply_keystream(data); + println!("{:?}", data); + Ok(data.clone().freeze()) + } +} + +#[cfg(test)] +mod test { + + #[test] + fn test_aes256ctr() { + use super::*; + use bytes::Bytes; + + let key = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, + 0x4f, 0x3c, 0x76, 0x3b, 0x61, 0x7b, 0x2e, 0x45, 0x8f, 0x17, 0x98, 0x4a, 0xc3, 0x5b, + 0x4d, 0xa4, 0x5c, 0x2a, + ]; + let iv = [ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, + 0xfe, 0xff, + ]; + + let mut enc = AES256CTR::new(&key, &iv).expect("Failed to create AES256CTR"); + let data = Bytes::from_static(b"hello world! this is my plaintext."); + let enc_data = enc + .encrypt(&mut data.clone().into()) + .unwrap_or_else(|e| panic!("{}", e.to_string())); + + assert_eq!(enc_data.to_vec(), b"\xde\xf2\xc6\xe6t\xec#x\x80\xce\xdb\xb1\x940\xa2\x0c\xab0\xef\0\x05B\"\x1eE\x92\xa6\xa4\xbe\x8c\x8dk\x5f\xDD"); + let dec_data = enc.decrypt(&mut enc_data.into()).unwrap(); + assert_eq!(data, dec_data); + } +} diff --git a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs new file mode 100644 index 0000000..7040b32 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs @@ -0,0 +1,247 @@ +use std::{pin::Pin, task::Poll}; + +use bytes::{Bytes, BytesMut}; +use futures_util::{ready, Stream}; +use pin_project_lite::pin_project; +use thiserror::Error; + +pub mod aes_ctr; +pub mod aes_gcm; +pub mod xsalsa20_poly1305; + +#[derive(Debug, Error)] +pub enum Error { + #[error("AES-256-CBC encryption error: {0}")] + AES256CTR(#[from] aes_ctr::Error), +} + +impl StreamEncryptionExt for T where T: Stream {} + +pub trait StreamEncryptionExt: Stream { + fn encrypt(self, key: &[u8; 32], iv: &[u8; 16]) -> Result, Error> + where + E: Encryption, + Self: Sized, + { + Encrypt::new(self, key, iv) + } + + fn decrypt(self, key: &[u8; 32], iv: &[u8; 16]) -> Result, Error> + where + E: Encryption, + Self: Sized, + { + Decrypt::new(self, key, iv) + } +} + +pub trait Encryption { + fn new(key: &[u8; 32], iv: &[u8; 16]) -> Result + where + Self: Sized; + fn encrypt(&mut self, data: &mut BytesMut) -> Result; + fn decrypt(&mut self, data: &mut BytesMut) -> Result; +} + +pin_project! { + #[must_use = "streams do nothing unless polled"] + pub struct Encrypt { + #[pin] + stream: D, + encryption: E, + } +} + +impl Encrypt +where + E: Encryption, +{ + pub fn new(stream: D, key: &[u8; 32], iv: &[u8; 16]) -> Result { + Ok(Self { + stream, + encryption: E::new(key, iv)?, + }) + } +} + +impl Stream for Encrypt +where + D: Stream>, + E: Encryption, +{ + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.project(); + let res = ready!(this.stream.as_mut().poll_next(cx)); + let mut bytes = match res { + Some(Ok(bytes)) => bytes.into(), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + }; + Poll::Ready(Some(this.encryption.encrypt(&mut bytes))) + } +} + +pin_project! { + #[must_use = "streams do nothing unless polled"] + pub struct Decrypt { + #[pin] + stream: D, + encryption: E, + } +} + +impl Decrypt +where + E: Encryption, +{ + pub fn new(stream: D, key: &[u8; 32], iv: &[u8; 16]) -> Result { + Ok(Self { + stream, + encryption: E::new(key, iv)?, + }) + } +} + +impl Stream for Decrypt +where + D: Stream>, + E: Encryption, +{ + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let mut this = self.project(); + let res = ready!(this.stream.as_mut().poll_next(cx)); + let mut bytes = match res { + Some(Ok(bytes)) => bytes.into(), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + }; + + Poll::Ready(Some(this.encryption.decrypt(&mut bytes))) + } +} + +#[cfg(test)] +mod test { + use super::{aes_ctr, Encryption, Error, StreamEncryptionExt}; + use bytes::Bytes; + use futures_util::{pin_mut, stream, StreamExt}; + + const KEY: [u8; 32] = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, + 0x3c, 0x76, 0x3b, 0x61, 0x7b, 0x2e, 0x45, 0x8f, 0x17, 0x98, 0x4a, 0xc3, 0x5b, 0x4d, 0xa4, + 0x5c, 0x2a, + ]; + const IV: [u8; 16] = [ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, + 0xff, + ]; + + #[tokio::test] + async fn test_aes256ctr_encrypt_stream() { + let msg_part_1 = Bytes::from_static(b"hello world! "); + let msg_part_2 = Bytes::from_static(b"this is my plaintext."); + let msg = Bytes::from([msg_part_1.clone(), msg_part_2.clone()].concat()); + + // Stream Encryption + let data_stream = stream::iter(vec![ + Ok::(msg_part_1.clone()), + Ok(msg_part_2.clone()), + ]) + .encrypt::(&KEY, &IV) + .expect("unable to generate encryption"); + + // Static encryption + let mut enc = aes_ctr::AES256CTR::new(&KEY, &IV).expect("Failed to create AES256CTR"); + let enc_data = enc + .encrypt(&mut msg.clone().into()) + .unwrap_or_else(|e| panic!("{}", e.to_string())); + + pin_mut!(data_stream); + let mut data: Vec = vec![]; + while let Some(c) = data_stream.next().await { + data.extend_from_slice(&c.unwrap()); + } + let data = Bytes::from(data); + + // Assert stream encryption and static encryption are equal + assert_eq!(data, enc_data); + // Assert the stream, and the static test do not match the original data + assert_ne!(data, msg); + } + + #[tokio::test] + async fn test_aes256ctr_decrypt_stream() { + let msg_part_1 = Bytes::from_static(b"hello world! "); + let msg_part_2 = Bytes::from_static(b"this is my plaintext."); + let msg = Bytes::from([msg_part_1.clone(), msg_part_2.clone()].concat()); + + // Stream Encryption + let data_stream = stream::iter(vec![ + Ok::(msg_part_1.clone()), + Ok(msg_part_2.clone()), + ]) + .encrypt::(&KEY, &IV) + .expect("unable to generate encryption") + .decrypt::(&KEY, &IV) + .expect("unable to generate decryption"); + + // Static encryption + let mut enc = aes_ctr::AES256CTR::new(&KEY, &IV).expect("Failed to create AES256CTR"); + let enc_data = enc + .encrypt(&mut msg.clone().into()) + .unwrap_or_else(|e| panic!("{}", e.to_string())); + let dec_data = enc + .decrypt(&mut enc_data.clone().into()) + .unwrap_or_else(|e| panic!("{}", e.to_string())); + + pin_mut!(data_stream); + let mut data: Vec = vec![]; + while let Some(c) = data_stream.next().await { + data.extend_from_slice(&c.unwrap()); + } + let data = Bytes::from(data); + + // Assert the stream, and the static test match the decrypted data + assert_eq!(data, dec_data); + // Assert the stream, and the static test match the original data + assert_eq!(data, msg); + } + + #[tokio::test] + async fn test_aes256ctr_encrypt_err() { + let msg = Bytes::from_static(b"hello world! "); + + // Stream Encryption + let data_stream = stream::iter(vec![ + Ok(msg.clone()), + Err(Error::AES256CTR(aes_ctr::Error::InvalidKeyLength( + aes::cipher::InvalidLength, + ))), + ]) + .encrypt::(&KEY, &IV) + .expect("unable to generate encryption"); + + pin_mut!(data_stream); + while let Some(c) = data_stream.next().await { + match c { + Ok(_) => {} + Err(e) => { + assert_eq!( + e.to_string(), + "AES-256-CBC encryption error: Invalid key length: Invalid Length" + ); + } + } + } + } +} From 8ff543eced9dbd114bb425ac5548dfa2ef1613c5 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Sat, 7 Dec 2024 23:21:42 -0400 Subject: [PATCH 19/23] chore: move IV init to IVEncryption trait --- .../src/encryption/symmetric/aes_ctr.rs | 110 +++++++++++++----- .../src/encryption/symmetric/mod.rs | 87 ++++++++++---- 2 files changed, 147 insertions(+), 50 deletions(-) diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs index 2dfa1a5..9567213 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs @@ -1,66 +1,98 @@ use aes::{ - cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}, + cipher::{generic_array::GenericArray, KeyIvInit, StreamCipher, StreamCipherSeek}, Aes256, }; use bytes::{Bytes, BytesMut}; use ctr::Ctr64BE; use thiserror::Error; -use super::Encryption; +use super::{Encryption, IVEncryption}; pub type CipherAES256CTR = Ctr64BE; -pub struct AES256CTR(CipherAES256CTR, CipherAES256CTR); +pub struct AES256CTR { + key: [u8; 32], + enc: Option, + dec: Option, +} #[derive(Debug, Error)] pub enum Error { #[error("Invalid key length: {0}")] InvalidKeyLength(#[from] aes::cipher::InvalidLength), - #[error("AES-256-CTR encryption error: {0}")] + #[error("AES-256-CTR encryption/decryption error: {0}")] EncryptError(aes::cipher::StreamCipherError), + #[error("AES-256-CTR IV error")] + NoIVError, } impl Encryption for AES256CTR { - fn new(key: &[u8; 32], iv: &[u8; 16]) -> Result { - let cipher = CipherAES256CTR::new_from_slices(key, iv).map_err(Error::InvalidKeyLength)?; - let mut dec_cipher = cipher.clone(); - dec_cipher.seek(0u32); - Ok(Self(cipher, dec_cipher)) + fn new(key: &[u8; 32]) -> Result { + Ok(Self { + key: *key, + enc: None, + dec: None, + }) } fn encrypt(&mut self, data: &mut BytesMut) -> Result { - println!("applying encryption"); - self.0.apply_keystream(data); - Ok(data.clone().freeze()) + if let Some(enc) = &mut self.enc { + enc.apply_keystream(data); + Ok(data.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } } fn decrypt(&mut self, data: &mut BytesMut) -> Result { - println!("applying decryption"); - self.1.apply_keystream(data); - println!("{:?}", data); - Ok(data.clone().freeze()) + if let Some(dec) = &mut self.dec { + dec.apply_keystream(data); + Ok(data.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } + } +} + +impl IVEncryption for AES256CTR { + type NonceSize = typenum::consts::U16; + + fn with_iv(&mut self, iv: GenericArray) -> Result { + let cipher = + CipherAES256CTR::new_from_slices(&self.key, &iv).map_err(Error::InvalidKeyLength)?; + let mut dec_cipher = cipher.clone(); + dec_cipher.seek(0u32); + + Ok(Self { + key: self.key, + enc: Some(cipher), + dec: Some(dec_cipher), + }) } } #[cfg(test)] mod test { + use super::*; + use bytes::Bytes; + + const KEY: [u8; 32] = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, + 0x3c, 0x76, 0x3b, 0x61, 0x7b, 0x2e, 0x45, 0x8f, 0x17, 0x98, 0x4a, 0xc3, 0x5b, 0x4d, 0xa4, + 0x5c, 0x2a, + ]; + const IV: [u8; 16] = [ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, + 0xff, + ]; #[test] fn test_aes256ctr() { - use super::*; - use bytes::Bytes; - - let key = [ - 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, - 0x4f, 0x3c, 0x76, 0x3b, 0x61, 0x7b, 0x2e, 0x45, 0x8f, 0x17, 0x98, 0x4a, 0xc3, 0x5b, - 0x4d, 0xa4, 0x5c, 0x2a, - ]; - let iv = [ - 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, - 0xfe, 0xff, - ]; - - let mut enc = AES256CTR::new(&key, &iv).expect("Failed to create AES256CTR"); + let mut enc = AES256CTR::new(&KEY) + .expect("Failed to create AES256CTR") + .with_iv(IV.into()) + .expect("Failed to set IV"); + let data = Bytes::from_static(b"hello world! this is my plaintext."); let enc_data = enc .encrypt(&mut data.clone().into()) @@ -70,4 +102,22 @@ mod test { let dec_data = enc.decrypt(&mut enc_data.into()).unwrap(); assert_eq!(data, dec_data); } + + #[test] + fn test_aes256ctr_no_iv() { + let mut enc = AES256CTR::new(&KEY).expect("Failed to create AES256CTR"); + + let data = Bytes::from_static(b"hello world! this is my plaintext."); + let enc_data = enc.encrypt(&mut data.clone().into()); + assert_eq!( + enc_data.unwrap_err().to_string(), + "AES-256-CBC encryption error: AES-256-CTR IV error" + ); + + let dec_data = enc.decrypt(&mut data.clone().into()); + assert_eq!( + dec_data.unwrap_err().to_string(), + "AES-256-CBC encryption error: AES-256-CTR IV error" + ); + } } diff --git a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs index 7040b32..ff5aca2 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs @@ -1,5 +1,6 @@ use std::{pin::Pin, task::Poll}; +use aes::cipher::{generic_array::GenericArray, ArrayLength}; use bytes::{Bytes, BytesMut}; use futures_util::{ready, Stream}; use pin_project_lite::pin_project; @@ -11,38 +12,50 @@ pub mod xsalsa20_poly1305; #[derive(Debug, Error)] pub enum Error { - #[error("AES-256-CBC encryption error: {0}")] + #[error("AES-256-CTR encryption error: {0}")] AES256CTR(#[from] aes_ctr::Error), + #[error("AES-256-GCM encryption error: {0}")] + AES256GCM(#[from] aes_gcm::Error), + #[error("XSalsa20Poly1305 encryption error: {0}")] + XSalsa20Poly1305(#[from] xsalsa20_poly1305::Error), } impl StreamEncryptionExt for T where T: Stream {} pub trait StreamEncryptionExt: Stream { - fn encrypt(self, key: &[u8; 32], iv: &[u8; 16]) -> Result, Error> + fn encrypt(self, key: &[u8; 32]) -> Result, Error> where E: Encryption, Self: Sized, { - Encrypt::new(self, key, iv) + Encrypt::new(self, key) } - fn decrypt(self, key: &[u8; 32], iv: &[u8; 16]) -> Result, Error> + fn decrypt(self, key: &[u8; 32]) -> Result, Error> where E: Encryption, Self: Sized, { - Decrypt::new(self, key, iv) + Decrypt::new(self, key) } } pub trait Encryption { - fn new(key: &[u8; 32], iv: &[u8; 16]) -> Result + fn new(key: &[u8; 32]) -> Result where Self: Sized; fn encrypt(&mut self, data: &mut BytesMut) -> Result; fn decrypt(&mut self, data: &mut BytesMut) -> Result; } +pub trait IVEncryption: Encryption { + type NonceSize: ArrayLength; + + fn with_iv(&mut self, iv: GenericArray) -> Result + where + Self: Sized; +} + pin_project! { #[must_use = "streams do nothing unless polled"] pub struct Encrypt { @@ -56,14 +69,24 @@ impl Encrypt where E: Encryption, { - pub fn new(stream: D, key: &[u8; 32], iv: &[u8; 16]) -> Result { + pub fn new(stream: D, key: &[u8; 32]) -> Result { Ok(Self { stream, - encryption: E::new(key, iv)?, + encryption: E::new(key)?, }) } } +impl Encrypt +where + E: IVEncryption, +{ + pub fn with_iv(mut self, iv: GenericArray) -> Result { + self.encryption = self.encryption.with_iv(iv)?; + Ok(self) + } +} + impl Stream for Encrypt where D: Stream>, @@ -99,14 +122,24 @@ impl Decrypt where E: Encryption, { - pub fn new(stream: D, key: &[u8; 32], iv: &[u8; 16]) -> Result { + pub fn new(stream: D, key: &[u8; 32]) -> Result { Ok(Self { stream, - encryption: E::new(key, iv)?, + encryption: E::new(key)?, }) } } +impl Decrypt +where + E: IVEncryption, +{ + pub fn with_iv(mut self, iv: GenericArray) -> Result { + self.encryption = self.encryption.with_iv(iv)?; + Ok(self) + } +} + impl Stream for Decrypt where D: Stream>, @@ -132,7 +165,7 @@ where #[cfg(test)] mod test { - use super::{aes_ctr, Encryption, Error, StreamEncryptionExt}; + use super::{aes_ctr, Encryption, Error, IVEncryption, StreamEncryptionExt}; use bytes::Bytes; use futures_util::{pin_mut, stream, StreamExt}; @@ -157,11 +190,16 @@ mod test { Ok::(msg_part_1.clone()), Ok(msg_part_2.clone()), ]) - .encrypt::(&KEY, &IV) - .expect("unable to generate encryption"); + .encrypt::(&KEY) + .expect("unable to generate encryption") + .with_iv(IV.into()) + .expect("unable to set IV"); // Static encryption - let mut enc = aes_ctr::AES256CTR::new(&KEY, &IV).expect("Failed to create AES256CTR"); + let mut enc = aes_ctr::AES256CTR::new(&KEY) + .expect("Failed to create AES256CTR") + .with_iv(IV.into()) + .expect("Failed to set IV"); let enc_data = enc .encrypt(&mut msg.clone().into()) .unwrap_or_else(|e| panic!("{}", e.to_string())); @@ -190,13 +228,20 @@ mod test { Ok::(msg_part_1.clone()), Ok(msg_part_2.clone()), ]) - .encrypt::(&KEY, &IV) + .encrypt::(&KEY) .expect("unable to generate encryption") - .decrypt::(&KEY, &IV) - .expect("unable to generate decryption"); + .with_iv(IV.into()) + .expect("Unable to set IV") + .decrypt::(&KEY) + .expect("unable to generate decryption") + .with_iv(IV.into()) + .expect("Unable to set IV"); // Static encryption - let mut enc = aes_ctr::AES256CTR::new(&KEY, &IV).expect("Failed to create AES256CTR"); + let mut enc = aes_ctr::AES256CTR::new(&KEY) + .expect("Failed to create AES256CTR") + .with_iv(IV.into()) + .expect("Unable to set IV"); let enc_data = enc .encrypt(&mut msg.clone().into()) .unwrap_or_else(|e| panic!("{}", e.to_string())); @@ -228,8 +273,10 @@ mod test { aes::cipher::InvalidLength, ))), ]) - .encrypt::(&KEY, &IV) - .expect("unable to generate encryption"); + .encrypt::(&KEY) + .expect("unable to generate encryption") + .with_iv(IV.into()) + .expect("Unable to set IV"); pin_mut!(data_stream); while let Some(c) = data_stream.next().await { From a7275c345875cfba1ca24d74e1546b0edb93fb73 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Sat, 7 Dec 2024 23:22:05 -0400 Subject: [PATCH 20/23] feat: add AES-GCM and XSalsa20-Poly1305 symmetric algorithms --- .../src/encryption/symmetric/aes_gcm.rs | 133 ++++++++++++++++++ .../encryption/symmetric/xsalsa20_poly1305.rs | 62 ++++++++ 2 files changed, 195 insertions(+) create mode 100644 crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs create mode 100644 crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs b/crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs new file mode 100644 index 0000000..ba17912 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs @@ -0,0 +1,133 @@ +use aes::cipher::generic_array::GenericArray; +use aes_gcm::{ + aead::{AeadMutInPlace, Buffer}, + Aes256Gcm, KeyInit, +}; +use bytes::{Bytes, BytesMut}; +use thiserror::Error; + +use super::{Encryption, IVEncryption}; + +pub(super) struct AESBuffer<'a>(pub(crate) &'a mut BytesMut); + +impl<'a> Buffer for AESBuffer<'a> { + fn extend_from_slice(&mut self, other: &[u8]) -> aes_gcm::aead::Result<()> { + self.0.extend_from_slice(other); + + Ok(()) + } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } +} + +impl<'a> AsRef<[u8]> for AESBuffer<'a> { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl<'a> AsMut<[u8]> for AESBuffer<'a> { + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut() + } +} + +pub struct AES256GCM { + cipher: Aes256Gcm, + iv: Option>, +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("AES-256-GCM encryption/decryption error: {0}")] + EncryptError(aes_gcm::Error), + #[error("AES-256-GCM Initialization Vector error")] + NoIVError, +} + +impl Encryption for AES256GCM { + fn new(key: &[u8; 32]) -> Result { + let cipher = Aes256Gcm::new(key.into()); + Ok(Self { cipher, iv: None }) + } + + fn encrypt(&mut self, data: &mut BytesMut) -> Result { + let mut data = AESBuffer(data); + if let Some(iv) = &self.iv { + self.cipher + .encrypt_in_place(iv, b"", &mut data) + .map_err(Error::EncryptError)?; + Ok(data.0.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } + } + + fn decrypt(&mut self, data: &mut BytesMut) -> Result { + let mut data = AESBuffer(data); + if let Some(iv) = &self.iv { + self.cipher + .decrypt_in_place(iv, b"", &mut data) + .map_err(Error::EncryptError)?; + Ok(data.0.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } + } +} + +impl IVEncryption for AES256GCM { + type NonceSize = typenum::consts::U12; + + fn with_iv(&mut self, iv: GenericArray) -> Result { + Ok(Self { + cipher: self.cipher.clone(), + iv: Some(iv), + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + const KEY: [u8; 32] = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, + 0x3c, 0x76, 0x3b, 0x61, 0x7b, 0x2e, 0x45, 0x8f, 0x17, 0x98, 0x4a, 0xc3, 0x5b, 0x4d, 0xa4, + 0x5c, 0x2a, + ]; + const IV: [u8; 12] = [ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, + ]; + + #[test] + fn test_aes256gcm() { + let mut enc = AES256GCM::new(&KEY) + .unwrap() + .with_iv(IV.into()) + .expect("IV error"); + + let data = BytesMut::from("Hello, world!"); + + let enc_data = enc.encrypt(&mut data.clone()).unwrap(); + let dec_data = enc.decrypt(&mut enc_data.clone().into()).unwrap(); + + assert_ne!(data, enc_data); + assert_eq!(data, dec_data); + } + + #[test] + fn test_aes256gcm_no_iv() { + let mut enc = AES256GCM::new(&KEY).unwrap(); + + let data = BytesMut::from("Hello, world!"); + + let enc_data = enc.encrypt(&mut data.clone()); + let dec_data = enc.decrypt(&mut data.clone()); + + assert!(enc_data.is_err()); + assert!(dec_data.is_err()); + } +} diff --git a/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs b/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs new file mode 100644 index 0000000..6e84655 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs @@ -0,0 +1,62 @@ +use aes::cipher::generic_array::GenericArray; +use aes_gcm::{aead::AeadMutInPlace, KeyInit}; +use bytes::{Bytes, BytesMut}; +use crypto_secretbox::XSalsa20Poly1305 as XSalsa20Poly1305Cipher; +use thiserror::Error; + +use super::{aes_gcm::AESBuffer, Encryption, IVEncryption}; + +pub struct XSalsa20Poly1305 { + cipher: XSalsa20Poly1305Cipher, + iv: Option>, +} + +#[derive(Error, Debug)] +pub enum Error { + #[error("XSalsa20Poly1305 encryption/decryption error: {0}")] + EncryptError(crypto_secretbox::Error), + #[error("XSalsa20Poly1305 Initialization Vector error")] + NoIVError, +} + +impl Encryption for XSalsa20Poly1305 { + fn new(key: &[u8; 32]) -> Result { + let cipher = XSalsa20Poly1305Cipher::new(key.into()); + Ok(Self { cipher, iv: None }) + } + + fn encrypt(&mut self, data: &mut BytesMut) -> Result { + let mut data = AESBuffer(data); + if let Some(iv) = &self.iv { + self.cipher + .encrypt_in_place(iv, b"", &mut data) + .map_err(Error::EncryptError)?; + Ok(data.0.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } + } + + fn decrypt(&mut self, data: &mut BytesMut) -> Result { + let mut data = AESBuffer(data); + if let Some(iv) = &self.iv { + self.cipher + .decrypt_in_place(iv, b"", &mut data) + .map_err(Error::EncryptError)?; + Ok(data.0.clone().freeze()) + } else { + Err(Error::NoIVError.into()) + } + } +} + +impl IVEncryption for XSalsa20Poly1305 { + type NonceSize = typenum::consts::U24; + + fn with_iv(&mut self, iv: GenericArray) -> Result { + Ok(Self { + cipher: self.cipher.clone(), + iv: Some(iv), + }) + } +} From 53ee4cfe6e0d7f65254d309d94a5117dd2ad8425 Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Sun, 8 Dec 2024 00:08:12 -0400 Subject: [PATCH 21/23] feat: consolidate AES-GCM and XSalsa20-Poly1305 into AEAD --- .../symmetric/{aes_gcm.rs => aead.rs} | 82 ++++++++++++++----- .../src/encryption/symmetric/aes_ctr.rs | 14 ++-- .../src/encryption/symmetric/mod.rs | 38 +++++---- .../encryption/symmetric/xsalsa20_poly1305.rs | 62 -------------- 4 files changed, 91 insertions(+), 105 deletions(-) rename crates/dwn-rs-core/src/encryption/symmetric/{aes_gcm.rs => aead.rs} (56%) delete mode 100644 crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs b/crates/dwn-rs-core/src/encryption/symmetric/aead.rs similarity index 56% rename from crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs rename to crates/dwn-rs-core/src/encryption/symmetric/aead.rs index ba17912..170178e 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/aes_gcm.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/aead.rs @@ -1,16 +1,22 @@ -use aes::cipher::generic_array::GenericArray; +use aes::cipher::{generic_array::GenericArray, ArrayLength}; use aes_gcm::{ aead::{AeadMutInPlace, Buffer}, Aes256Gcm, KeyInit, }; use bytes::{Bytes, BytesMut}; +use crypto_secretbox::XSalsa20Poly1305 as XSalsa20Poly1305Cipher; use thiserror::Error; use super::{Encryption, IVEncryption}; -pub(super) struct AESBuffer<'a>(pub(crate) &'a mut BytesMut); +pub struct AEAD { + cipher: C, + iv: Option>, +} + +pub(super) struct AEADBufferBytesMut<'a>(&'a mut BytesMut); -impl<'a> Buffer for AESBuffer<'a> { +impl<'a> Buffer for AEADBufferBytesMut<'a> { fn extend_from_slice(&mut self, other: &[u8]) -> aes_gcm::aead::Result<()> { self.0.extend_from_slice(other); @@ -22,23 +28,18 @@ impl<'a> Buffer for AESBuffer<'a> { } } -impl<'a> AsRef<[u8]> for AESBuffer<'a> { +impl<'a> AsRef<[u8]> for AEADBufferBytesMut<'a> { fn as_ref(&self) -> &[u8] { self.0.as_ref() } } -impl<'a> AsMut<[u8]> for AESBuffer<'a> { +impl<'a> AsMut<[u8]> for AEADBufferBytesMut<'a> { fn as_mut(&mut self) -> &mut [u8] { self.0.as_mut() } } -pub struct AES256GCM { - cipher: Aes256Gcm, - iv: Option>, -} - #[derive(Debug, Error)] pub enum Error { #[error("AES-256-GCM encryption/decryption error: {0}")] @@ -47,14 +48,19 @@ pub enum Error { NoIVError, } -impl Encryption for AES256GCM { - fn new(key: &[u8; 32]) -> Result { - let cipher = Aes256Gcm::new(key.into()); +impl Encryption for AEAD +where + C::NonceSize: ArrayLength, +{ + type KeySize = C::KeySize; + + fn new(key: GenericArray) -> Result { + let cipher = C::new(&key); Ok(Self { cipher, iv: None }) } fn encrypt(&mut self, data: &mut BytesMut) -> Result { - let mut data = AESBuffer(data); + let mut data = AEADBufferBytesMut(data); if let Some(iv) = &self.iv { self.cipher .encrypt_in_place(iv, b"", &mut data) @@ -66,7 +72,7 @@ impl Encryption for AES256GCM { } fn decrypt(&mut self, data: &mut BytesMut) -> Result { - let mut data = AESBuffer(data); + let mut data = AEADBufferBytesMut(data); if let Some(iv) = &self.iv { self.cipher .decrypt_in_place(iv, b"", &mut data) @@ -78,8 +84,8 @@ impl Encryption for AES256GCM { } } -impl IVEncryption for AES256GCM { - type NonceSize = typenum::consts::U12; +impl IVEncryption for AEAD { + type NonceSize = C::NonceSize; fn with_iv(&mut self, iv: GenericArray) -> Result { Ok(Self { @@ -89,8 +95,13 @@ impl IVEncryption for AES256GCM { } } +pub type AES256GCM = AEAD; +pub type XSalsa20Poly1305 = AEAD; + #[cfg(test)] mod test { + use aes_gcm::Aes256Gcm; + use super::*; const KEY: [u8; 32] = [ @@ -101,10 +112,14 @@ mod test { const IV: [u8; 12] = [ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, ]; + const SALSA_IV: [u8; 24] = [ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, + 0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + ]; #[test] fn test_aes256gcm() { - let mut enc = AES256GCM::new(&KEY) + let mut enc = AEAD::::new(KEY.into()) .unwrap() .with_iv(IV.into()) .expect("IV error"); @@ -120,7 +135,36 @@ mod test { #[test] fn test_aes256gcm_no_iv() { - let mut enc = AES256GCM::new(&KEY).unwrap(); + let mut enc = AEAD::::new(KEY.into()).unwrap(); + + let data = BytesMut::from("Hello, world!"); + + let enc_data = enc.encrypt(&mut data.clone()); + let dec_data = enc.decrypt(&mut data.clone()); + + assert!(enc_data.is_err()); + assert!(dec_data.is_err()); + } + + #[test] + fn test_xsalsa20poly1305() { + let mut enc = XSalsa20Poly1305::new(KEY.into()) + .unwrap() + .with_iv(SALSA_IV.into()) + .expect("IV error"); + + let data = BytesMut::from("Hello, world!"); + + let enc_data = enc.encrypt(&mut data.clone()).unwrap(); + let dec_data = enc.decrypt(&mut enc_data.clone().into()).unwrap(); + + assert_ne!(data, enc_data); + assert_eq!(data, dec_data); + } + + #[test] + fn test_xsalsa20poly1305_no_iv() { + let mut enc = XSalsa20Poly1305::new(KEY.into()).unwrap(); let data = BytesMut::from("Hello, world!"); diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs index 9567213..53b0854 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs @@ -27,9 +27,11 @@ pub enum Error { } impl Encryption for AES256CTR { - fn new(key: &[u8; 32]) -> Result { + type KeySize = typenum::consts::U32; + + fn new(key: GenericArray) -> Result { Ok(Self { - key: *key, + key: key.into(), enc: None, dec: None, }) @@ -88,7 +90,7 @@ mod test { #[test] fn test_aes256ctr() { - let mut enc = AES256CTR::new(&KEY) + let mut enc = AES256CTR::new(KEY.into()) .expect("Failed to create AES256CTR") .with_iv(IV.into()) .expect("Failed to set IV"); @@ -105,19 +107,19 @@ mod test { #[test] fn test_aes256ctr_no_iv() { - let mut enc = AES256CTR::new(&KEY).expect("Failed to create AES256CTR"); + let mut enc = AES256CTR::new(KEY.into()).expect("Failed to create AES256CTR"); let data = Bytes::from_static(b"hello world! this is my plaintext."); let enc_data = enc.encrypt(&mut data.clone().into()); assert_eq!( enc_data.unwrap_err().to_string(), - "AES-256-CBC encryption error: AES-256-CTR IV error" + "AES-256-CTR encryption error: AES-256-CTR IV error" ); let dec_data = enc.decrypt(&mut data.clone().into()); assert_eq!( dec_data.unwrap_err().to_string(), - "AES-256-CBC encryption error: AES-256-CTR IV error" + "AES-256-CTR encryption error: AES-256-CTR IV error" ); } } diff --git a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs index ff5aca2..b14f7f0 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs @@ -6,34 +6,33 @@ use futures_util::{ready, Stream}; use pin_project_lite::pin_project; use thiserror::Error; +pub mod aead; pub mod aes_ctr; -pub mod aes_gcm; -pub mod xsalsa20_poly1305; #[derive(Debug, Error)] pub enum Error { #[error("AES-256-CTR encryption error: {0}")] AES256CTR(#[from] aes_ctr::Error), - #[error("AES-256-GCM encryption error: {0}")] - AES256GCM(#[from] aes_gcm::Error), - #[error("XSalsa20Poly1305 encryption error: {0}")] - XSalsa20Poly1305(#[from] xsalsa20_poly1305::Error), + #[error("AEAD encryption error: {0}")] + AEAD(#[from] aead::Error), } impl StreamEncryptionExt for T where T: Stream {} pub trait StreamEncryptionExt: Stream { - fn encrypt(self, key: &[u8; 32]) -> Result, Error> + fn encrypt(self, key: GenericArray) -> Result, Error> where E: Encryption, + E::KeySize: ArrayLength, Self: Sized, { Encrypt::new(self, key) } - fn decrypt(self, key: &[u8; 32]) -> Result, Error> + fn decrypt(self, key: GenericArray) -> Result, Error> where E: Encryption, + E::KeySize: ArrayLength, Self: Sized, { Decrypt::new(self, key) @@ -41,7 +40,9 @@ pub trait StreamEncryptionExt: Stream { } pub trait Encryption { - fn new(key: &[u8; 32]) -> Result + type KeySize: ArrayLength; + + fn new(key: GenericArray) -> Result where Self: Sized; fn encrypt(&mut self, data: &mut BytesMut) -> Result; @@ -68,8 +69,9 @@ pin_project! { impl Encrypt where E: Encryption, + E::KeySize: ArrayLength, { - pub fn new(stream: D, key: &[u8; 32]) -> Result { + pub fn new(stream: D, key: GenericArray) -> Result { Ok(Self { stream, encryption: E::new(key)?, @@ -122,7 +124,7 @@ impl Decrypt where E: Encryption, { - pub fn new(stream: D, key: &[u8; 32]) -> Result { + pub fn new(stream: D, key: GenericArray) -> Result { Ok(Self { stream, encryption: E::new(key)?, @@ -190,13 +192,13 @@ mod test { Ok::(msg_part_1.clone()), Ok(msg_part_2.clone()), ]) - .encrypt::(&KEY) + .encrypt::(KEY.into()) .expect("unable to generate encryption") .with_iv(IV.into()) .expect("unable to set IV"); // Static encryption - let mut enc = aes_ctr::AES256CTR::new(&KEY) + let mut enc = aes_ctr::AES256CTR::new(KEY.into()) .expect("Failed to create AES256CTR") .with_iv(IV.into()) .expect("Failed to set IV"); @@ -228,17 +230,17 @@ mod test { Ok::(msg_part_1.clone()), Ok(msg_part_2.clone()), ]) - .encrypt::(&KEY) + .encrypt::(KEY.into()) .expect("unable to generate encryption") .with_iv(IV.into()) .expect("Unable to set IV") - .decrypt::(&KEY) + .decrypt::(KEY.into()) .expect("unable to generate decryption") .with_iv(IV.into()) .expect("Unable to set IV"); // Static encryption - let mut enc = aes_ctr::AES256CTR::new(&KEY) + let mut enc = aes_ctr::AES256CTR::new(KEY.into()) .expect("Failed to create AES256CTR") .with_iv(IV.into()) .expect("Unable to set IV"); @@ -273,7 +275,7 @@ mod test { aes::cipher::InvalidLength, ))), ]) - .encrypt::(&KEY) + .encrypt::(KEY.into()) .expect("unable to generate encryption") .with_iv(IV.into()) .expect("Unable to set IV"); @@ -285,7 +287,7 @@ mod test { Err(e) => { assert_eq!( e.to_string(), - "AES-256-CBC encryption error: Invalid key length: Invalid Length" + "AES-256-CTR encryption error: Invalid key length: Invalid Length" ); } } diff --git a/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs b/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs deleted file mode 100644 index 6e84655..0000000 --- a/crates/dwn-rs-core/src/encryption/symmetric/xsalsa20_poly1305.rs +++ /dev/null @@ -1,62 +0,0 @@ -use aes::cipher::generic_array::GenericArray; -use aes_gcm::{aead::AeadMutInPlace, KeyInit}; -use bytes::{Bytes, BytesMut}; -use crypto_secretbox::XSalsa20Poly1305 as XSalsa20Poly1305Cipher; -use thiserror::Error; - -use super::{aes_gcm::AESBuffer, Encryption, IVEncryption}; - -pub struct XSalsa20Poly1305 { - cipher: XSalsa20Poly1305Cipher, - iv: Option>, -} - -#[derive(Error, Debug)] -pub enum Error { - #[error("XSalsa20Poly1305 encryption/decryption error: {0}")] - EncryptError(crypto_secretbox::Error), - #[error("XSalsa20Poly1305 Initialization Vector error")] - NoIVError, -} - -impl Encryption for XSalsa20Poly1305 { - fn new(key: &[u8; 32]) -> Result { - let cipher = XSalsa20Poly1305Cipher::new(key.into()); - Ok(Self { cipher, iv: None }) - } - - fn encrypt(&mut self, data: &mut BytesMut) -> Result { - let mut data = AESBuffer(data); - if let Some(iv) = &self.iv { - self.cipher - .encrypt_in_place(iv, b"", &mut data) - .map_err(Error::EncryptError)?; - Ok(data.0.clone().freeze()) - } else { - Err(Error::NoIVError.into()) - } - } - - fn decrypt(&mut self, data: &mut BytesMut) -> Result { - let mut data = AESBuffer(data); - if let Some(iv) = &self.iv { - self.cipher - .decrypt_in_place(iv, b"", &mut data) - .map_err(Error::EncryptError)?; - Ok(data.0.clone().freeze()) - } else { - Err(Error::NoIVError.into()) - } - } -} - -impl IVEncryption for XSalsa20Poly1305 { - type NonceSize = typenum::consts::U24; - - fn with_iv(&mut self, iv: GenericArray) -> Result { - Ok(Self { - cipher: self.cipher.clone(), - iv: Some(iv), - }) - } -} From afe81da4ba0408550c9b79fb1a12dccf57cbdc0a Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Mon, 16 Dec 2024 11:52:10 -0400 Subject: [PATCH 22/23] chore: separate secp256k1 and x25519 key implementations with wrapper --- .../src/encryption/asymmetric/mod.rs | 106 +++++++- .../src/encryption/asymmetric/publickey.rs | 115 ++++----- .../src/encryption/asymmetric/secp256k1.rs | 103 ++++++++ .../src/encryption/asymmetric/secretkey.rs | 226 ++++++------------ .../src/encryption/asymmetric/x25519.rs | 124 ++++++++++ crates/dwn-rs-core/src/encryption/errors.rs | 4 +- crates/dwn-rs-core/src/encryption/hd_keys.rs | 70 ++---- crates/dwn-rs-core/src/encryption/mod.rs | 2 +- 8 files changed, 494 insertions(+), 256 deletions(-) create mode 100644 crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs create mode 100644 crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs index f6fe363..2814a18 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs @@ -1,12 +1,114 @@ pub mod publickey; +pub(crate) mod secp256k1; pub mod secretkey; +pub(crate) mod x25519; +use aes::cipher::{generic_array::GenericArray, ArrayLength}; +use k256::sha2; +use ssi_jwk::JWK; use thiserror::Error; + +use super::HashAlgorithm; #[derive(Error, Debug)] pub enum ECIESError { #[error("Invalid HKDF key length: {0}")] InvalidHKDFKeyLength(hkdf::InvalidLength), } -pub use publickey::PublicKey; -pub use secretkey::{ParseError, PrivateKeyError, SecretKey}; +#[derive(Error, Debug)] +pub enum Error { + #[error("Error getting SecretKey from bytes: {0}")] + SecretKeyError(String), + #[error("Error parsing private key: {0}")] + PrivateKeyError(#[from] PrivateKeyError), + #[error("Error parsing private key: {0}")] + PublicKeyError(#[from] PublicKeyError), + #[error("ECIES encryption error: {0}")] + ECIESError(#[from] ECIESError), + #[error("Error deriving key: unsupported hash algorithm: {0}")] + UnsupportedHashAlgorithm(String), + #[error("Error deriving key, bad key length: {0}")] + DeriveKeyLengthError(hkdf::InvalidLength), +} + +#[derive(Error, Debug)] +pub enum PublicKeyError { + #[error("Error parsing JWK: {0}")] + PublicKeyError(#[from] ssi_jwk::Error), + #[error("Curve error: {0}")] + CurveError(#[from] k256::elliptic_curve::Error), + #[error("Unsupported Curve: {0}")] + InvalidCurve(String), + #[error("ECIES encryption error: {0}")] + ECIESError(#[from] ECIESError), + #[error("Error parsing public key. Invalid length provided")] + InvalidKey, +} + +#[derive(Error, Debug)] +pub enum PrivateKeyError { + #[error("Error encoding key: {0}")] + EncodeError(#[from] k256::pkcs8::der::Error), + #[error("Error parsing private key: {0}")] + PrivateKeyError(#[from] ssi_jwk::Error), + #[error("Error parsing private key: {0}")] + ParseError(#[from] ParseError), + #[error("Error parsing private key. Invalid length provided")] + InvalidKeyLength, +} + +#[derive(Error, Debug)] +pub enum ParseError { + #[error("Error parsing secp256k1 private key: {0}")] + Secp256k1(#[from] k256::elliptic_curve::Error), + #[error("Error parsing x25519 private key: {0}")] + X25519(String), + #[error("Error parsing ed25519 private key: {0}")] + Ed25519(#[from] ed25519_dalek::SignatureError), +} + +const HKDF_KEY_LENGTH: usize = 32; // * 8 (without sign); // 32 bytes = 256 bits + +trait SecretKeyTrait: Sized { + type KeySize: ArrayLength; + type PublicKey: PublicKeyTrait; + + fn from_bytes(bytes: &[u8]) -> Result; + fn to_bytes(&self) -> Vec; + fn public_key(&self) -> Self::PublicKey; + fn jwk(&self) -> Result; + fn encapsulate(self, pk: Self::PublicKey) -> Result, Error>; + fn decrypt(&self, data: &[u8]) -> Result, Error>; +} + +trait PublicKeyTrait: Sized { + type KeySize: ArrayLength; + type SecretKey: SecretKeyTrait; + + fn from_bytes(bytes: GenericArray) -> Result; + fn to_bytes(&self) -> GenericArray; + fn jwk(&self) -> JWK; + fn decapsulate(self, sk: Self::SecretKey) -> Result, Error>; +} + +trait DeriveKey: SecretKeyTrait { + fn derive_hkdf_key( + &self, + hash_algo: HashAlgorithm, + salt: &[u8], + info: &[u8], + ) -> Result { + if hash_algo != crate::encryption::HashAlgorithm::SHA256 { + return Err(Error::UnsupportedHashAlgorithm( + "Unsupported hash algorithm".to_string(), + )); + } + let mut okm: [u8; HKDF_KEY_LENGTH] = [0; HKDF_KEY_LENGTH]; + + let hkdf = hkdf::Hkdf::::new(Some(salt), &self.to_bytes()); + hkdf.expand(info, &mut okm) + .map_err(ECIESError::InvalidHKDFKeyLength)?; + + Self::from_bytes(okm.as_slice()) + } +} diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs index d6804a3..bcb5a69 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs @@ -1,94 +1,95 @@ -use k256::{elliptic_curve::sec1::ToEncodedPoint, sha2}; -use ssi_jwk::{Base64urlUInt, OctetParams, Params, JWK}; +use aes::cipher::generic_array::GenericArray; +use ssi_jwk::{Params, JWK}; -use super::{ECIESError, SecretKey}; -use thiserror::Error; +use super::{secp256k1, secretkey, x25519, Error, PublicKeyError, PublicKeyTrait}; -#[derive(Error, Debug)] -pub enum Error { - #[error("Error parsing JWK: {0}")] - PublicKeyError(#[from] ssi_jwk::Error), - #[error("Unsupported Curve: {0}")] - InvalidCurve(String), - #[error("ECIES encryption error: {0}")] - ECIESError(#[from] ECIESError), +impl From for PublicKey { + fn from(pk: secp256k1::PublicKey) -> Self { + PublicKey::Secp256k1(pk) + } +} + +impl From for PublicKey { + fn from(pk: x25519::PublicKey) -> Self { + PublicKey::X25519(pk) + } } -#[derive(Clone, Debug, PartialEq)] -#[non_exhaustive] pub enum PublicKey { - Secp256k1(k256::PublicKey), - X25519(x25519_dalek::PublicKey), + Secp256k1(secp256k1::PublicKey), + X25519(x25519::PublicKey), } +// Maximum potential size utilized here based on known key sizes. +static MAX_PUBLIC_KEY_SIZE: usize = 33; + impl PublicKey { - pub fn to_bytes(&self) -> Vec { - match self { - PublicKey::Secp256k1(pk) => pk.to_encoded_point(true).as_bytes().to_vec(), - PublicKey::X25519(pk) => pk.as_bytes().to_vec(), + pub fn from_bytes(bytes: &[u8]) -> Result { + let ga = GenericArray::from_slice(bytes); + match bytes.len() { + 33 => Ok(PublicKey::Secp256k1(secp256k1::PublicKey::from_bytes(*ga)?)), + 32 => { + let mut x = [0u8; 32]; + x.copy_from_slice(bytes); + let ga = GenericArray::from_slice(&x); + + Ok(PublicKey::X25519(x25519::PublicKey::from_bytes(*ga)?)) + } + _ => Err(PublicKeyError::InvalidKey.into()), } } - pub fn jwk(&self) -> JWK { + pub fn to_bytes(&self) -> Vec { + // Direct handling, interpolation balancing satisfies binary exact match self { - PublicKey::Secp256k1(pk) => (*pk).into(), - PublicKey::X25519(pk) => JWK::from(Params::OKP(OctetParams { - curve: "X25519".to_string(), - public_key: Base64urlUInt(pk.to_bytes().to_vec()), - private_key: None, - })), + PublicKey::Secp256k1(pk) => pk.to_bytes().to_vec(), + PublicKey::X25519(pk) => pk.to_bytes().to_vec(), } } - pub fn decapsulate(self, sk: &SecretKey) -> Result<[u8; 32], Error> { - match (self, sk) { - (PublicKey::Secp256k1(pk), SecretKey::Secp256k1(sk)) => { - let mut okm = [0u8; 32]; - k256::ecdh::diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine()) - .extract::(None) - .expand(&[], &mut okm) - .map_err(ECIESError::InvalidHKDFKeyLength)?; - Ok(okm) - } - (PublicKey::X25519(pk), SecretKey::X25519(sk)) => Ok(sk.diffie_hellman(&pk).to_bytes()), - _ => Err(Error::InvalidCurve("Unsupported key type".to_string())), + pub fn jwk(&self) -> JWK { + match self { + PublicKey::Secp256k1(pk) => pk.jwk(), + PublicKey::X25519(pk) => pk.jwk(), } } -} -impl From<&SecretKey> for PublicKey { - fn from(sk: &SecretKey) -> Self { - match sk { - SecretKey::Secp256k1(sk) => PublicKey::Secp256k1(sk.public_key()), - SecretKey::X25519(sk) => PublicKey::X25519(sk.into()), + pub fn decapsulate(self, sk: secretkey::SecretKey) -> Result, Error> { + match self { + PublicKey::Secp256k1(pk) => pk.decapsulate(sk.into()).map(|ga| ga.to_vec()), + PublicKey::X25519(pk) => pk.decapsulate(sk.into()).map(|ga| ga.to_vec()), } } } impl TryFrom for PublicKey { - type Error = Error; + type Error = PublicKeyError; fn try_from(jwk: JWK) -> Result { match jwk.params { - Params::EC(ref ec) => Ok(PublicKey::Secp256k1(ec.try_into()?)), + Params::EC(ref ec) => Ok(PublicKey::Secp256k1(secp256k1::PublicKey { + pk: ec.try_into().map_err(PublicKeyError::PublicKeyError)?, + })), Params::OKP(ref op) => match op.curve.to_lowercase().as_str() { "x25519" => { let mut sk = [0u8; 32]; sk.copy_from_slice(&op.public_key.0); - Ok(PublicKey::X25519(x25519_dalek::PublicKey::from(sk))) + Ok(PublicKey::X25519(x25519::PublicKey { + pk: x25519_dalek::PublicKey::from(sk), + })) } "ed25519" => { - let pk: ed25519_dalek::VerifyingKey = op.try_into()?; - Ok(PublicKey::X25519(x25519_dalek::PublicKey::from( - pk.to_montgomery().to_bytes(), - ))) + let pk: ed25519_dalek::VerifyingKey = + op.try_into().map_err(PublicKeyError::PublicKeyError)?; + Ok(PublicKey::X25519(x25519::PublicKey { + pk: x25519_dalek::PublicKey::from(pk.to_montgomery().to_bytes()), + })) } - _ => Err(Error::InvalidCurve(format!( - "Unsupported curve: {}", - op.curve - ))), + _ => Err( + PublicKeyError::InvalidCurve(format!("Unsupported curve: {}", op.curve)).into(), + ), }, - _ => Err(Error::InvalidCurve("Unsupported key type".to_string())), + _ => Err(PublicKeyError::InvalidCurve("Unsupported key type".to_string()).into()), } } } diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs b/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs new file mode 100644 index 0000000..4236358 --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs @@ -0,0 +1,103 @@ +use std::fmt::Debug; + +use aes::cipher::generic_array::GenericArray; +use k256::{elliptic_curve::sec1::ToEncodedPoint, sha2}; +use ssi_jwk::{secp256k1_parse_private, JWK}; +use tracing::error; +use typenum::U33; + +use super::{ + DeriveKey, ECIESError, Error, ParseError, PrivateKeyError, PublicKeyError, PublicKeyTrait, + SecretKeyTrait, +}; + +pub struct PublicKey { + pub pk: k256::PublicKey, +} + +impl PublicKeyTrait for PublicKey { + type KeySize = U33; + type SecretKey = SecretKey; + + fn from_bytes(bytes: GenericArray) -> Result { + let pk = k256::PublicKey::from_sec1_bytes(&bytes).map_err(PublicKeyError::CurveError)?; + Ok(Self { pk }) + } + + fn to_bytes(&self) -> GenericArray { + let v = self.pk.to_encoded_point(false).to_bytes().to_vec(); + GenericArray::from_iter(v[..32].iter().copied()) + } + + fn jwk(&self) -> JWK { + self.pk.into() + } + + fn decapsulate(self, sk: Self::SecretKey) -> Result, Error> { + sk.encapsulate(self) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct SecretKey { + sk: k256::SecretKey, +} + +impl DeriveKey for SecretKey {} + +impl SecretKeyTrait for SecretKey { + type KeySize = U33; + type PublicKey = PublicKey; + + fn from_bytes(bytes: &[u8]) -> Result { + let sk: k256::SecretKey = k256::SecretKey::from_slice(bytes).map_err(|e| { + error!("Error parsing secp256k1 private key: {:?}", e); + PrivateKeyError::ParseError(ParseError::Secp256k1(e)) + })?; + Ok(SecretKey { sk }) + } + + fn to_bytes(&self) -> Vec { + self.sk.to_bytes().to_vec() + } + + fn public_key(&self) -> Self::PublicKey { + let pk = self.sk.public_key(); + PublicKey { pk } + } + + fn jwk(&self) -> Result { + let mut jwk: JWK = self.sk.public_key().into(); + let pjwk = secp256k1_parse_private( + &self + .sk + .to_sec1_der() + .map_err(PrivateKeyError::EncodeError)?, + ) + .map_err(PrivateKeyError::PrivateKeyError)?; + jwk.params = pjwk.params.clone(); + + Ok(jwk) + } + + fn encapsulate(self, pk: Self::PublicKey) -> Result, Error> { + let mut okm: GenericArray = GenericArray::default(); + + k256::ecdh::diffie_hellman(self.sk.to_nonzero_scalar(), pk.pk.as_affine()) + .extract::(None) + .expand(&[], &mut okm) + .map_err(ECIESError::InvalidHKDFKeyLength)?; + + Ok(okm) + } + + fn decrypt(&self, data: &[u8]) -> Result, Error> { + todo!(); + } +} + +impl From for SecretKey { + fn from(sk: k256::SecretKey) -> Self { + SecretKey { sk } + } +} diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs index 709efb9..b35ae6e 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs @@ -1,156 +1,95 @@ -use std::fmt::Debug; - -use k256::sha2; -use ssi_jwk::{secp256k1_parse_private, Base64urlUInt, OctetParams, Params, JWK}; -use thiserror::Error; - -use super::{ECIESError, PublicKey}; - -#[derive(Error, Debug)] -pub enum Error { - #[error("Error getting SecretKey from bytes: {0}")] - SecretKeyError(String), - #[error("Error parsing private key: {0}")] - PrivateKeyError(#[from] PrivateKeyError), - #[error("ECIES encryption error: {0}")] - ECIESError(#[from] ECIESError), -} +use ssi_jwk::{Params, JWK}; -#[derive(Error, Debug)] -pub enum PrivateKeyError { - #[error("Error encoding key: {0}")] - EncodeError(#[from] k256::pkcs8::der::Error), - #[error("Error parsing private key: {0}")] - PrivateKeyError(#[from] ssi_jwk::Error), - #[error("Error parsing private key: {0}")] - ParseError(#[from] ParseError), -} +use crate::encryption::HashAlgorithm; -#[derive(Error, Debug)] -pub enum ParseError { - #[error("Error parsing secp256k1 private key: {0}")] - Secp256k1(#[from] k256::elliptic_curve::Error), - #[error("Error parsing x25519 private key: {0}")] - X25519(String), - #[error("Error parsing ed25519 private key: {0}")] - Ed25519(#[from] ed25519_dalek::SignatureError), -} +use super::{ + publickey::PublicKey, secp256k1, x25519, DeriveKey, Error, PrivateKeyError, SecretKeyTrait, +}; -/// SecretKey represents a private asymmetric key. Supported key types are: -/// - secp256k1 -/// - x25519 -/// - ed25519 (converted to x25519) -/// -/// secp256k1 keys are preferred, since the x26619 keys are converted from ed25519 keys. See -/// also: https://eprint.iacr.org/2021/509 -#[derive(Clone)] -#[non_exhaustive] +#[derive(Debug, PartialEq, Clone)] pub enum SecretKey { - Secp256k1(k256::SecretKey), - X25519(x25519_dalek::StaticSecret), -} - -impl PartialEq for SecretKey { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (SecretKey::Secp256k1(sk1), SecretKey::Secp256k1(sk2)) => sk1 == sk2, - (SecretKey::X25519(sk1), SecretKey::X25519(sk2)) => sk1.as_bytes() == sk2.as_bytes(), - _ => false, - } - } -} - -impl Debug for SecretKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fn x25519_bytes(sk: &x25519_dalek::StaticSecret) -> [u8; 32] { - let pk: x25519_dalek::PublicKey = sk.into(); - pk.to_bytes() - } - - match self { - SecretKey::Secp256k1(sk) => write!(f, "Secp256k1(pub: {:?})", sk.public_key()), - SecretKey::X25519(sk) => write!(f, "X25519(pub: {:?})", x25519_bytes(sk)), - } - } + Secp256k1(secp256k1::SecretKey), + X25519(x25519::SecretKey), } impl SecretKey { pub fn to_bytes(&self) -> Vec { match self { SecretKey::Secp256k1(sk) => sk.to_bytes().to_vec(), - SecretKey::X25519(sk) => sk.as_bytes().to_vec(), + SecretKey::X25519(sk) => sk.to_bytes().to_vec(), } } pub fn public_key(&self) -> PublicKey { - self.into() + match self { + SecretKey::Secp256k1(sk) => PublicKey::Secp256k1(sk.public_key()), + SecretKey::X25519(sk) => PublicKey::X25519(sk.public_key()), + } } pub fn jwk(&self) -> Result { match self { - SecretKey::Secp256k1(sk) => { - let mut jwk: JWK = sk.public_key().into(); - let pjwk = secp256k1_parse_private( - &sk.to_sec1_der().map_err(PrivateKeyError::EncodeError)?, - ) - .map_err(PrivateKeyError::PrivateKeyError)?; - jwk.params = pjwk.params.clone(); - - Ok(jwk) - } - SecretKey::X25519(sk) => { - let pk: x25519_dalek::PublicKey = sk.into(); - let jwk = JWK::from(Params::OKP(OctetParams { - curve: "X25519".to_string(), - public_key: Base64urlUInt(pk.as_bytes().to_vec()), - private_key: Some(Base64urlUInt(sk.as_bytes().to_vec())), - })); - - Ok(jwk) - } + SecretKey::Secp256k1(sk) => sk.jwk(), + SecretKey::X25519(sk) => sk.jwk(), } } - pub fn encapsulate(self, pk: PublicKey) -> Result<[u8; 32], Error> { - // TODO support key compression for secp256k1 hkdf key and ephemeral key + pub fn encapsulate(self, pk: PublicKey) -> Result, Error> { match (self, pk) { (SecretKey::Secp256k1(sk), PublicKey::Secp256k1(pk)) => { - let mut okm = [0u8; 32]; - k256::ecdh::diffie_hellman(sk.to_nonzero_scalar(), pk.as_affine()) - .extract::(None) - .expand(&[], &mut okm) - .map_err(ECIESError::InvalidHKDFKeyLength)?; + sk.encapsulate(pk).map(|ga| ga.to_vec()) + } + (SecretKey::X25519(sk), PublicKey::X25519(pk)) => { + sk.encapsulate(pk).map(|ga| ga.to_vec()) + } + _ => Err(Error::SecretKeyError("Invalid key pair".to_string())), + } + } - Ok(okm) + pub fn derive_hkdf( + &self, + hash_algo: HashAlgorithm, + salt: &[u8], + info: &[u8], + ) -> Result { + match self { + SecretKey::Secp256k1(sk) => { + Ok(Self::Secp256k1(sk.derive_hkdf_key(hash_algo, salt, info)?)) } - (SecretKey::X25519(sk), PublicKey::X25519(pk)) => Ok(sk.diffie_hellman(&pk).to_bytes()), - _ => Err(Error::SecretKeyError("Unsupported key type".to_string())), + SecretKey::X25519(sk) => Ok(Self::X25519(sk.derive_hkdf_key(hash_algo, salt, info)?)), } } -} -/// TryFrom (&SecretKey, &[u8; 32]) for SecretKey implements the conversion of a HKDF derived key -/// into a SecretKey of the same type -impl TryFrom<(&SecretKey, &[u8; 32])> for SecretKey { - type Error = Error; + pub fn decrypt(&self, data: &[u8]) -> Result, Error> { + todo!() + } +} - fn try_from(value: (&SecretKey, &[u8; 32])) -> Result { - let sk = match value.0 { - SecretKey::Secp256k1(_) => { - let sk: k256::SecretKey = k256::SecretKey::from_slice(value.1) - .map_err(|e| PrivateKeyError::ParseError(ParseError::Secp256k1(e)))?; - SecretKey::Secp256k1(sk) - } - SecretKey::X25519(_) => { - let mut sk_bytes = [0u8; 32]; - sk_bytes.copy_from_slice(value.1); - let sk: x25519_dalek::StaticSecret = x25519_dalek::StaticSecret::from(sk_bytes); +impl From for secp256k1::SecretKey { + fn from(sk: SecretKey) -> Self { + match sk { + SecretKey::Secp256k1(sk) => sk, + _ => panic!("Invalid conversion"), + } + } +} - SecretKey::X25519(sk) - } - }; +impl From for x25519::SecretKey { + fn from(sk: SecretKey) -> Self { + match sk { + SecretKey::X25519(sk) => sk, + _ => panic!("Invalid conversion"), + } + } +} - Ok(sk) +impl TryFrom for JWK { + type Error = Error; + fn try_from(sk: SecretKey) -> Result { + match sk { + SecretKey::Secp256k1(sk) => sk.jwk(), + SecretKey::X25519(sk) => sk.jwk(), + } } } @@ -164,30 +103,28 @@ impl TryFrom for SecretKey { .try_into() .map_err(PrivateKeyError::PrivateKeyError)?; - Ok(SecretKey::Secp256k1(sk)) + Ok(SecretKey::Secp256k1(sk.into())) } Params::OKP(okpparams) => { if okpparams.curve.to_lowercase() == "x25519" { - let sk: [u8; 32] = match okpparams.private_key.clone() { + let sk: x25519_dalek::StaticSecret = match okpparams.private_key.clone() { Some(sk) => { let mut sk_bytes = [0u8; 32]; sk_bytes.copy_from_slice(sk.0.as_slice()); - sk_bytes + sk_bytes.into() } None => { return Err(Error::SecretKeyError("Missing private key".to_string())) } }; - Ok(SecretKey::X25519(x25519_dalek::StaticSecret::from(sk))) + Ok(SecretKey::X25519(sk.into())) } else if okpparams.curve.to_lowercase() == "ed25519" { let edsk: ed25519_dalek::SigningKey = (&okpparams) .try_into() .map_err(PrivateKeyError::PrivateKeyError)?; - Ok(SecretKey::X25519(x25519_dalek::StaticSecret::from( - edsk.to_scalar_bytes(), - ))) + Ok(SecretKey::X25519(edsk.into())) } else { Err(Error::SecretKeyError(format!( "Unsupported curve type: {}", @@ -200,42 +137,31 @@ impl TryFrom for SecretKey { } } -impl From for SecretKey { - fn from(sk: ed25519_dalek::SigningKey) -> Self { - SecretKey::X25519(x25519_dalek::StaticSecret::from(sk.to_scalar_bytes())) - } -} - -impl TryFrom for JWK { - type Error = Error; - fn try_from(sk: SecretKey) -> Result { - sk.jwk() - } -} - #[cfg(test)] mod test { + use crate::encryption::asymmetric::SecretKeyTrait; + use super::*; use ssi_jwk::JWK; use std::convert::TryInto; #[test] fn test_secret_key() { - let sk = SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())); + let sk: secp256k1::SecretKey = k256::SecretKey::random(&mut rand::thread_rng()).into(); let jwk: JWK = sk.jwk().unwrap(); let sk2: SecretKey = jwk.try_into().unwrap(); - assert_eq!(sk, sk2); + assert_eq!(sk, sk2.into()); - let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( - rand::thread_rng(), - )); + let sk: x25519::SecretKey = + x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng()).into(); let jwk: JWK = sk.jwk().unwrap(); let sk2: SecretKey = jwk.try_into().unwrap(); - assert_eq!(sk, sk2); + assert_eq!(sk, sk2.into()); - let sk: SecretKey = ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(); + let sk: x25519::SecretKey = + ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(); let jwk: JWK = sk.jwk().unwrap(); let sk2: SecretKey = jwk.try_into().unwrap(); - assert_eq!(sk, sk2); + assert_eq!(sk, sk2.into()); } } diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs b/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs new file mode 100644 index 0000000..b1338eb --- /dev/null +++ b/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs @@ -0,0 +1,124 @@ +use std::fmt::Debug; + +use aes::cipher::generic_array::GenericArray; +use ssi_jwk::{Base64urlUInt, OctetParams, Params, JWK}; +use typenum::U32; + +use super::{DeriveKey, Error, PrivateKeyError, PublicKeyTrait, SecretKeyTrait}; + +pub struct PublicKey { + pub pk: x25519_dalek::PublicKey, +} + +impl PublicKeyTrait for PublicKey { + type KeySize = U32; + type SecretKey = SecretKey; + + fn from_bytes(bytes: GenericArray) -> Result { + let mut pk = [0u8; 32]; + pk.copy_from_slice(&bytes); + Ok(Self { + pk: x25519_dalek::PublicKey::from(pk), + }) + } + + fn to_bytes(&self) -> GenericArray { + let v = self.pk.as_bytes().to_vec(); + GenericArray::from_iter(v.iter().copied()) + } + + fn jwk(&self) -> JWK { + JWK::from(Params::OKP(OctetParams { + curve: "X25519".to_string(), + public_key: Base64urlUInt(self.to_bytes().to_vec()), + private_key: None, + })) + } + + fn decapsulate(self, sk: Self::SecretKey) -> Result, Error> { + todo!() + } +} + +#[derive(Clone)] +pub struct SecretKey { + sk: x25519_dalek::StaticSecret, +} + +impl DeriveKey for SecretKey {} + +impl Debug for SecretKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("X25519") + .field("pk", &self.public_key().to_bytes()) + .finish() + } +} + +impl PartialEq for SecretKey { + fn eq(&self, other: &Self) -> bool { + self.sk.to_bytes() == other.sk.to_bytes() + } +} + +impl SecretKeyTrait for SecretKey { + type KeySize = U32; + type PublicKey = PublicKey; + + fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() != 32 { + return Err(PrivateKeyError::InvalidKeyLength.into()); + } + + let mut key = [0u8; 32]; + key.copy_from_slice(bytes); + + Ok(SecretKey { + sk: x25519_dalek::StaticSecret::from(key), + }) + } + + fn to_bytes(&self) -> Vec { + self.sk.as_bytes().to_vec() + } + + fn public_key(&self) -> Self::PublicKey { + let pk = x25519_dalek::PublicKey::from(&self.sk); + PublicKey { pk } + } + + fn jwk(&self) -> Result { + let pk: x25519_dalek::PublicKey = (&self.sk).into(); + let jwk = JWK::from(Params::OKP(OctetParams { + curve: "X25519".to_string(), + public_key: Base64urlUInt(pk.as_bytes().to_vec()), + private_key: Some(Base64urlUInt(self.sk.as_bytes().to_vec())), + })); + + Ok(jwk) + } + + fn encapsulate(self, pk: Self::PublicKey) -> Result, Error> { + let shared = self.sk.diffie_hellman(&pk.pk).to_bytes(); + + Ok(GenericArray::from_iter(shared[..32].iter().copied())) + } + + fn decrypt(&self, data: &[u8]) -> Result, Error> { + todo!(); + } +} + +impl From for SecretKey { + fn from(sk: ed25519_dalek::SigningKey) -> Self { + SecretKey { + sk: x25519_dalek::StaticSecret::from(sk.to_scalar_bytes()), + } + } +} + +impl From for SecretKey { + fn from(sk: x25519_dalek::StaticSecret) -> Self { + SecretKey { sk } + } +} diff --git a/crates/dwn-rs-core/src/encryption/errors.rs b/crates/dwn-rs-core/src/encryption/errors.rs index cbcf28c..a6c95aa 100644 --- a/crates/dwn-rs-core/src/encryption/errors.rs +++ b/crates/dwn-rs-core/src/encryption/errors.rs @@ -1,12 +1,12 @@ use thiserror::Error; -use super::asymmetric::secretkey::Error as AsymmetricSecretKeyError; +use super::asymmetric; use super::hd_keys::Error as HDKeysError; #[derive(Error, Debug)] pub enum Error { #[error("Error getting JWK secret: {0}")] - JWKSecretKeyError(#[from] AsymmetricSecretKeyError), + JWKSecretKeyError(#[from] asymmetric::Error), #[error("Error deriving key: {0}")] DeriveKeyError(#[from] HDKeysError), } diff --git a/crates/dwn-rs-core/src/encryption/hd_keys.rs b/crates/dwn-rs-core/src/encryption/hd_keys.rs index 82bb851..82dc364 100644 --- a/crates/dwn-rs-core/src/encryption/hd_keys.rs +++ b/crates/dwn-rs-core/src/encryption/hd_keys.rs @@ -1,27 +1,23 @@ -use k256::sha2; use ssi_jwk::JWK; -use super::{asymmetric, DerivationScheme, SecretKey}; +use super::{ + asymmetric::{self, secretkey::SecretKey}, + DerivationScheme, +}; use thiserror::Error as ThisError; -const HKDF_KEY_LENGTH: usize = 32; // * 8; // 32 bytes = 256 bits - #[derive(Debug, ThisError)] pub enum Error { #[error("Error getting JWK secret key: {0}")] JWKSecretKeyError(#[from] ssi_jwk::Error), #[error("Error getting SecretKey from bytes: {0}")] - SecretKeyError(#[from] asymmetric::secretkey::Error), - #[error("Error deriving key, bad key length: {0}")] - DeriveKeyLengthError(hkdf::InvalidLength), + SecretKeyError(#[from] asymmetric::Error), #[error("Error deriving key: {0}")] DeriveKeyError(#[from] k256::elliptic_curve::Error), #[error("Error encoding key: {0}")] EncodeError(#[from] k256::pkcs8::der::Error), #[error("Invalid path segment: {0}")] InvalidPathSegment(String), - #[error("Unsupported hash algorithm: {0}")] - UnsupportedHashAlgorithm(String), #[error("Unsupported key type")] UnsupportedKeyType, } @@ -90,34 +86,14 @@ impl DerivedPrivateJWK { ancestor_key.to_owned(), |key, segment| -> Result { let seg = segment.as_bytes(); - Self::derive_hkdf_key(HashAlgorithm::SHA256, &key, seg) + key.derive_hkdf(HashAlgorithm::SHA256, &[], seg) + .map_err(Error::SecretKeyError) }, )?; Ok(sk) } - pub fn derive_hkdf_key( - hash_algo: HashAlgorithm, - initial_key_material: &SecretKey, - info: &[u8], - ) -> Result { - if hash_algo != HashAlgorithm::SHA256 { - // TODO support more algorithms - return Err(Error::UnsupportedHashAlgorithm( - "Unsupported hash algorithm".to_string(), - )); - } - - let mut okm = [0u8; HKDF_KEY_LENGTH]; - - hkdf::Hkdf::::new(None, initial_key_material.to_bytes().as_slice()) - .expand(info, &mut okm) - .map_err(Error::DeriveKeyLengthError)?; - - Ok(SecretKey::try_from((initial_key_material, &okm))?) - } - fn validate_path(path: &[&str]) -> Result<(), Error> { // check if any path segments are empty if path.iter().any(|s| s.is_empty()) { @@ -132,6 +108,7 @@ impl DerivedPrivateJWK { mod tests { use super::*; use ssi_jwk::JWK; + use tracing_test::traced_test; struct JWKTestTable { private_jwk: JWK, @@ -141,6 +118,7 @@ mod tests { secret_key: SecretKey, } + #[traced_test] #[test] fn test_derive() { let tcs = vec![ @@ -149,9 +127,9 @@ mod tests { }, JWKTestTable { private_jwk: { - let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( - rand::thread_rng(), - )); + let sk = SecretKey::X25519( + x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng()).into(), + ); sk.try_into().unwrap() }, }, @@ -191,9 +169,9 @@ mod tests { }, JWKTestTable { private_jwk: { - let sk = SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( - rand::thread_rng(), - )); + let sk = SecretKey::X25519( + x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng()).into(), + ); sk.try_into().unwrap() }, }, @@ -224,20 +202,24 @@ mod tests { fn test_derive_ancestor_chain_path() { let tcs = vec![ SecretKeyTestTable { - secret_key: SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())), + secret_key: SecretKey::Secp256k1( + k256::SecretKey::random(&mut rand::thread_rng()).into(), + ), }, SecretKeyTestTable { - secret_key: SecretKey::X25519(x25519_dalek::StaticSecret::random_from_rng( - rand::thread_rng(), - )), + secret_key: SecretKey::X25519( + x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng()).into(), + ), }, SecretKeyTestTable { - secret_key: ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(), + secret_key: SecretKey::X25519( + ed25519_dalek::SigningKey::generate(&mut rand::thread_rng()).into(), + ), }, ]; for tc in tcs { - let root_key = tc.secret_key.clone(); + let root_key = tc.secret_key; let path_to_g = ["a", "b", "c", "d", "e", "f", "g"].as_slice(); let path_to_d = ["a", "b", "c", "d"].as_slice(); @@ -255,7 +237,7 @@ mod tests { #[test] fn test_invalid_path() { let root_key: SecretKey = - SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng())); + SecretKey::Secp256k1(k256::SecretKey::random(&mut rand::thread_rng()).into()); let path = ["a", "", "c"].as_slice(); let result = DerivedPrivateJWK::derive_secret(&root_key, path); diff --git a/crates/dwn-rs-core/src/encryption/mod.rs b/crates/dwn-rs-core/src/encryption/mod.rs index 66b6138..9b4d8a5 100644 --- a/crates/dwn-rs-core/src/encryption/mod.rs +++ b/crates/dwn-rs-core/src/encryption/mod.rs @@ -3,7 +3,7 @@ pub mod errors; pub mod hd_keys; pub mod symmetric; -pub use asymmetric::SecretKey; +pub use asymmetric::secretkey::SecretKey; pub use errors::Error; pub use hd_keys::{DerivedPrivateJWK, HashAlgorithm}; From 8dfb96b6c10cde7564f9a393b0acbb05fe2a212e Mon Sep 17 00:00:00 2001 From: Dan Enman Date: Wed, 29 Jan 2025 20:00:02 -0400 Subject: [PATCH 23/23] feat: (EC)IES implementation for shared key encryption --- .../src/encryption/asymmetric/mod.rs | 242 +++++++++++++++++- .../src/encryption/asymmetric/publickey.rs | 28 +- .../src/encryption/asymmetric/secp256k1.rs | 239 ++++++++++++++++- .../src/encryption/asymmetric/secretkey.rs | 9 +- .../src/encryption/asymmetric/x25519.rs | 212 ++++++++++++++- crates/dwn-rs-core/src/encryption/mod.rs | 2 + .../src/encryption/symmetric/aead.rs | 6 +- .../src/encryption/symmetric/aes_ctr.rs | 6 + .../src/encryption/symmetric/mod.rs | 6 + 9 files changed, 703 insertions(+), 47 deletions(-) diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs index 2814a18..e7813fe 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/mod.rs @@ -1,14 +1,17 @@ +use crate::encryption::symmetric::{Encryption, IVEncryption}; pub mod publickey; pub(crate) mod secp256k1; pub mod secretkey; pub(crate) mod x25519; use aes::cipher::{generic_array::GenericArray, ArrayLength}; +use bytes::BytesMut; use k256::sha2; use ssi_jwk::JWK; use thiserror::Error; +use typenum::Unsigned; -use super::HashAlgorithm; +use super::{symmetric, HashAlgorithm}; #[derive(Error, Debug)] pub enum ECIESError { #[error("Invalid HKDF key length: {0}")] @@ -29,6 +32,8 @@ pub enum Error { UnsupportedHashAlgorithm(String), #[error("Error deriving key, bad key length: {0}")] DeriveKeyLengthError(hkdf::InvalidLength), + #[error("Error encrypting symmetric key: {0}")] + EncryptionError(#[from] symmetric::Error), } #[derive(Error, Debug)] @@ -71,24 +76,83 @@ const HKDF_KEY_LENGTH: usize = 32; // * 8 (without sign); // 32 bytes = 256 bits trait SecretKeyTrait: Sized { type KeySize: ArrayLength; - type PublicKey: PublicKeyTrait; + type PublicKey: PublicKeyTrait; + + fn generate_keypair() -> (Self, Self::PublicKey); fn from_bytes(bytes: &[u8]) -> Result; fn to_bytes(&self) -> Vec; fn public_key(&self) -> Self::PublicKey; fn jwk(&self) -> Result; - fn encapsulate(self, pk: Self::PublicKey) -> Result, Error>; - fn decrypt(&self, data: &[u8]) -> Result, Error>; + + #[allow(clippy::type_complexity)] + fn encapsulate( + &self, + pk: &Self::PublicKey, + ) -> Result< + GenericArray< + u8, + <::SymmetricEncryption as Encryption>::KeySize, + >, + Error, + >; + + // (EC)IES decryption for SecretKey + fn decrypt(&self, data: &[u8]) -> Result, Error> { + let ephemeral_pk_len = <::KeySize as Unsigned>::USIZE; + let ephemeral_pk = Self::PublicKey::from_bytes(&data[0..ephemeral_pk_len])?; + + let nonce_size = <<::SymmetricEncryption as IVEncryption>::NonceSize as Unsigned>::USIZE; + let nonce_start = ephemeral_pk_len; + let nonce_end = nonce_start + nonce_size; + + let nonce = GenericArray::from_slice(&data[nonce_start..nonce_end]).to_owned(); + let ciphertext = &data[nonce_end..]; + + let key = self.encapsulate(&ephemeral_pk)?; + + let mut ciper = ::SymmetricEncryption::new(key)?; + let mut buf = BytesMut::from(ciphertext); + let plaintext = ciper.with_iv(nonce)?.decrypt(&mut buf)?; + + Ok(plaintext.to_vec()) + } } trait PublicKeyTrait: Sized { type KeySize: ArrayLength; - type SecretKey: SecretKeyTrait; + type SecretKey: SecretKeyTrait; + type SymmetricEncryption: Encryption + IVEncryption; - fn from_bytes(bytes: GenericArray) -> Result; - fn to_bytes(&self) -> GenericArray; + fn from_bytes(bytes: &[u8]) -> Result; + fn to_bytes(&self) -> Vec; fn jwk(&self) -> JWK; - fn decapsulate(self, sk: Self::SecretKey) -> Result, Error>; + fn decapsulate( + &self, + sk: Self::SecretKey, + ) -> Result< + GenericArray::SymmetricEncryption as Encryption>::KeySize>, + Error, + >; + + // (EC)IES encryption for PublicKey + fn encrypt(&self, data: &[u8]) -> Result, Error> { + let (empheral_sk, ephemeral_pk) = Self::SecretKey::generate_keypair(); + let key = empheral_sk.encapsulate(self)?; + + let mut cipher = Self::SymmetricEncryption::new(key)?; + let nonce = cipher.nonce(); + + let mut buf = BytesMut::from(data); + let ciphertext = cipher.with_iv(nonce.clone())?.encrypt(&mut buf)?; + + let mut res = Vec::new(); + res.extend_from_slice(&ephemeral_pk.to_bytes()); + res.extend_from_slice(&nonce); + res.extend_from_slice(&ciphertext); // ciphertext includes tag + + Ok(res) + } } trait DeriveKey: SecretKeyTrait { @@ -112,3 +176,165 @@ trait DeriveKey: SecretKeyTrait { Self::from_bytes(okm.as_slice()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encrypt_decrypt() { + // Generate a key pair + let (sk, pk) = secp256k1::SecretKey::generate_keypair(); + + // Plaintext to encrypt + let plaintext = b"Hello, world!"; + + // Encrypt using the public key + let ciphertext = pk.encrypt(plaintext).unwrap(); + + // Decrypt using the secret key + let decrypted = sk.decrypt(&ciphertext).unwrap(); + + // Ensure the decrypted data matches the original plaintext + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_encrypt_decrypt_large_data() { + // Generate a key pair + let (sk, pk) = x25519::SecretKey::generate_keypair(); + + // Large plaintext to encrypt + let plaintext = vec![0u8; 1024 * 1024]; // 1 MB of data + + // Encrypt using the public key + let ciphertext = pk.encrypt(&plaintext).unwrap(); + + // Decrypt using the secret key + let decrypted = sk.decrypt(&ciphertext).unwrap(); + + // Ensure the decrypted data matches the original plaintext + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_encrypt_decrypt_empty_data() { + // Generate a key pair + let (sk, pk) = secp256k1::SecretKey::generate_keypair(); + + // Empty plaintext to encrypt + let plaintext = b""; + + // Encrypt using the public key + let ciphertext = pk.encrypt(plaintext).unwrap(); + + // Decrypt using the secret key + let decrypted = sk.decrypt(&ciphertext).unwrap(); + + // Ensure the decrypted data matches the original plaintext + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_encrypt_decrypt_invalid_ciphertext() { + // Generate a key pair + let (sk, _) = secp256k1::SecretKey::generate_keypair(); + + // Invalid ciphertext (too short) + let invalid_ciphertext = vec![0u8; 47]; // Less than 48 bytes (ephemeral_pk + nonce) + + // Attempt to decrypt + let result = sk.decrypt(&invalid_ciphertext); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::PublicKeyError(_)))); + } + + #[test] + fn test_encrypt_decrypt_invalid_ephemeral_public_key() { + // Generate a key pair + let (sk, _) = secp256k1::SecretKey::generate_keypair(); + + // Invalid ciphertext with invalid ephemeral public key + let mut invalid_ciphertext = vec![0u8; 48]; // 32 bytes (ephemeral_pk) + 16 bytes (nonce) + invalid_ciphertext.extend_from_slice(b"invalid ciphertext"); + + // Attempt to decrypt + let result = sk.decrypt(&invalid_ciphertext); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::PublicKeyError(_)))); + } + + #[test] + fn test_encrypt_decrypt_invalid_nonce() { + // Generate a key pair + let (sk, pk) = secp256k1::SecretKey::generate_keypair(); + + // Encrypt valid data + let plaintext = b"Hello, world!"; + let mut ciphertext = pk.encrypt(plaintext).unwrap(); + + // Corrupt the nonce in the ciphertext + ciphertext[32..48].copy_from_slice(&[0u8; 16]); // Overwrite nonce with zeros + + // Attempt to decrypt + let result = sk.decrypt(&ciphertext); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::EncryptionError(_)))); + } + + #[test] + fn test_encrypt_decrypt_invalid_ciphertext_tag() { + // Generate a key pair + let (sk, pk) = secp256k1::SecretKey::generate_keypair(); + + // Encrypt valid data + let plaintext = b"Hello, world!"; + let mut ciphertext = pk.encrypt(plaintext).unwrap(); + + // Corrupt the tag in the ciphertext + let len = ciphertext.len(); + ciphertext[len - 1] ^= 0xFF; // Flip the last byte of the tag + + // Attempt to decrypt + let result = sk.decrypt(&ciphertext); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::EncryptionError(_)))); + } + + #[test] + fn test_encrypt_decrypt_key_mismatch() { + // Generate two key pairs + let (_, pk1) = secp256k1::SecretKey::generate_keypair(); + let (sk2, _) = secp256k1::SecretKey::generate_keypair(); + + // Encrypt using the first public key + let plaintext = b"Hello, world!"; + let ciphertext = pk1.encrypt(plaintext).unwrap(); + + // Attempt to decrypt using the second secret key + let result = sk2.decrypt(&ciphertext); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::EncryptionError(_)))); + } + + #[test] + fn test_encapsulate_decapsulate() { + // Generate two key pairs + let (sk1, pk1) = secp256k1::SecretKey::generate_keypair(); + let (sk2, pk2) = secp256k1::SecretKey::generate_keypair(); + + // Encapsulate using sk1 and pk2 + let shared_secret1 = sk1.encapsulate(&pk2).unwrap(); + + // Encapsulate using sk2 and pk1 + let shared_secret2 = sk2.encapsulate(&pk1).unwrap(); + + // Ensure the shared secrets match + assert_eq!(shared_secret1, shared_secret2); + } +} diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs index bcb5a69..585408c 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/publickey.rs @@ -1,4 +1,3 @@ -use aes::cipher::generic_array::GenericArray; use ssi_jwk::{Params, JWK}; use super::{secp256k1, secretkey, x25519, Error, PublicKeyError, PublicKeyTrait}; @@ -21,19 +20,18 @@ pub enum PublicKey { } // Maximum potential size utilized here based on known key sizes. -static MAX_PUBLIC_KEY_SIZE: usize = 33; impl PublicKey { pub fn from_bytes(bytes: &[u8]) -> Result { - let ga = GenericArray::from_slice(bytes); match bytes.len() { - 33 => Ok(PublicKey::Secp256k1(secp256k1::PublicKey::from_bytes(*ga)?)), + 33 => Ok(PublicKey::Secp256k1(secp256k1::PublicKey::from_bytes( + bytes, + )?)), 32 => { let mut x = [0u8; 32]; x.copy_from_slice(bytes); - let ga = GenericArray::from_slice(&x); - Ok(PublicKey::X25519(x25519::PublicKey::from_bytes(*ga)?)) + Ok(PublicKey::X25519(x25519::PublicKey::from_bytes(bytes)?)) } _ => Err(PublicKeyError::InvalidKey.into()), } @@ -60,6 +58,13 @@ impl PublicKey { PublicKey::X25519(pk) => pk.decapsulate(sk.into()).map(|ga| ga.to_vec()), } } + + pub fn encrypt(&self, data: &[u8]) -> Result, Error> { + match self { + PublicKey::Secp256k1(pk) => pk.encrypt(data), + PublicKey::X25519(pk) => pk.encrypt(data), + } + } } impl TryFrom for PublicKey { @@ -85,11 +90,14 @@ impl TryFrom for PublicKey { pk: x25519_dalek::PublicKey::from(pk.to_montgomery().to_bytes()), })) } - _ => Err( - PublicKeyError::InvalidCurve(format!("Unsupported curve: {}", op.curve)).into(), - ), + _ => Err(PublicKeyError::InvalidCurve(format!( + "Unsupported curve: {}", + op.curve + ))), }, - _ => Err(PublicKeyError::InvalidCurve("Unsupported key type".to_string()).into()), + _ => Err(PublicKeyError::InvalidCurve( + "Unsupported key type".to_string(), + )), } } } diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs b/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs index 4236358..3ddd676 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/secp256k1.rs @@ -1,10 +1,12 @@ use std::fmt::Debug; use aes::cipher::generic_array::GenericArray; -use k256::{elliptic_curve::sec1::ToEncodedPoint, sha2}; +use k256::sha2; use ssi_jwk::{secp256k1_parse_private, JWK}; use tracing::error; -use typenum::U33; +use typenum::{U32, U33}; + +use crate::encryption::symmetric::{self, Encryption}; use super::{ DeriveKey, ECIESError, Error, ParseError, PrivateKeyError, PublicKeyError, PublicKeyTrait, @@ -18,22 +20,31 @@ pub struct PublicKey { impl PublicKeyTrait for PublicKey { type KeySize = U33; type SecretKey = SecretKey; + type SymmetricEncryption = symmetric::aead::AES256GCM; - fn from_bytes(bytes: GenericArray) -> Result { - let pk = k256::PublicKey::from_sec1_bytes(&bytes).map_err(PublicKeyError::CurveError)?; + fn from_bytes(bytes: &[u8]) -> Result { + let pk = k256::PublicKey::from_sec1_bytes(bytes).map_err(PublicKeyError::CurveError)?; Ok(Self { pk }) } - fn to_bytes(&self) -> GenericArray { - let v = self.pk.to_encoded_point(false).to_bytes().to_vec(); - GenericArray::from_iter(v[..32].iter().copied()) + fn to_bytes(&self) -> Vec { + self.pk.to_sec1_bytes().to_vec() } fn jwk(&self) -> JWK { self.pk.into() } - fn decapsulate(self, sk: Self::SecretKey) -> Result, Error> { + fn decapsulate( + &self, + sk: Self::SecretKey, + ) -> Result< + GenericArray< + u8, + <::SymmetricEncryption as symmetric::Encryption>::KeySize, + >, + Error, + > { sk.encapsulate(self) } } @@ -46,7 +57,7 @@ pub struct SecretKey { impl DeriveKey for SecretKey {} impl SecretKeyTrait for SecretKey { - type KeySize = U33; + type KeySize = U32; type PublicKey = PublicKey; fn from_bytes(bytes: &[u8]) -> Result { @@ -80,8 +91,17 @@ impl SecretKeyTrait for SecretKey { Ok(jwk) } - fn encapsulate(self, pk: Self::PublicKey) -> Result, Error> { - let mut okm: GenericArray = GenericArray::default(); + fn encapsulate( + &self, + pk: &Self::PublicKey, + ) -> Result< + GenericArray::SymmetricEncryption as symmetric::Encryption>::KeySize>, + Error, + >{ + let mut okm: GenericArray< + u8, + <::SymmetricEncryption as Encryption>::KeySize, + > = GenericArray::default(); k256::ecdh::diffie_hellman(self.sk.to_nonzero_scalar(), pk.pk.as_affine()) .extract::(None) @@ -91,8 +111,16 @@ impl SecretKeyTrait for SecretKey { Ok(okm) } - fn decrypt(&self, data: &[u8]) -> Result, Error> { - todo!(); + fn generate_keypair() -> (Self, Self::PublicKey) { + let sk = k256::SecretKey::random(&mut rand::thread_rng()); + let pk = sk.public_key(); + (sk.into(), pk.into()) + } +} + +impl From for PublicKey { + fn from(pk: k256::PublicKey) -> Self { + PublicKey { pk } } } @@ -101,3 +129,188 @@ impl From for SecretKey { SecretKey { sk } } } + +#[cfg(test)] +mod tests { + use crate::encryption::HashAlgorithm; + + use super::*; + use k256::SecretKey as K256SecretKey; + use rand::thread_rng; + + // Helper function to generate a random SecretKey + fn generate_random_secret_key() -> SecretKey { + let sk = K256SecretKey::random(&mut thread_rng()); + SecretKey { sk } + } + + #[test] + fn test_secret_key_from_bytes() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let sk_bytes = sk.to_bytes(); + + // Test parsing from bytes + let parsed_sk = SecretKey::from_bytes(&sk_bytes).unwrap(); + assert_eq!(sk.to_bytes(), parsed_sk.to_bytes()); + } + + #[test] + fn test_secret_key_to_bytes() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let sk_bytes = sk.to_bytes(); + + // Ensure the bytes are not all zeros + assert_ne!(sk_bytes, vec![0u8; 32]); + } + + #[test] + fn test_secret_key_public_key() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + + // Ensure the public key is derived correctly + let expected_pk_bytes = sk.sk.public_key().to_sec1_bytes(); + assert_eq!(pk.to_bytes().as_slice(), &expected_pk_bytes[..33]); + } + + #[test] + fn test_secret_key_jwk() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let jwk = sk.jwk().unwrap(); + + // Ensure the JWK contains the correct curve and key type + match jwk.params { + ssi_jwk::Params::EC(ec) => { + assert_eq!(ec.curve, Some("secp256k1".to_string())); + } + _ => panic!("Invalid JWK params"), + } + } + + #[test] + fn test_public_key_from_bytes() { + // Generate a random secret key and its corresponding public key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + let pk_bytes = pk.to_bytes(); + + // Test parsing from bytes + let parsed_pk = PublicKey::from_bytes(&pk_bytes).unwrap(); + assert_eq!(pk.to_bytes(), parsed_pk.to_bytes()); + } + + #[test] + fn test_public_key_to_bytes() { + // Generate a random secret key and its corresponding public key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + let pk_bytes = pk.to_bytes(); + + // Ensure the bytes are not all zeros + assert_ne!(pk_bytes, vec![0u8; 33]); + } + + #[test] + fn test_public_key_jwk() { + // Generate a random secret key and its corresponding public key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + let jwk = pk.jwk(); + + // Ensure the JWK contains the correct curve and key type + match jwk.params { + ssi_jwk::Params::EC(ec) => { + assert_eq!(ec.curve, Some("secp256k1".to_string())); + } + _ => panic!("Invalid JWK params"), + } + } + + #[test] + fn test_encapsulate_decapsulate() { + // Generate two key pairs + let sk1 = generate_random_secret_key(); + let pk1 = sk1.public_key(); + let sk2 = generate_random_secret_key(); + let pk2 = sk2.public_key(); + + // Encapsulate using sk1 and pk2 + let shared_secret1 = sk1.encapsulate(&pk2).unwrap(); + + // Encapsulate using sk2 and pk1 + let shared_secret2 = sk2.encapsulate(&pk1).unwrap(); + + // Ensure the shared secrets match + assert_eq!(shared_secret1, shared_secret2); + } + + #[test] + fn test_generate_keypair() { + // Generate a key pair + let (sk, pk) = SecretKey::generate_keypair(); + + // Ensure the secret key and public key are consistent + let derived_pk = sk.public_key(); + assert_eq!(pk.to_bytes(), derived_pk.to_bytes()); + } + + #[test] + fn test_invalid_secret_key_from_bytes() { + // Test parsing an invalid secret key (wrong length) + let invalid_bytes = vec![0u8; 31]; // 31 bytes instead of 32 + let result = SecretKey::from_bytes(&invalid_bytes); + assert!(result.is_err()); + + // Test parsing an invalid secret key (zero key) + let zero_bytes = vec![0u8; 32]; + let result = SecretKey::from_bytes(&zero_bytes); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_public_key_from_bytes() { + // Test parsing an invalid public key (wrong length) + let invalid_bytes = vec![0u8; 31]; // 31 bytes instead of 32 + let result = PublicKey::from_bytes(&invalid_bytes); + assert!(result.is_err()); + + // Test parsing an invalid public key (invalid encoding) + let invalid_encoding = vec![0u8; 34]; + let result = PublicKey::from_bytes(&invalid_encoding); + assert!(result.is_err()); + } + + #[test] + fn test_derive_hkdf_key() { + // Generate a random secret key + let sk = generate_random_secret_key(); + + // Derive a new key using HKDF + let salt = b"salt"; + let info = b"info"; + let derived_sk = sk + .derive_hkdf_key(HashAlgorithm::SHA256, salt, info) + .unwrap(); + + // Ensure the derived key is different from the original key + assert_ne!(sk.to_bytes(), derived_sk.to_bytes()); + } + + #[test] + fn test_unsupported_hash_algorithm() { + // Generate a random secret key + let sk = generate_random_secret_key(); + + // Attempt to derive a key using an unsupported hash algorithm + let salt = b"salt"; + let info = b"info"; + let result = sk.derive_hkdf_key(HashAlgorithm::SHA512, salt, info); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::UnsupportedHashAlgorithm(_)))); + } +} diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs index b35ae6e..b384a55 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/secretkey.rs @@ -37,10 +37,10 @@ impl SecretKey { pub fn encapsulate(self, pk: PublicKey) -> Result, Error> { match (self, pk) { (SecretKey::Secp256k1(sk), PublicKey::Secp256k1(pk)) => { - sk.encapsulate(pk).map(|ga| ga.to_vec()) + sk.encapsulate(&pk).map(|ga| ga.to_vec()) } (SecretKey::X25519(sk), PublicKey::X25519(pk)) => { - sk.encapsulate(pk).map(|ga| ga.to_vec()) + sk.encapsulate(&pk).map(|ga| ga.to_vec()) } _ => Err(Error::SecretKeyError("Invalid key pair".to_string())), } @@ -61,7 +61,10 @@ impl SecretKey { } pub fn decrypt(&self, data: &[u8]) -> Result, Error> { - todo!() + match self { + SecretKey::Secp256k1(sk) => sk.decrypt(data), + SecretKey::X25519(sk) => sk.decrypt(data), + } } } diff --git a/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs b/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs index b1338eb..3c466ca 100644 --- a/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs +++ b/crates/dwn-rs-core/src/encryption/asymmetric/x25519.rs @@ -4,6 +4,8 @@ use aes::cipher::generic_array::GenericArray; use ssi_jwk::{Base64urlUInt, OctetParams, Params, JWK}; use typenum::U32; +use crate::encryption::symmetric; + use super::{DeriveKey, Error, PrivateKeyError, PublicKeyTrait, SecretKeyTrait}; pub struct PublicKey { @@ -13,18 +15,27 @@ pub struct PublicKey { impl PublicKeyTrait for PublicKey { type KeySize = U32; type SecretKey = SecretKey; + type SymmetricEncryption = symmetric::aead::AES256GCM; - fn from_bytes(bytes: GenericArray) -> Result { + fn from_bytes(bytes: &[u8]) -> Result { let mut pk = [0u8; 32]; - pk.copy_from_slice(&bytes); + + if bytes.len() != 32 { + return Err(PrivateKeyError::InvalidKeyLength.into()); + } + + if bytes.iter().all(|&x| x == 0) { + return Err(PrivateKeyError::InvalidKeyLength.into()); + } + + pk.copy_from_slice(bytes); Ok(Self { pk: x25519_dalek::PublicKey::from(pk), }) } - fn to_bytes(&self) -> GenericArray { - let v = self.pk.as_bytes().to_vec(); - GenericArray::from_iter(v.iter().copied()) + fn to_bytes(&self) -> Vec { + self.pk.as_bytes().to_vec() } fn jwk(&self) -> JWK { @@ -35,8 +46,17 @@ impl PublicKeyTrait for PublicKey { })) } - fn decapsulate(self, sk: Self::SecretKey) -> Result, Error> { - todo!() + fn decapsulate( + &self, + sk: Self::SecretKey, + ) -> Result< + GenericArray< + u8, + <::SymmetricEncryption as symmetric::Encryption>::KeySize, + >, + Error, + > { + sk.encapsulate(self) } } @@ -70,6 +90,10 @@ impl SecretKeyTrait for SecretKey { return Err(PrivateKeyError::InvalidKeyLength.into()); } + if bytes.iter().all(|&x| x == 0) { + return Err(PrivateKeyError::InvalidKeyLength.into()); + } + let mut key = [0u8; 32]; key.copy_from_slice(bytes); @@ -98,14 +122,22 @@ impl SecretKeyTrait for SecretKey { Ok(jwk) } - fn encapsulate(self, pk: Self::PublicKey) -> Result, Error> { + fn encapsulate( + &self, + pk: &Self::PublicKey, + ) -> Result< + GenericArray::SymmetricEncryption as symmetric::Encryption>::KeySize>, + Error, + >{ let shared = self.sk.diffie_hellman(&pk.pk).to_bytes(); Ok(GenericArray::from_iter(shared[..32].iter().copied())) } - fn decrypt(&self, data: &[u8]) -> Result, Error> { - todo!(); + fn generate_keypair() -> (Self, Self::PublicKey) { + let sk = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng()); + let pk = x25519_dalek::PublicKey::from(&sk); + (sk.into(), pk.into()) } } @@ -122,3 +154,163 @@ impl From for SecretKey { SecretKey { sk } } } + +impl From for PublicKey { + fn from(pk: x25519_dalek::PublicKey) -> Self { + PublicKey { pk } + } +} + +#[cfg(test)] +mod tests { + use crate::encryption::HashAlgorithm; + + use super::*; + use rand::thread_rng; + use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret}; + + // Helper function to generate a random SecretKey + fn generate_random_secret_key() -> SecretKey { + let sk = StaticSecret::random_from_rng(thread_rng()); + SecretKey { sk } + } + + #[test] + fn test_secret_key_from_bytes() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let sk_bytes = sk.to_bytes(); + + // Test parsing from bytes + let parsed_sk = SecretKey::from_bytes(&sk_bytes).unwrap(); + assert_eq!(sk.to_bytes(), parsed_sk.to_bytes()); + } + + #[test] + fn test_secret_key_to_bytes() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let sk_bytes = sk.to_bytes(); + + // Ensure the bytes are not all zeros + assert_ne!(sk_bytes, vec![0u8; 32]); + } + + #[test] + fn test_secret_key_public_key() { + // Generate a random secret key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + + // Ensure the public key is derived correctly + let expected_pk_bytes = X25519PublicKey::from(&sk.sk).to_bytes(); + assert_eq!(pk.to_bytes().as_slice(), expected_pk_bytes.as_slice()); + } + + #[test] + fn test_public_key_from_bytes() { + // Generate a random secret key and its corresponding public key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + let pk_bytes = pk.to_bytes(); + + // Test parsing from bytes + let parsed_pk = PublicKey::from_bytes(&pk_bytes).unwrap(); + assert_eq!(pk.to_bytes(), parsed_pk.to_bytes()); + } + + #[test] + fn test_public_key_to_bytes() { + // Generate a random secret key and its corresponding public key + let sk = generate_random_secret_key(); + let pk = sk.public_key(); + let pk_bytes = pk.to_bytes(); + + // Ensure the bytes are not all zeros + assert_ne!(pk_bytes, vec![0u8; 32]); + } + + #[test] + fn test_encapsulate_decapsulate() { + // Generate two key pairs + let sk1 = generate_random_secret_key(); + let pk1 = sk1.public_key(); + let sk2 = generate_random_secret_key(); + let pk2 = sk2.public_key(); + + // Encapsulate using sk1 and pk2 + let shared_secret1 = sk1.encapsulate(&pk2).unwrap(); + + // Encapsulate using sk2 and pk1 + let shared_secret2 = sk2.encapsulate(&pk1).unwrap(); + + // Ensure the shared secrets match + assert_eq!(shared_secret1, shared_secret2); + } + + #[test] + fn test_generate_keypair() { + // Generate a key pair + let (sk, pk) = SecretKey::generate_keypair(); + + // Ensure the secret key and public key are consistent + let derived_pk = sk.public_key(); + assert_eq!(pk.to_bytes(), derived_pk.to_bytes()); + } + + #[test] + fn test_invalid_secret_key_from_bytes() { + // Test parsing an invalid secret key (wrong length) + let invalid_bytes = vec![0u8; 31]; // 31 bytes instead of 32 + let result = SecretKey::from_bytes(&invalid_bytes); + assert!(result.is_err()); + + // Test parsing an invalid secret key (zero key) + let zero_bytes = vec![0u8; 32]; + let result = SecretKey::from_bytes(&zero_bytes); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_public_key_from_bytes() { + // Test parsing an invalid public key (wrong length) + let invalid_bytes = vec![0u8; 31]; // 31 bytes instead of 32 + let result = PublicKey::from_bytes(&invalid_bytes); + assert!(result.is_err()); + + // Test parsing an invalid public key (invalid encoding) + let invalid_encoding = vec![0u8; 32]; // All zeros + let result = PublicKey::from_bytes(&invalid_encoding); + assert!(result.is_err()); + } + + #[test] + fn test_derive_hkdf_key() { + // Generate a random secret key + let sk = generate_random_secret_key(); + + // Derive a new key using HKDF + let salt = b"salt"; + let info = b"info"; + let derived_sk = sk + .derive_hkdf_key(HashAlgorithm::SHA256, salt, info) + .unwrap(); + + // Ensure the derived key is different from the original key + assert_ne!(sk.to_bytes(), derived_sk.to_bytes()); + } + + #[test] + fn test_unsupported_hash_algorithm() { + // Generate a random secret key + let sk = generate_random_secret_key(); + + // Attempt to derive a key using an unsupported hash algorithm + let salt = b"salt"; + let info = b"info"; + let result = sk.derive_hkdf_key(HashAlgorithm::SHA512, salt, info); + + // Ensure the operation fails with the correct error + assert!(matches!(result, Err(Error::UnsupportedHashAlgorithm(_)))); + } +} diff --git a/crates/dwn-rs-core/src/encryption/mod.rs b/crates/dwn-rs-core/src/encryption/mod.rs index 9b4d8a5..e5e2f04 100644 --- a/crates/dwn-rs-core/src/encryption/mod.rs +++ b/crates/dwn-rs-core/src/encryption/mod.rs @@ -42,6 +42,8 @@ pub enum KeyEncryptionAlgorithm { pub enum KeyEncryptionAlgorithmAsymmetric { #[serde(rename = "ECIES-ES256K")] EciesSecp256k1, + #[serde(rename = "X25519")] + X25519, } #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aead.rs b/crates/dwn-rs-core/src/encryption/symmetric/aead.rs index 170178e..6a695c3 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/aead.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/aead.rs @@ -16,7 +16,7 @@ pub struct AEAD { pub(super) struct AEADBufferBytesMut<'a>(&'a mut BytesMut); -impl<'a> Buffer for AEADBufferBytesMut<'a> { +impl Buffer for AEADBufferBytesMut<'_> { fn extend_from_slice(&mut self, other: &[u8]) -> aes_gcm::aead::Result<()> { self.0.extend_from_slice(other); @@ -28,13 +28,13 @@ impl<'a> Buffer for AEADBufferBytesMut<'a> { } } -impl<'a> AsRef<[u8]> for AEADBufferBytesMut<'a> { +impl AsRef<[u8]> for AEADBufferBytesMut<'_> { fn as_ref(&self) -> &[u8] { self.0.as_ref() } } -impl<'a> AsMut<[u8]> for AEADBufferBytesMut<'a> { +impl AsMut<[u8]> for AEADBufferBytesMut<'_> { fn as_mut(&mut self) -> &mut [u8] { self.0.as_mut() } diff --git a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs index 53b0854..c463bfb 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs @@ -10,6 +10,9 @@ use super::{Encryption, IVEncryption}; pub type CipherAES256CTR = Ctr64BE; +#[deprecated( + note = "Use `AEAD` for AES-GCM or Xsalsa20Poly1305Cipher. No message authentication is provided by AES-CTR." +)] pub struct AES256CTR { key: [u8; 32], enc: Option, @@ -26,6 +29,7 @@ pub enum Error { NoIVError, } +#[allow(deprecated)] impl Encryption for AES256CTR { type KeySize = typenum::consts::U32; @@ -56,6 +60,7 @@ impl Encryption for AES256CTR { } } +#[allow(deprecated)] impl IVEncryption for AES256CTR { type NonceSize = typenum::consts::U16; @@ -89,6 +94,7 @@ mod test { ]; #[test] + #[allow(deprecated)] fn test_aes256ctr() { let mut enc = AES256CTR::new(KEY.into()) .expect("Failed to create AES256CTR") diff --git a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs index b14f7f0..f7ec170 100644 --- a/crates/dwn-rs-core/src/encryption/symmetric/mod.rs +++ b/crates/dwn-rs-core/src/encryption/symmetric/mod.rs @@ -55,6 +55,12 @@ pub trait IVEncryption: Encryption { fn with_iv(&mut self, iv: GenericArray) -> Result where Self: Sized; + + fn nonce(&self) -> GenericArray { + let mut nonce = GenericArray::default(); + nonce.iter_mut().for_each(|b| *b = rand::random()); + nonce + } } pin_project! {