diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fc6e538b1..51c1bbf58 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -122,6 +122,7 @@ jobs: - json - value - serde + - api env: RUSTFLAGS: -D warnings steps: @@ -130,9 +131,27 @@ jobs: - run: cargo check -p musli --no-default-features --features ${{matrix.base}} - run: cargo check -p musli --no-default-features --features ${{matrix.base}},alloc - run: cargo check -p musli --no-default-features --features ${{matrix.base}},std + - run: cargo check -p musli --no-default-features --features ${{matrix.base}},std,alloc - run: cargo check -p musli --no-default-features --features ${{matrix.base}},simdutf8 - run: cargo check -p musli --no-default-features --features ${{matrix.base}},parse-full + crate_features: + needs: [rustfmt, clippy] + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + crate: + - musli-axum + env: + RUSTFLAGS: -D warnings + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - run: cargo build -p ${{matrix.crate}} --no-default-features + - run: cargo build -p ${{matrix.crate}} --no-default-features --features alloc + - run: cargo build -p ${{matrix.crate}} --no-default-features --features std + recursive: runs-on: ubuntu-latest steps: diff --git a/crates/musli-axum/Cargo.toml b/crates/musli-axum/Cargo.toml new file mode 100644 index 000000000..febcb588d --- /dev/null +++ b/crates/musli-axum/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "musli-axum" +version = "0.0.124" +authors = ["John-John Tedro "] +edition = "2021" +description = """ +Types for integrating Müsli with axum. +""" +documentation = "https://docs.rs/musli" +readme = "README.md" +homepage = "https://github.com/udoprog/musli" +repository = "https://github.com/udoprog/musli" +license = "MIT OR Apache-2.0" +keywords = ["framework", "http", "web"] +categories = ["asynchronous", "network-programming", "web-programming::http-server"] + +[features] +default = ["alloc", "std", "ws", "json"] +alloc = ["musli/alloc"] +std = ["musli/std"] +json = ["musli/json", "axum/json", "dep:bytes", "dep:mime"] +ws = ["axum/ws", "dep:rand", "tokio/time", "dep:tokio-stream"] + +[dependencies] +musli = { path = "../musli", version = "0.0.124", default-features = false, features = ["api"] } + +axum = { version = "0.7.5", default-features = false, optional = true } +bytes = { version = "1.6.0", optional = true } +mime = { version = "0.3.17", default-features = false, optional = true } +rand = { version = "0.8.5", default-features = false, optional = true, features = ["small_rng"] } +tracing = { version = "0.1.40", default-features = false } +tokio = { version = "1.37.0", default-features = false, features = ["time"], optional = true } +tokio-stream = { version = "0.1.15", default-features = false, optional = true } diff --git a/crates/musli-axum/README.md b/crates/musli-axum/README.md new file mode 100644 index 000000000..e54f5c6a4 --- /dev/null +++ b/crates/musli-axum/README.md @@ -0,0 +1,11 @@ +# musli-axum + +[github](https://github.com/udoprog/musli) +[crates.io](https://crates.io/crates/musli-axum) +[docs.rs](https://docs.rs/musli-axum) +[build status](https://github.com/udoprog/musli/actions?query=branch%3Amain) + +This crate provides a set of utilities for working with [Axum] and [Müsli]. + +[Axum]: https://github.com/tokio-rs/axum +[Müsli]: https://github.com/udoprog/musli diff --git a/crates/musli-axum/src/json.rs b/crates/musli-axum/src/json.rs new file mode 100644 index 000000000..9606a9139 --- /dev/null +++ b/crates/musli-axum/src/json.rs @@ -0,0 +1,155 @@ +use alloc::boxed::Box; +use alloc::string::{String, ToString}; + +use axum::async_trait; +use axum::extract::rejection::BytesRejection; +use axum::extract::{FromRequest, Request}; +use axum::http::header::{self, HeaderValue}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::{IntoResponse, Response}; +use bytes::{BufMut, Bytes, BytesMut}; +use musli::de::DecodeOwned; +use musli::json::Encoding; +use musli::mode::Text; +use musli::Encode; + +const ENCODING: Encoding = Encoding::new(); + +/// A rejection from the JSON extractor. +pub enum JsonRejection { + ContentType, + Report(String), + BytesRejection(BytesRejection), +} + +impl From for JsonRejection { + #[inline] + fn from(rejection: BytesRejection) -> Self { + JsonRejection::BytesRejection(rejection) + } +} + +impl IntoResponse for JsonRejection { + fn into_response(self) -> Response { + let status; + let body; + + match self { + JsonRejection::ContentType => { + status = StatusCode::UNSUPPORTED_MEDIA_TYPE; + body = String::from("Expected request with `Content-Type: application/json`"); + } + JsonRejection::Report(report) => { + status = StatusCode::BAD_REQUEST; + body = report; + } + JsonRejection::BytesRejection(rejection) => { + return rejection.into_response(); + } + } + + ( + status, + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), + )], + body, + ) + .into_response() + } +} + +/// Encode the given value as JSON. +pub struct Json(pub T); + +#[async_trait] +impl FromRequest for Json +where + T: DecodeOwned, + S: Send + Sync, +{ + type Rejection = JsonRejection; + + async fn from_request(req: Request, state: &S) -> Result { + if !json_content_type(req.headers()) { + return Err(JsonRejection::ContentType); + } + + let bytes = Bytes::from_request(req, state).await?; + Self::from_bytes(&bytes) + } +} + +fn json_content_type(headers: &HeaderMap) -> bool { + let content_type = if let Some(content_type) = headers.get(header::CONTENT_TYPE) { + content_type + } else { + return false; + }; + + let content_type = if let Ok(content_type) = content_type.to_str() { + content_type + } else { + return false; + }; + + let mime = if let Ok(mime) = content_type.parse::() { + mime + } else { + return false; + }; + + let is_json_content_type = mime.type_() == "application" + && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + + is_json_content_type +} + +impl IntoResponse for Json +where + T: Encode, +{ + fn into_response(self) -> Response { + // Use a small initial capacity of 128 bytes like serde_json::to_vec + // https://docs.rs/serde_json/1.0.82/src/serde_json/ser.rs.html#2189 + let mut buf = BytesMut::with_capacity(128).writer(); + + match ENCODING.to_writer(&mut buf, &self.0) { + Ok(()) => ( + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::APPLICATION_JSON.as_ref()), + )], + buf.into_inner().freeze(), + ) + .into_response(), + Err(err) => ( + StatusCode::INTERNAL_SERVER_ERROR, + [( + header::CONTENT_TYPE, + HeaderValue::from_static(mime::TEXT_PLAIN_UTF_8.as_ref()), + )], + err.to_string(), + ) + .into_response(), + } + } +} + +impl Json +where + T: DecodeOwned, +{ + fn from_bytes(bytes: &[u8]) -> Result { + let cx = musli::context::DefaultContext::default(); + + if let Ok(value) = ENCODING.from_slice_with(&cx, bytes) { + return Ok(Json(value)); + } + + let report = cx.report(); + let report = report.to_string(); + Err(JsonRejection::Report(report)) + } +} diff --git a/crates/musli-axum/src/lib.rs b/crates/musli-axum/src/lib.rs new file mode 100644 index 000000000..6aec34584 --- /dev/null +++ b/crates/musli-axum/src/lib.rs @@ -0,0 +1,24 @@ +//! [github](https://github.com/udoprog/musli) +//! [crates.io](https://crates.io/crates/musli-axum) +//! [docs.rs](https://docs.rs/musli-axum) +//! +//! This crate provides a set of utilities for working with [Axum] and [Müsli]. +//! +//! [Axum]: https://github.com/tokio-rs/axum +//! [Müsli]: https://github.com/udoprog/musli + +#![no_std] + +#[cfg(feature = "std")] +extern crate std; + +#[cfg(feature = "alloc")] +extern crate alloc; + +#[cfg(all(feature = "json", feature = "alloc"))] +mod json; +#[cfg(all(feature = "json", feature = "alloc"))] +pub use self::json::Json; + +#[cfg(all(feature = "ws", feature = "alloc"))] +pub mod ws; diff --git a/crates/musli-axum/src/ws.rs b/crates/musli-axum/src/ws.rs new file mode 100644 index 000000000..95da0fe33 --- /dev/null +++ b/crates/musli-axum/src/ws.rs @@ -0,0 +1,370 @@ +use core::fmt::{self, Write}; + +use alloc::borrow::Cow; +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; + +use axum::extract::ws::{CloseFrame, Message, WebSocket}; +use musli::mode::Binary; +use musli::reader::SliceReader; +use musli::{api, Decode, Encode}; +use rand::prelude::*; +use rand::rngs::SmallRng; +use tokio::time::Duration; +use tokio_stream::StreamExt; + +const MAX_CAPACITY: usize = 1048576; + +enum OneOf { + Error(Error), + Handler(E), +} + +impl From for OneOf { + #[inline] + fn from(error: Error) -> Self { + Self::Error(error) + } +} + +impl From for OneOf { + #[inline] + fn from(error: musli::api::Error) -> Self { + Self::Error(Error::from(error)) + } +} + +impl fmt::Display for OneOf +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OneOf::Error(error) => error.fmt(f), + OneOf::Handler(error) => error.fmt(f), + } + } +} + +#[derive(Debug)] +enum ErrorKind { + Axum(axum::Error), + Api(musli::api::Error), + UnknownRequest(Box), + FormatError, +} + +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, +} + +impl Error { + const fn new(kind: ErrorKind) -> Self { + Self { kind } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Axum(..) => write!(f, "Error in axum"), + ErrorKind::Api(..) => write!(f, "Encoding error"), + ErrorKind::UnknownRequest(kind) => { + write!(f, "Unknown request kind `{kind}`") + } + ErrorKind::FormatError => write!(f, "Error formatting error response"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.kind { + ErrorKind::Axum(error) => Some(error), + ErrorKind::Api(error) => Some(error), + _ => None, + } + } +} + +impl From for Error { + #[inline] + fn from(error: axum::Error) -> Self { + Self::new(ErrorKind::Axum(error)) + } +} + +impl From for Error { + #[inline] + fn from(error: musli::api::Error) -> Self { + Self::new(ErrorKind::Api(error)) + } +} + +type Result = core::result::Result; + +/// A handler for incoming requests. +#[axum::async_trait] +pub trait Handler: Send + Sync { + /// Error returned by handler. + type Error: 'static + Send + Sync; + + async fn handle( + &mut self, + incoming: &mut Incoming<'_>, + outgoing: &mut Outgoing<'_>, + kind: &str, + ) -> Result<(), Self::Error>; +} + +/// The server side of the websocket connection. +pub struct Server<'a, E> { + buf: Buf, + error: String, + socket: WebSocket, + handler: Box + 'a>, +} + +impl<'a, E> Server<'a, E> +where + E: 'static + Send + Sync + fmt::Display, +{ + /// Construct a new server with the specified handler. + pub fn new(socket: WebSocket, handler: H) -> Self + where + H: 'a + Handler, + { + Self { + buf: Buf::default(), + error: String::new(), + socket, + handler: Box::new(handler), + } + } + + /// Run the server. + pub async fn run(&mut self) -> Result<(), Error> { + tracing::trace!("Accepted"); + + const CLOSE_NORMAL: u16 = 1000; + const CLOSE_PROTOCOL_ERROR: u16 = 1002; + const CLOSE_TIMEOUT: Duration = Duration::from_secs(30); + const PING_TIMEOUT: Duration = Duration::from_secs(10); + + let mut last_ping = None::; + let mut rng = SmallRng::seed_from_u64(0x404241112); + let mut close_interval = tokio::time::interval(CLOSE_TIMEOUT); + close_interval.reset(); + + let mut ping_interval = tokio::time::interval(PING_TIMEOUT); + ping_interval.reset(); + + let close_here = loop { + tokio::select! { + _ = close_interval.tick() => { + break Some((CLOSE_NORMAL, "connection timed out")); + } + _ = ping_interval.tick() => { + let payload = rng.gen::(); + last_ping = Some(payload); + let data = payload.to_ne_bytes().into_iter().collect::>(); + tracing::trace!(data = ?&data[..], "Sending ping"); + self.socket.send(Message::Ping(data)).await?; + ping_interval.reset(); + } + message = self.socket.next() => { + let Some(message) = message else { + break None; + }; + + match message? { + Message::Text(_) => break Some((CLOSE_PROTOCOL_ERROR, "unsupported message")), + Message::Binary(bytes) => { + let mut reader = SliceReader::new(&bytes); + + let header = match musli::api::decode(&mut reader) { + Ok(header) => header, + Err(error) => { + tracing::warn!(?error, "Failed to decode request header"); + break Some((CLOSE_PROTOCOL_ERROR, "invalid request")); + } + }; + + match self.handle_request(reader, header).await { + Ok(()) => { + self.flush().await?; + }, + Err(error) => { + if write!(self.error, "{error}").is_err() { + return Err(Error::new(ErrorKind::FormatError)); + } + + self.buf.buffer.clear(); + + self.buf.write(api::ResponseHeader { + serial: header.serial, + broadcast: None, + error: Some(self.error.as_str()), + })?; + + self.flush().await?; + } + } + }, + Message::Ping(payload) => { + self.socket.send(Message::Pong(payload)).await?; + continue; + }, + Message::Pong(data) => { + tracing::trace!(data = ?&data[..], "Pong"); + + let Some(expected) = last_ping else { + continue; + }; + + if expected.to_ne_bytes()[..] != data[..] { + continue; + } + + close_interval.reset(); + ping_interval.reset(); + last_ping = None; + }, + Message::Close(_) => break None, + } + } + } + }; + + if let Some((code, reason)) = close_here { + tracing::trace!(code, reason, "Closing websocket with reason"); + + self.socket + .send(Message::Close(Some(CloseFrame { + code, + reason: Cow::Borrowed(reason), + }))) + .await?; + } else { + tracing::trace!("Closing websocket"); + }; + + Ok(()) + } + + async fn flush(&mut self) -> Result<()> { + self.socket + .send(Message::Binary(self.buf.buffer.to_vec())) + .await?; + self.error.clear(); + self.buf.buffer.clear(); + self.buf.buffer.shrink_to(MAX_CAPACITY); + Ok(()) + } + + async fn handle_request( + &mut self, + reader: SliceReader<'_>, + header: api::RequestHeader<'_>, + ) -> Result<(), OneOf> { + tracing::trace!("Got request: {header:?}"); + + self.buf.write(api::ResponseHeader { + serial: header.serial, + broadcast: None, + error: None, + })?; + + let mut incoming = Incoming { + error: None, + reader, + }; + + let mut outgoing = Outgoing { + error: None, + written: false, + buf: &mut self.buf, + }; + + if let Err(error) = self + .handler + .handle(&mut incoming, &mut outgoing, header.kind) + .await + { + return Err(OneOf::Handler(error)); + } + + if let Some(error) = incoming.error.take() { + return Err(OneOf::Error(error)); + } + + if !outgoing.written { + return Err(OneOf::Error(Error::new(ErrorKind::UnknownRequest( + header.kind.into(), + )))); + } + + Ok(()) + } +} + +/// An incoming request. +pub struct Incoming<'de> { + error: Option, + reader: SliceReader<'de>, +} + +impl<'de> Incoming<'de> { + /// Read a request. + pub fn read(&mut self) -> Option + where + T: Decode<'de, Binary>, + { + match musli::api::decode(&mut self.reader) { + Ok(value) => Some(value), + Err(error) => { + self.error = Some(error.into()); + None + } + } + } +} + +/// Handler for an outgoing buffer. +pub struct Outgoing<'a> { + error: Option, + written: bool, + buf: &'a mut Buf, +} + +impl Outgoing<'_> { + /// Write a response. + pub fn write(&mut self, value: T) + where + T: Encode, + { + if let Err(error) = self.buf.write(value) { + self.error = Some(error); + } else { + self.written = true; + } + } +} + +#[derive(Default)] +pub struct Buf { + buffer: Vec, +} + +impl Buf { + fn write(&mut self, value: T) -> Result<(), Error> + where + T: Encode, + { + musli::api::encode(musli::wrap::wrap(&mut self.buffer), &value)?; + Ok(()) + } +} diff --git a/crates/musli-macros/Cargo.toml b/crates/musli-macros/Cargo.toml index de4e4ffac..3c07ac500 100644 --- a/crates/musli-macros/Cargo.toml +++ b/crates/musli-macros/Cargo.toml @@ -21,6 +21,7 @@ path = "src/lib.rs" [features] verbose = [] +api = [] [dependencies] proc-macro2 = "1.0.79" diff --git a/crates/musli-macros/src/api.rs b/crates/musli-macros/src/api.rs new file mode 100644 index 000000000..49537cf68 --- /dev/null +++ b/crates/musli-macros/src/api.rs @@ -0,0 +1,211 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::parse::ParseStream; +use syn::Token; + +/// Expand endpoint. +pub(super) fn endpoint( + input: syn::DeriveInput, + crate_name: &str, + module_name: &str, +) -> syn::Result { + let mut crate_path = None; + let mut response = None; + let mut response_lt = None; + + for attr in input.attrs { + if !attr.path().is_ident("endpoint") { + continue; + } + + attr.parse_args_with(|p: ParseStream<'_>| { + while !p.is_empty() { + let path = p.parse::()?; + + if let Some(lt) = as_ident(&path, "response") { + p.parse::()?; + response = Some(p.parse::()?); + response_lt = lt; + } else if path.is_ident("crate") { + parse_crate(p, &mut crate_path)?; + } else { + return Err(syn::Error::new_spanned(path, "unknown attribute")); + } + + if p.parse::>()?.is_none() { + break; + } + } + + Ok(()) + })?; + } + + let crate_path = match crate_path { + Some(path) => path, + None => syn::parse_quote!(#crate_name), + }; + + let endpoint_t = path(&crate_path, [module_name, "Endpoint"]); + + let Some(response) = response else { + return Err(syn::Error::new( + Span::call_site(), + "missing `#[endpoint(response = )]` attribute", + )); + }; + + let lt = match response_lt { + Some(lt) => lt, + None => syn::Lifetime::new("'__de", Span::call_site()), + }; + + let ident = &input.ident; + let name = name_from_ident(&ident.to_string()); + + Ok(quote! { + impl #endpoint_t for #ident { + const KIND: &'static str = #name; + type Response<#lt> = #response; + } + }) +} + +/// Expand request impl. +pub(super) fn request( + input: syn::DeriveInput, + crate_name: &str, + module_name: &str, +) -> syn::Result { + let mut crate_path = None; + let mut endpoint = None; + + for attr in input.attrs { + if !attr.path().is_ident("request") { + continue; + } + + attr.parse_args_with(|p: ParseStream<'_>| { + while !p.is_empty() { + let path = p.parse::()?; + + if let Some(lt) = as_ident(&path, "endpoint") { + if let Some(lt) = lt { + return Err(syn::Error::new_spanned( + lt, + "lifetimes are not supported for endpoints", + )); + } + + p.parse::()?; + endpoint = Some(p.parse::()?); + } else if path.is_ident("crate") { + parse_crate(p, &mut crate_path)?; + } else { + return Err(syn::Error::new_spanned(path, "unknown attribute")); + } + + if p.parse::>()?.is_none() { + break; + } + } + + Ok(()) + })?; + } + + let crate_path = match crate_path { + Some(path) => path, + None => syn::parse_quote!(#crate_name), + }; + + let request_t = path(&crate_path, [module_name, "Request"]); + + let Some(endpoint) = endpoint else { + return Err(syn::Error::new( + Span::call_site(), + "missing `#[request(endpoint = )]` attribute", + )); + }; + + let ident = &input.ident; + + Ok(quote! { + impl #request_t for #ident { + type Endpoint = #endpoint; + } + }) +} + +fn parse_crate(p: ParseStream<'_>, crate_path: &mut Option) -> syn::Result<()> { + if let Some(existing) = crate_path { + return Err(syn::Error::new_spanned( + existing, + "duplicate `crate` attribute", + )); + } + + *crate_path = if p.parse::>()?.is_some() { + Some(p.parse::()?) + } else { + Some(syn::Path::from(syn::PathSegment::from( + ::default(), + ))) + }; + + Ok(()) +} + +fn as_ident(path: &syn::Path, expect: &str) -> Option> { + let one = path.segments.first()?; + + if path.segments.len() != 1 || path.leading_colon.is_some() { + return None; + } + + if one.ident != expect { + return None; + } + + match &one.arguments { + syn::PathArguments::AngleBracketed(lt) => { + let first = lt.args.first()?; + + if lt.args.len() != 1 { + return None; + } + + match first { + syn::GenericArgument::Lifetime(lt) => Some(Some(lt.clone())), + _ => None, + } + } + syn::PathArguments::None => Some(None), + _ => None, + } +} + +fn name_from_ident(ident: &str) -> String { + let mut name = String::with_capacity(ident.len()); + + for c in ident.chars() { + if c.is_uppercase() && !name.is_empty() { + name.push('-'); + } + + name.extend(c.to_lowercase()); + } + + name +} + +fn path(base: &syn::Path, segments: [&str; N]) -> syn::Path { + let mut path = base.clone(); + + for segment in segments { + path.segments + .push(syn::Ident::new(segment, Span::call_site()).into()); + } + + path +} diff --git a/crates/musli-macros/src/lib.rs b/crates/musli-macros/src/lib.rs index 8455fcc77..d8298c779 100644 --- a/crates/musli-macros/src/lib.rs +++ b/crates/musli-macros/src/lib.rs @@ -12,6 +12,8 @@ #![allow(clippy::needless_late_init)] #![allow(missing_docs)] +#[cfg(feature = "api")] +mod api; mod de; mod en; mod expander; @@ -22,6 +24,30 @@ use proc_macro::TokenStream; const CRATE_DEFAULT: &str = "musli"; +#[doc(hidden)] +#[proc_macro_derive(Endpoint, attributes(endpoint))] +#[cfg(feature = "api")] +pub fn endpoint(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as syn::DeriveInput); + + match api::endpoint(input, CRATE_DEFAULT, "api") { + Ok(tokens) => tokens.into(), + Err(error) => error.to_compile_error().into(), + } +} + +#[doc(hidden)] +#[proc_macro_derive(Request, attributes(request))] +#[cfg(feature = "api")] +pub fn request(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as syn::DeriveInput); + + match api::request(input, CRATE_DEFAULT, "api") { + Ok(tokens) => tokens.into(), + Err(error) => error.to_compile_error().into(), + } +} + /// Derive which automatically implements the [`Encode` trait]. /// /// See the [`derives` module] for detailed documentation. diff --git a/crates/musli-yew/Cargo.toml b/crates/musli-yew/Cargo.toml new file mode 100644 index 000000000..145703a1a --- /dev/null +++ b/crates/musli-yew/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "musli-yew" +version = "0.0.124" +authors = ["John-John Tedro "] +edition = "2021" +description = """ +Types for integrating Müsli with yew. +""" +documentation = "https://docs.rs/musli" +readme = "README.md" +homepage = "https://github.com/udoprog/musli" +repository = "https://github.com/udoprog/musli" +license = "MIT OR Apache-2.0" +keywords = ["javascript", "web", "webasm"] +categories = ["gui", "wasm", "web-programming"] + +[features] +default = ["log"] +log = ["dep:log"] + +[dependencies] +musli = { version = "0.0.124", path = "../musli", default-features = false, features = ["api", "alloc"] } + +log = { version = "0.4.21", optional = true } +slab = "0.4.9" +wasm-bindgen = "0.2.92" +web-sys = { version = "0.3.69", features = ["WebSocket", "MessageEvent", "Performance"] } +yew = "0.21.0" +gloo = { version = "0.11.0", features = ["timers"], default-features = false } diff --git a/crates/musli-yew/README.md b/crates/musli-yew/README.md new file mode 100644 index 000000000..9d1470d60 --- /dev/null +++ b/crates/musli-yew/README.md @@ -0,0 +1,11 @@ +# musli-yew + +[github](https://github.com/udoprog/musli) +[crates.io](https://crates.io/crates/musli-yew) +[docs.rs](https://docs.rs/musli-yew) +[build status](https://github.com/udoprog/musli/actions?query=branch%3Amain) + +This crate provides a set of utilities for working with [Yew] and [Müsli]. + +[Yew]: https://yew.rs +[Müsli]: https://github.com/udoprog/musli diff --git a/crates/musli-yew/src/lib.rs b/crates/musli-yew/src/lib.rs new file mode 100644 index 000000000..da9667545 --- /dev/null +++ b/crates/musli-yew/src/lib.rs @@ -0,0 +1,10 @@ +//! [github](https://github.com/udoprog/musli) +//! [crates.io](https://crates.io/crates/musli-yew) +//! [docs.rs](https://docs.rs/musli-yew) +//! +//! This crate provides a set of utilities for working with [Yew] and [Müsli]. +//! +//! [Yew]: https://yew.rs +//! [Müsli]: https://github.com/udoprog/musli + +pub mod ws; diff --git a/crates/musli-yew/src/ws.rs b/crates/musli-yew/src/ws.rs new file mode 100644 index 000000000..5efd3c1cf --- /dev/null +++ b/crates/musli-yew/src/ws.rs @@ -0,0 +1,922 @@ +use core::fmt; +use std::cell::{Cell, RefCell}; +use std::collections::{hash_map, HashMap}; +use std::marker::PhantomData; +use std::mem::take; +use std::rc::Rc; + +use gloo::timers::callback::Timeout; +use musli::mode::Binary; +use musli::{api, Decode, Encode}; +use slab::Slab; +use wasm_bindgen::closure::Closure; +use wasm_bindgen::{JsCast, JsValue}; +use web_sys::js_sys::{ArrayBuffer, Uint8Array}; +use web_sys::{window, BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket}; +use yew::html::{ImplicitClone, Scope}; +use yew::{Component, Context}; + +#[cfg(feature = "log")] +use log::{error, trace}; + +#[cfg(not(feature = "log"))] +macro_rules! dummy { + ($msg:literal $(, $expr:expr)* $(,)?) => { $(_ = $expr;)* }; +} + +#[cfg(not(feature = "log"))] +use {dummy as trace, dummy as error}; + +const MAX_CAPACITY: usize = 1048576; + +/// The state of the connection. +/// +/// A listener for state changes can be set up through +/// [`Handle::state_changes`]. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[non_exhaustive] +pub enum State { + /// The connection is open. + Open, + /// The connection is closed. + Closed, +} + +/// Error type for the WebSocket service. +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, +} + +#[derive(Debug)] +enum ErrorKind { + Message(String), + Api(musli::api::Error), + Overflow(usize, usize), +} + +impl Error { + #[inline] + fn new(kind: ErrorKind) -> Self { + Self { kind } + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.kind { + ErrorKind::Message(message) => write!(f, "{message}"), + ErrorKind::Api(error) => write!(f, "Encoding error: {error}"), + ErrorKind::Overflow(at, len) => { + write!(f, "Internal packet overflow, {at} not in range 0-{len}") + } + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.kind { + ErrorKind::Api(error) => Some(error), + _ => None, + } + } +} + +impl From for Error { + fn from(error: musli::api::Error) -> Self { + Self::new(ErrorKind::Api(error)) + } +} + +impl From for Error { + fn from(error: JsValue) -> Self { + Self::new(ErrorKind::Message(format!("{error:?}"))) + } +} + +impl From<&str> for Error { + fn from(error: &str) -> Self { + Self::new(ErrorKind::Message(error.to_string())) + } +} + +type Result = core::result::Result; + +const INITIAL_TIMEOUT: u32 = 250; +const MAX_TIMEOUT: u32 = 16000; + +struct ClientRequest<'a> { + header: api::RequestHeader<'a>, + body: Vec, +} + +enum MsgKind { + Reconnect, + Open, + Close(CloseEvent), + Message(MessageEvent), + Error(ErrorEvent), + ClientRequest(ClientRequest<'static>), +} + +/// A message passed into the WebSocket service. +pub struct Msg { + kind: MsgKind, +} + +impl Msg { + #[inline] + const fn new(kind: MsgKind) -> Self { + Self { kind } + } +} + +/// The websocket service. +/// +/// This needs to be wired up with messages inside of a component in yew, see +/// the example for how this can be done. +/// +/// Interaction with the service is done through the [`Handle`] type. +/// +/// # Examples +/// +/// ```no_run +/// use yew::prelude::*; +/// use musli_yew::ws; +/// +/// struct App { +/// ws: ws::Service, +/// handle: ws::Handle, +/// } +/// +/// pub(crate) enum Msg { +/// WebSocket(ws::Msg), +/// Error(ws::Error), +/// } +/// +/// impl From for Msg { +/// #[inline] +/// fn from(error: ws::Msg) -> Self { +/// Self::WebSocket(error) +/// } +/// } +/// +/// impl From for Msg { +/// #[inline] +/// fn from(error: ws::Error) -> Self { +/// Self::Error(error) +/// } +/// } +/// +/// impl Component for App { +/// type Message = Msg; +/// type Properties = (); +/// +/// fn create(ctx: &Context) -> Self { +/// let (ws, handle) = ws::Service::new(ctx); +/// let mut this = Self { ws, handle }; +/// this.ws.connect(ctx); +/// this +/// } +/// +/// fn update(&mut self, ctx: &Context, msg: Self::Message) -> bool { +/// match msg { +/// Msg::WebSocket(msg) => { +/// self.ws.update(ctx, msg); +/// false +/// } +/// Msg::Error(error) => { +/// log::error!("Websocket Error: {error}"); +/// false +/// } +/// } +/// } +/// +/// fn view(&self, ctx: &Context) -> Html { +/// html! { +/// "Hello World" +/// } +/// } +/// } +/// ``` +pub struct Service { + shared: Rc, + socket: Option, + opened: Option, + state: State, + buffer: Vec>, + output: Vec, + timeout: u32, + on_open: Closure, + on_close: Closure, + on_message: Closure, + on_error: Closure, + _timeout: Option, + _ping_timeout: Option, + _marker: PhantomData, +} + +impl Service +where + C: Component + From>, +{ + /// Construct a new websocket service, and return it and an associated + /// handle to it. + pub fn new(ctx: &Context) -> (Self, Handle) { + let link = ctx.link().clone(); + + let shared = Rc::new(Shared { + serial: Cell::new(0), + onmessage: Box::new(move |request| { + link.send_message(Msg::new(MsgKind::ClientRequest(request))) + }), + requests: RefCell::new(Slab::new()), + broadcasts: RefCell::new(HashMap::new()), + state_changes: RefCell::new(Slab::new()), + }); + + let on_open = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move || { + link.send_message(Msg::new(MsgKind::Open)); + }); + + Closure::wrap(cb) + }; + + let on_close = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: CloseEvent| { + link.send_message(Msg::new(MsgKind::Close(e))); + }); + + Closure::wrap(cb) + }; + + let on_message = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: MessageEvent| { + link.send_message(Msg::new(MsgKind::Message(e))); + }); + + Closure::wrap(cb) + }; + + let on_error = { + let link = ctx.link().clone(); + + let cb: Box = Box::new(move |e: ErrorEvent| { + link.send_message(Msg::new(MsgKind::Error(e))); + }); + + Closure::wrap(cb) + }; + + let this = Self { + shared: shared.clone(), + socket: None, + opened: None, + state: State::Closed, + buffer: Vec::new(), + output: Vec::new(), + timeout: INITIAL_TIMEOUT, + on_open, + on_close, + on_message, + on_error, + _timeout: None, + _ping_timeout: None, + _marker: PhantomData, + }; + + let handle = Handle { shared }; + + (this, handle) + } + + /// Send a client message. + fn send_client_request(&mut self, request: ClientRequest<'_>) -> Result<()> { + let Some(socket) = &self.socket else { + return Err("Socket is not connected".into()); + }; + + self.output.clear(); + musli::api::encode(musli::wrap::wrap(&mut self.output), &request.header)?; + self.output.extend_from_slice(request.body.as_slice()); + socket.send_with_u8_array(&self.output)?; + self.output.shrink_to(MAX_CAPACITY); + Ok(()) + } + + fn message(&mut self, e: MessageEvent) -> Result<()> { + let Ok(array_buffer) = e.data().dyn_into::() else { + return Err("Expected message as ArrayBuffer".into()); + }; + + let body = Rc::from(Uint8Array::new(&array_buffer).to_vec()); + let mut reader = musli::reader::SliceReader::new(&body); + + let header: api::ResponseHeader<'_> = musli::api::decode(&mut reader)?; + + match header.broadcast { + Some(kind) => { + let broadcasts = self.shared.broadcasts.borrow(); + let at = body.len() - reader.remaining(); + + if let Some(broadcasts) = broadcasts.get(kind) { + let mut it = broadcasts.iter(); + + let last = it.next_back(); + + let raw = RawPacket { + body: body.clone(), + at, + }; + + for (_, callback) in it { + (callback)(raw.clone()); + } + + if let Some((_, callback)) = last { + (callback)(raw); + } + } + } + None => { + let (index, serial) = unpack(header.serial); + trace!("Got response index={index}, serial={serial}"); + + let requests = self.shared.requests.borrow(); + + let Some(pending) = requests.get(index) else { + return Ok(()); + }; + + if pending.serial == serial { + if let Some(error) = header.error { + pending.callback.error(error); + } else { + let at = body.len() - reader.remaining(); + pending.callback.packet(RawPacket { body, at }); + } + } + } + } + + Ok(()) + } + + fn set_open(&mut self) { + trace!("Set open"); + self.opened = Some(Opened { at: now() }); + self.emit_state_change(State::Open); + } + + fn is_open_for_a_while(&self) -> bool { + let Some(opened) = self.opened else { + return false; + }; + + let Some(at) = opened.at else { + return false; + }; + + let Some(now) = now() else { + return false; + }; + + (now - at) >= 250.0 + } + + fn set_closed(&mut self, ctx: &Context) { + trace!( + "Set closed timeout={}, opened={:?}", + self.timeout, + self.opened + ); + + if !self.is_open_for_a_while() { + if self.timeout < MAX_TIMEOUT { + self.timeout *= 2; + } + } else { + self.timeout = INITIAL_TIMEOUT; + } + + self.opened = None; + self.reconnect(ctx); + self.emit_state_change(State::Closed); + } + + fn emit_state_change(&mut self, state: State) { + if self.state != state { + let callbacks = self.shared.state_changes.borrow(); + + for (_, callback) in callbacks.iter() { + callback(state); + } + + self.state = state; + } + } + + /// Handle an update message. + pub fn update(&mut self, ctx: &Context, message: Msg) { + match message.kind { + MsgKind::Reconnect => { + trace!("Reconnect"); + + if let Err(error) = self.inner_connect() { + ctx.link().send_message(error); + self.inner_reconnect(ctx); + } + } + MsgKind::Open => { + trace!("Open"); + self.set_open(); + + let buffer = take(&mut self.buffer); + + for request in buffer { + if let Err(error) = self.send_client_request(request) { + ctx.link().send_message(error); + } + } + } + MsgKind::Close(e) => { + trace!("Close: {} ({})", e.code(), e.reason()); + self.set_closed(ctx); + } + MsgKind::Message(e) => { + if let Err(error) = self.message(e) { + ctx.link().send_message(error); + } + } + MsgKind::Error(e) => { + error!("{}", e.message()); + self.set_closed(ctx); + } + MsgKind::ClientRequest(request) => { + if self.opened.is_none() { + self.buffer.push(request); + return; + } + + if let Err(error) = self.send_client_request(request) { + ctx.link().send_message(error); + } + } + } + } + + pub(crate) fn reconnect(&mut self, ctx: &Context) { + if let Some(old) = self.socket.take() { + if let Err(error) = old.close() { + ctx.link().send_message(Error::from(error)); + } + } + + let link = ctx.link().clone(); + + self._timeout = Some(Timeout::new(self.timeout, move || { + link.send_message(Msg::new(MsgKind::Reconnect)); + })); + } + + /// Attempt to establish a websocket connection. + pub fn connect(&mut self, ctx: &Context) { + if let Err(error) = self.inner_connect() { + ctx.link().send_message(error); + self.inner_reconnect(ctx); + } + } + + fn inner_connect(&mut self) -> Result<()> { + let window = window().ok_or("No window")?; + let port = window.location().port()?; + let url = format!("ws://127.0.0.1:{port}/ws"); + let ws = WebSocket::new(&url)?; + + ws.set_binary_type(BinaryType::Arraybuffer); + ws.set_onopen(Some(self.on_open.as_ref().unchecked_ref())); + ws.set_onclose(Some(self.on_close.as_ref().unchecked_ref())); + ws.set_onmessage(Some(self.on_message.as_ref().unchecked_ref())); + ws.set_onerror(Some(self.on_error.as_ref().unchecked_ref())); + + if let Some(old) = self.socket.replace(ws) { + old.close()?; + } + + Ok(()) + } + + fn inner_reconnect(&mut self, ctx: &Context) { + let link = ctx.link().clone(); + + self._timeout = Some(Timeout::new(1000, move || { + link.send_message(Msg::new(MsgKind::Reconnect)); + })); + } +} + +/// The handle for a pending request. Dropping this handle cancels the request. +pub struct Request { + inner: Option<(Rc, u32)>, + _marker: PhantomData, +} + +impl Request { + /// An empty request handler. + #[inline] + pub fn empty() -> Self { + Self::default() + } +} + +impl Default for Request { + #[inline] + fn default() -> Self { + Self { + inner: None, + _marker: PhantomData, + } + } +} + +impl Drop for Request { + #[inline] + fn drop(&mut self) { + if let Some((shared, index)) = self.inner.take() { + shared.requests.borrow_mut().try_remove(index as usize); + } + } +} + +/// The handle for a pending request. Dropping this handle cancels the request. +pub struct Listener { + kind: &'static str, + index: usize, + shared: Rc, + _marker: PhantomData, +} + +impl Drop for Listener { + #[inline] + fn drop(&mut self) { + let mut broadcast = self.shared.broadcasts.borrow_mut(); + + if let hash_map::Entry::Occupied(mut e) = broadcast.entry(self.kind) { + e.get_mut().try_remove(self.index); + + if e.get().is_empty() { + e.remove(); + } + } + } +} + +/// The handle for state change listening. Dropping this handle cancels the request. +pub struct StateListener { + index: usize, + shared: Rc, +} + +impl Drop for StateListener { + #[inline] + fn drop(&mut self) { + self.shared + .state_changes + .borrow_mut() + .try_remove(self.index); + } +} + +#[derive(Clone)] +struct RawPacket { + body: Rc<[u8]>, + at: usize, +} + +/// A packet of data. +pub struct Packet { + raw: RawPacket, + _marker: PhantomData, +} + +impl Packet { + /// Handle a broadcast packet. + pub fn decode( + &self, + ctx: &Context>>, + f: impl FnOnce(T::Response<'_>), + ) where + for<'de> T: api::Endpoint: Decode<'de, Binary>>, + { + let Some(bytes) = self.raw.body.get(self.raw.at..) else { + ctx.link().send_message(Error::new(ErrorKind::Overflow( + self.raw.at, + self.raw.body.len(), + ))); + return; + }; + + match musli::api::from_slice(bytes) { + Ok(value) => { + f(value); + } + Err(error) => { + ctx.link().send_message(Error::from(error)); + } + } + } +} + +/// A handle to the WebSocket [`Service`]. +/// +/// Through a handle you can initialize a [`Request`] with [`Handle::request`]. +/// This can conveniently be constructed through a [`Default`] implementation of +/// no request is pending. Dropping this handle will cancel the request. +/// +/// You can also listen to broadcast events through [`Handle::listen`]. +/// Similarly here, dropping the handle will cancel the listener. +/// +/// # Examples +/// +/// ``` +/// use musli::{Encode, Decode}; +/// use musli::api::{Endpoint, Request}; +/// use musli_yew::ws; +/// use yew::prelude::*; +/// +/// #[derive(Encode, Decode)] +/// pub struct MessageOfTheDayResponse<'a> { +/// pub message_of_the_day: &'a str, +/// } +/// +/// #[derive(Request, Encode, Decode)] +/// #[request(endpoint = MessageOfTheDay)] +/// pub struct MessageOfTheDayRequest; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response<'de> = MessageOfTheDayResponse<'de>)] +/// pub enum MessageOfTheDay {} +/// +/// pub(crate) struct Dashboard { +/// message_of_the_day: String, +/// _initialize: ws::Request, +/// } +/// +/// pub(crate) enum Msg { +/// MessageOfTheDay(ws::Packet), +/// Error(ws::Error), +/// } +/// +/// impl From> for Msg { +/// #[inline] +/// fn from(packet: ws::Packet) -> Self { +/// Self::MessageOfTheDay(packet) +/// } +/// } +/// +/// impl From for Msg { +/// #[inline] +/// fn from(error: ws::Error) -> Self { +/// Self::Error(error) +/// } +/// } +/// +/// #[derive(Properties, PartialEq)] +/// pub(crate) struct Props { +/// pub(crate) ws: ws::Handle, +/// pub(crate) onerror: Callback, +/// } +/// +/// impl Component for Dashboard { +/// type Message = Msg; +/// type Properties = Props; +/// +/// fn create(ctx: &Context) -> Self { +/// Self { +/// message_of_the_day: String::new(), +/// _initialize: ctx.props().ws.request(ctx, MessageOfTheDayRequest), +/// } +/// } +/// +/// fn update(&mut self, ctx: &Context, msg: Self::Message) -> bool { +/// match msg { +/// Msg::MessageOfTheDay(packet) => { +/// packet.decode(ctx, |update| { +/// self.message_of_the_day = update.message_of_the_day.to_owned(); +/// }); +/// +/// true +/// } +/// Msg::Error(error) => { +/// ctx.props().onerror.emit(error); +/// false +/// } +/// } +/// } +/// +/// fn view(&self, _: &Context) -> Html { +/// html! { +/// {self.message_of_the_day.clone()} +/// } +/// } +/// } +/// ``` +#[derive(Clone)] +pub struct Handle { + shared: Rc, +} + +impl Handle { + /// Send a request of type `T` and returns a handle for the request. + /// + /// If the handle is dropped, the request is cancelled. + /// + /// See [`Handle`] for an example. + pub fn request( + &self, + ctx: &Context> + From>>, + request: T, + ) -> Request + where + T: api::Request + Encode, + { + struct CallbackImpl + where + C: Component, + { + link: Scope, + _marker: PhantomData, + } + + impl Callback for CallbackImpl + where + C: Component> + From>, + { + fn packet(&self, raw: RawPacket) { + self.link.send_message(Packet { + raw, + _marker: PhantomData, + }) + } + + fn error(&self, error: &str) { + self.link.send_message(Error::from(error)); + } + } + + let body = match musli::api::to_vec(&request) { + Ok(body) => body, + Err(error) => { + ctx.link().send_message(Error::from(error)); + return Request::default(); + } + }; + + let serial = self.shared.serial.get(); + self.shared.serial.set(serial.wrapping_add(1)); + + let pending = Pending { + serial, + callback: Box::new(CallbackImpl { + link: ctx.link().clone(), + _marker: PhantomData, + }), + }; + + let index = self.shared.requests.borrow_mut().insert(pending) as u32; + + (self.shared.onmessage)(ClientRequest { + header: api::RequestHeader { + serial: pack(index, serial), + kind: ::KIND, + }, + body, + }); + + Request { + inner: Some((self.shared.clone(), index)), + _marker: PhantomData, + } + } + + /// List for broadcasts of type `T` and returns a handle for the listener. + /// + /// If the handle is dropped, the listener is cancelled. + /// + /// See [`Handle`] for an example. + pub fn listen( + &self, + ctx: &Context> + From>>, + ) -> Listener + where + T: api::Endpoint, + { + let mut broadcasts = self.shared.broadcasts.borrow_mut(); + + let slots = broadcasts.entry(T::KIND).or_default(); + let link = ctx.link().clone(); + + let index = slots.insert(Box::new(move |raw| { + link.send_message(Packet { + raw, + _marker: PhantomData, + }); + })); + + Listener { + kind: T::KIND, + index, + shared: self.shared.clone(), + _marker: PhantomData, + } + } + + /// Listen for state changes to the underlying connection. + pub fn state_changes( + &self, + ctx: &Context>>, + ) -> StateListener { + let link = ctx.link().clone(); + let mut state = self.shared.state_changes.borrow_mut(); + + let index = state.insert(Box::new(move |state| link.send_message(state))); + + StateListener { + index, + shared: self.shared.clone(), + } + } +} + +impl ImplicitClone for Handle { + #[inline] + fn implicit_clone(&self) -> Self { + self.clone() + } +} + +impl PartialEq for Handle { + #[inline] + fn eq(&self, _: &Self) -> bool { + true + } +} + +fn now() -> Option { + Some(window()?.performance()?.now()) +} + +trait Callback { + /// Handle a packet. + fn packet(&self, packet: RawPacket); + + /// Handle an error. + fn error(&self, error: &str); +} + +struct Pending { + serial: u32, + callback: Box, +} + +type Broadcasts = HashMap<&'static str, Slab>>; +type OnMessageCallback = dyn Fn(ClientRequest<'static>); +type StateCallback = dyn Fn(State); + +struct Shared { + serial: Cell, + onmessage: Box, + requests: RefCell>, + broadcasts: RefCell, + state_changes: RefCell>>, +} + +#[derive(Debug, Clone, Copy)] +struct Opened { + at: Option, +} + +#[inline] +fn pack(a: u32, b: u32) -> u64 { + ((a as u64) << 32) | (b as u64) +} + +#[inline] +fn unpack(serial: u64) -> (usize, u32) { + ((serial >> 32) as usize, serial as u32) +} + +#[test] +fn test_pack() { + assert_eq!(unpack(pack(0, 0)), (0, 0)); + assert_eq!(unpack(pack(1, 0)), (1, 0)); + assert_eq!(unpack(pack(u32::MAX, 0)), (u32::MAX as usize, 0)); +} diff --git a/crates/musli/Cargo.toml b/crates/musli/Cargo.toml index 8672c7c38..79f9caa57 100644 --- a/crates/musli/Cargo.toml +++ b/crates/musli/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" homepage = "https://github.com/udoprog/musli" repository = "https://github.com/udoprog/musli" license = "MIT OR Apache-2.0" -keywords = ["no_std", "serialization"] +keywords = ["binary", "json", "no_std", "serialization"] categories = ["encoding", "no-std", "no-std::no-alloc"] [package.metadata.docs.rs] @@ -34,11 +34,13 @@ json = ["value", "dep:itoa", "dep:ryu"] parse-full = [] value = [] serde = ["dep:serde"] +api = ["musli-macros/api", "storage"] test = ["storage", "wire", "descriptive", "json", "parse-full", "value", "serde"] [dependencies] musli-core = { version = "=0.0.124", path = "../musli-core", default-features = false } +musli-macros = { version = "=0.0.124", path = "../musli-macros", features = [], optional = true } simdutf8 = { version = "0.1.4", optional = true, default-features = false } itoa = { version = "1.0.10", optional = true } @@ -49,7 +51,7 @@ serde = { version = "1.0.198", optional = true, default-features = false} loom = "0.7.2" [dev-dependencies] -musli = { path = ".", features = ["test"] } +musli = { path = ".", features = ["test", "api"] } tests = { path = "../../tests" } rand = "0.8.5" diff --git a/crates/musli/src/api.rs b/crates/musli/src/api.rs new file mode 100644 index 000000000..007a420af --- /dev/null +++ b/crates/musli/src/api.rs @@ -0,0 +1,169 @@ +//! API definitions for Musli. +//! +//! This provides types and traits for defining a simple binary strictly-typed +//! interchange API. + +use crate::{Decode, Encode}; + +#[cfg(feature = "storage")] +mod encoding; +#[cfg(feature = "storage")] +pub use self::encoding::*; + +#[cfg(test)] +mod tests; + +/// Define an endpoint with a well-known name and a request and response type. +/// +/// This derive requires one unique identifiers and a couple of types to be +/// designated. Once provider it defines an enum marker type that implements the +/// [`Endpoint`] trait. +/// * The unique identifier of the request, like `"ping"`, if this is not +/// specified using the `name = "..."` attribute it will default to the name +/// of the endpoint type in lower kebab-case. +/// * The response type which must implement `Encode` and `Decode` and can +/// optionally take a lifetime. This is specified with the +/// `#[endpoint(response = )]` or `#[endpoint(response<'de> = )]` +/// attribute. Responses can have a lifetime since it allows for local buffers +/// to be re-used and avoiding copies. +/// +/// The overall structure of the derive is as follows: +/// +/// ``` +/// # use musli::{Encode, Decode}; +/// # use musli::api::Request; +/// # #[derive(Request, Encode, Decode)] +/// # #[request(endpoint = Hello)] +/// # pub struct HelloRequest; +/// # #[derive(Encode, Decode)] +/// # pub struct HelloResponse<'de> { _marker: core::marker::PhantomData<&'de ()> } +/// use musli::api::Endpoint; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response<'de> = HelloResponse<'de>)] +/// pub enum Hello {} +/// ``` +/// +/// # Examples +/// +/// ``` +/// use musli::{Encode, Decode}; +/// use musli::api::{Endpoint, Request}; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response = Pong)] +/// pub enum PingPong {} +/// +/// #[derive(Request, Encode, Decode)] +/// #[request(endpoint = PingPong)] +/// pub struct Ping(u32); +/// +/// #[derive(Encode, Decode)] +/// pub struct Pong(u32); +/// +/// #[derive(Encode, Decode)] +/// pub struct MessageOfTheDayResponse<'de> { +/// pub message_of_the_day: &'de str, +/// } +/// +/// #[derive(Request, Encode)] +/// #[request(endpoint = MessageOfTheDay)] +/// pub struct MessageOfTheDayRequest; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response<'de> = MessageOfTheDayResponse<'de>)] +/// pub enum MessageOfTheDay {} +/// ``` +#[doc(inline)] +pub use musli_macros::Endpoint; + +/// Define an request and the endpoint it is associated with. +/// +/// The overall structure of the derive is as follows: +/// +/// ``` +/// # #[derive(Encode, Decode)] +/// # pub struct HelloResponse<'de> { _marker: core::marker::PhantomData<&'de ()> } +/// use musli::{Encode, Decode}; +/// use musli::api::{Endpoint, Request}; +/// +/// #[derive(Request, Encode, Decode)] +/// #[request(endpoint = Hello)] +/// pub struct HelloRequest; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response<'de> = HelloResponse<'de>)] +/// pub enum Hello {} +/// ``` +/// +/// # Examples +/// +/// ``` +/// use musli::{Encode, Decode}; +/// use musli::api::{Endpoint, Request}; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response = Pong)] +/// pub enum PingPong {} +/// +/// #[derive(Request, Encode, Decode)] +/// #[request(endpoint = PingPong)] +/// pub struct Ping(u32); +/// +/// #[derive(Encode, Decode)] +/// pub struct Pong(u32); +/// +/// #[derive(Encode, Decode)] +/// pub struct MessageOfTheDayResponse<'de> { +/// pub message_of_the_day: &'de str, +/// } +/// +/// #[derive(Request, Encode)] +/// #[request(endpoint = MessageOfTheDay)] +/// pub struct MessageOfTheDayRequest; +/// +/// #[derive(Endpoint)] +/// #[endpoint(response<'de> = MessageOfTheDayResponse<'de>)] +/// pub enum MessageOfTheDay {} +/// ``` +#[doc(inline)] +pub use musli_macros::Request; + +/// A marker indicating a decodable type. +pub trait Endpoint: 'static { + /// The name of the endpoint. + const KIND: &'static str; + + /// The response type of the endpoint. + type Response<'de>; +} + +/// Trait governing requests. +pub trait Request { + /// The endpoint the request is associated with. + type Endpoint: Endpoint; +} + +/// The API header of a request. +#[derive(Debug, Clone, Copy, Encode, Decode)] +#[musli(crate)] +pub struct RequestHeader<'a> { + /// Identifier of the request. + pub serial: u64, + /// The kind of the request. + pub kind: &'a str, +} + +/// The API header of a response. +#[derive(Debug, Clone, Encode, Decode)] +#[musli(crate)] +pub struct ResponseHeader<'de> { + /// Identifier of the request this response belongs to. + pub serial: u64, + /// The response is a broadcast belonging to the given type. + #[musli(default, skip_encoding_if = Option::is_none)] + pub broadcast: Option<&'de str>, + /// An error message in the response. + #[musli(default, skip_encoding_if = Option::is_none)] + pub error: Option<&'de str>, +} diff --git a/crates/musli/src/api/encoding.rs b/crates/musli/src/api/encoding.rs new file mode 100644 index 000000000..2376a6a70 --- /dev/null +++ b/crates/musli/src/api/encoding.rs @@ -0,0 +1,63 @@ +use core::fmt; + +use crate::mode::Binary; +use crate::{Decode, Encode, IntoReader, Writer}; + +#[cfg(feature = "alloc")] +use rust_alloc::vec::Vec; + +/// Errors raised during api serialization. +pub struct Error(crate::storage::Error); + +impl fmt::Debug for Error { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for Error { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl core::error::Error for Error {} + +/// Encode an API frame. +pub fn encode(writer: W, value: &T) -> Result<(), Error> +where + W: Writer, + T: ?Sized + Encode, +{ + crate::storage::encode(writer, value).map_err(Error) +} + +/// Encode an API frame into an allocated vector. +#[cfg(feature = "alloc")] +pub fn to_vec(value: &T) -> Result, Error> +where + T: ?Sized + Encode, +{ + crate::storage::to_vec(value).map_err(Error) +} + +/// Decode an API frame. +#[inline] +pub fn decode<'de, R, T>(reader: R) -> Result +where + R: IntoReader<'de>, + T: Decode<'de, Binary>, +{ + crate::storage::decode(reader).map_err(Error) +} + +/// Decode an API frame from a slice. +#[cfg(feature = "alloc")] +pub fn from_slice<'de, T>(bytes: &'de [u8]) -> Result +where + T: Decode<'de, Binary>, +{ + crate::storage::from_slice(bytes).map_err(Error) +} diff --git a/crates/musli/src/api/tests.rs b/crates/musli/src/api/tests.rs new file mode 100644 index 000000000..443a0db13 --- /dev/null +++ b/crates/musli/src/api/tests.rs @@ -0,0 +1,15 @@ +use crate::api::Endpoint; + +struct Pong; + +#[derive(Endpoint)] +#[endpoint(crate, response = Pong)] +struct PingPong; + +#[test] +fn test_match() { + match PingPong::KIND { + PingPong::KIND => {} + _ => panic!(), + } +} diff --git a/crates/musli/src/lib.rs b/crates/musli/src/lib.rs index 4ec1a472f..e7749d334 100644 --- a/crates/musli/src/lib.rs +++ b/crates/musli/src/lib.rs @@ -592,5 +592,8 @@ pub use self::writer::Writer; pub mod no_std; +#[cfg(feature = "api")] +pub mod api; + mod int; mod str; diff --git a/crates/musli/src/wrap.rs b/crates/musli/src/wrap.rs index 5e708a936..ab6267c8c 100644 --- a/crates/musli/src/wrap.rs +++ b/crates/musli/src/wrap.rs @@ -3,9 +3,9 @@ //! The main methods in this module is the [`wrap`] function which constructs an //! adapter around an I/O type to work with musli. -#[cfg(feature = "std")] +#[cfg(any(feature = "alloc", feature = "std"))] use crate::alloc::Vec; -#[cfg(feature = "std")] +#[cfg(any(feature = "alloc", feature = "std"))] use crate::Context; /// Wrap a type so that it implements [`Reader`] and [`Writer`]. @@ -15,7 +15,7 @@ use crate::Context; /// [`Reader`]: crate::reader::Reader /// [`Writer`]: crate::writer::Writer pub struct Wrap { - #[cfg_attr(not(feature = "std"), allow(unused))] + #[cfg_attr(not(any(feature = "alloc", feature = "std")), allow(unused))] inner: T, } @@ -44,7 +44,6 @@ where where C: ?Sized + Context, { - // SAFETY: the buffer never outlives this function call. self.write_bytes(cx, buffer.as_slice()) } @@ -58,3 +57,31 @@ where Ok(()) } } + +#[cfg(all(feature = "alloc", not(feature = "std")))] +impl crate::writer::Writer for Wrap<&mut rust_alloc::vec::Vec> { + type Mut<'this> = &'this mut Self where Self: 'this; + + #[inline] + fn borrow_mut(&mut self) -> Self::Mut<'_> { + self + } + + #[inline] + fn extend(&mut self, cx: &C, buffer: Vec<'_, u8, C::Allocator>) -> Result<(), C::Error> + where + C: ?Sized + Context, + { + self.write_bytes(cx, buffer.as_slice()) + } + + #[inline] + fn write_bytes(&mut self, cx: &C, bytes: &[u8]) -> Result<(), C::Error> + where + C: ?Sized + Context, + { + self.inner.extend_from_slice(bytes); + cx.advance(bytes.len()); + Ok(()) + } +}