From 77691310bc9d9c5146de050af1a6531fa6b2a65f Mon Sep 17 00:00:00 2001 From: LimpidCrypto Date: Sun, 18 Aug 2024 19:05:15 +0000 Subject: [PATCH] current state of refactoring --- .devcontainer/devcontainer.json | 31 +++ .github/dependabot.yml | 12 + Cargo.toml | 120 +++++----- .../websocket/async_websocket_client.rs | 3 +- src/client/websocket/errors.rs | 44 ---- src/client/websocket/exceptions.rs | 49 +++++ src/client/websocket/mod.rs | 78 ++++++- src/core/dns/mod.rs | 25 ++- src/core/dns/queries/a/mod.rs | 36 ++- src/core/dns/queries/aaaa/mod.rs | 36 ++- src/core/dns/queries/mod.rs | 6 +- src/core/framed/errors.rs | 8 +- src/core/framed/framed_impl.rs | 4 +- src/core/framed/mod.rs | 1 - src/core/io/async_read.rs | 98 ++++----- src/core/io/async_write.rs | 128 +++++------ src/core/mod.rs | 9 +- src/core/tcp/adapters.rs | 120 ---------- src/core/tcp/errors.rs | 10 - src/core/tcp/mod.rs | 143 ++---------- src/core/tls/errors.rs | 14 -- src/core/tls/exceptions.rs | 35 +++ src/core/tls/mod.rs | 207 +++++++----------- src/lib.rs | 2 - tests/common/constants.rs | 2 +- 25 files changed, 556 insertions(+), 665 deletions(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .github/dependabot.yml delete mode 100644 src/client/websocket/errors.rs create mode 100644 src/client/websocket/exceptions.rs delete mode 100644 src/core/tcp/adapters.rs delete mode 100644 src/core/tcp/errors.rs delete mode 100644 src/core/tls/errors.rs create mode 100644 src/core/tls/exceptions.rs diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..490cdfc --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,31 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/rust +{ + "name": "Rust", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye", + // Use 'mounts' to make the cargo cache persistent in a Docker Volume. + // "mounts": [ + // { + // "source": "devcontainer-cargo-cache-${devcontainerId}", + // "target": "/usr/local/cargo", + // "type": "volume" + // } + // ] + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "rustc --version", + // Configure tool-specific properties. + "customizations": { + "vscode": { + "extensions": [ + "swellaby.rust-pack" + ] + } + } + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f33a02c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/Cargo.toml b/Cargo.toml index 3297af0..3c1bebb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,81 +10,79 @@ name = "em_as_net" crate-type = ["lib"] [dependencies] -anyhow = { version = "1.0.68", default-features = false } -heapless = { version = "0.7.16", default-features = false } +heapless = { version = "0.8.0", default-features = false } libc = { version = "0.2.139", default-features = false } -rand = { version = "0.8.5", default-features = false, features = ["getrandom"] } -rand_core = { version = "0.6.4", default-features = false } -static_cell = { version = "1.0", default-features = false } +static_cell = { version = "2.1.0", default-features = false } +# Error handling thiserror-no-std = { version = "2.0.2", default-features = false } -futures = { version = "0.3.25", default-features = false } -embedded-tls = { version = "0.14.1", default-features = false, features = ["async"], optional = true } -reqwless = "0.5.0" -tokio = { version = "1.27.0", default-features = false, optional = true } -async-std = { version = "1.12.0", features = ["attributes", "tokio1"], default-features = false, optional = true } -tokio-rustls = { version = "0.24.1", optional = true } -tokio-util = { version = "0.7.7", optional = true } -bytes = { version = "1.4.0", default-features = false } -embedded-io = { version = "0.4.0", features = ["async"] } -pin-project-lite = "0.2.9" -strum_macros = { version = "0.25.1", default-features = false } +anyhow = { version = "1.0.68", default-features = false } +# URL url = { version = "2.3.1", default-features = false } -embedded-nal-async = "0.4.0" -tokio-tungstenite = { version = "0.20.0", optional = true } +# TCP +tokio = { version = "1.39.2", features = [ + "macros", + "rt-multi-thread", +], optional = true } +embassy-net = { version = "0.4.0", features = [ + "tcp", + "medium-ethernet", + "dhcpv4", + "proto-ipv6", +] } +embedded-io-async = "0.6.1" +embedded-io-adapters = "0.6.1" +# DNS +embedded-nal-async = "0.7.1" +# TLS +rustls = { version = "0.23.12", default-features = false, features = ["tls12"] } +tokio-rustls = { version = "0.26.0", optional = true } +# Web Socket +embedded-websocket = { git = "https://github.com/LimpidCrypto/embedded-websocket", branch = "update-dep-and-embedded-io-async", features = [ + "embedded-io-async", +], optional = true } +# JSON-RPC +reqwless = "0.12.1" +webpki-roots = { version = "0.26.3", optional = true } +rand_core = "0.6.4" -[dependencies.embedded-websocket] -# git version needed to use `framer_async` -git = "https://github.com/ninjasource/embedded-websocket" -version = "0.9.2" -default-features = false -[dependencies.embassy-net] -git = "https://github.com/embassy-rs/embassy" -package = "embassy-net" -version = "0.1.0" -rev = "5d5cd2371504915a531e669dce3558485a51a2e1" -features = ["nightly", "tcp", "medium-ethernet", "dhcpv4", "proto-ipv6"] - -[dependencies.embassy-net-driver] -git = "https://github.com/embassy-rs/embassy" -package = "embassy-net-driver" -version = "0.1.0" -rev = "83ff3cbc69875f93c5a9bb36825c12df39f04f71" - -[dependencies.embassy-futures] -git = "https://github.com/embassy-rs/embassy" -package = "embassy-futures" -version = "0.1.0" -rev = "9e8de5f596ffa9036c2343ccc1e69f471a4770eb" - -[dependencies.embassy-time] -git = "https://github.com/embassy-rs/embassy" -package = "embassy-time" -version = "0.1.0" -rev = "dff9bd9711205fd4cd5a91384072ab6aa2335d18" +# strum_macros = { version = "0.26.4", default-features = false } +# async-std = { version = "1.12.0", features = [ +# "attributes", +# "tokio1", +# ], default-features = false, optional = true } +# tokio-util = { version = "0.7.7", optional = true } +# bytes = { version = "1.4.0", default-features = false } +# pin-project-lite = "0.2.9" +# embassy-net-driver = "0.2.0" +# embassy-futures = "0.1.1" +# embassy-time = "0.3.1" +# futures = { version = "0.3.25", default-features = false } +# rand = { version = "0.8.5", default-features = false, features = ["getrandom"] } +# rand_core = { version = "0.6.4", default-features = false } [dev-dependencies] tokio = { version = "1.27.0", features = ["full"] } [features] -default = ["std", "dns", "websocket", "json-rpc"] # TODO: Add tls as soon as it's working +default = [ + "std", + "tcp", + "dns", + "tls", + "websocket", + "json-rpc", +] # TODO: Add tls as soon as it's working +tcp = [] dns = ["embassy-net/dns"] -tls = ["embedded-tls"] +tls = [] websocket = [] json-rpc = [] std = [ - "tokio/full", + "dep:tokio", + "embedded-io-adapters/tokio-1", "embedded-websocket/std", - "embedded-tls/std", - "embedded-tls/tokio", - "async-std", "tokio-rustls", - "tokio-util/codec", - "embassy-net/std", - "embassy-time/std", - "embassy-time/generic-queue", - "rand/std", - "rand/std_rng", - "futures/std", - "tokio-tungstenite/native-tls", + "webpki-roots", ] +webpki-roots = ["dep:webpki-roots"] diff --git a/src/client/websocket/async_websocket_client.rs b/src/client/websocket/async_websocket_client.rs index e21e0b6..b36098e 100644 --- a/src/client/websocket/async_websocket_client.rs +++ b/src/client/websocket/async_websocket_client.rs @@ -1,5 +1,6 @@ use crate::{client::websocket::errors::WebsocketError, Err}; +use alloc::string::ToString; use anyhow::Result; use core::{ fmt::{Debug, Display}, @@ -141,7 +142,7 @@ impl WebsocketOpen, >, > { - let (websocket_stream, _) = tungstenite_connect_async(uri).await.unwrap(); + let (websocket_stream, _) = tungstenite_connect_async(uri.to_string()).await.unwrap(); Ok(AsyncWebsocketClient { inner: websocket_stream, diff --git a/src/client/websocket/errors.rs b/src/client/websocket/errors.rs deleted file mode 100644 index f104e1e..0000000 --- a/src/client/websocket/errors.rs +++ /dev/null @@ -1,44 +0,0 @@ -use super::async_websocket_client::EmbeddedWebsocketFramerError; -use core::fmt::Debug; -use core::str::Utf8Error; -use thiserror_no_std::Error; - -#[derive(Debug, PartialEq, Eq, Error)] -pub enum WebsocketError { - #[error("Stream is not connected.")] - NotConnected, - // FramerError - #[error("I/O error: {0:?}")] - Io(E), - #[error("Frame too large (size: {0:?})")] - FrameTooLarge(usize), - #[error("Failed to interpret u8 to string (error: {0:?})")] - Utf8(Utf8Error), - #[error("Invalid HTTP header")] - HttpHeader, - #[error("Websocket error: {0:?}")] - WebSocket(embedded_websocket::Error), - #[error("Disconnected")] - Disconnected, - #[error("Read buffer is too small (size: {0:?})")] - RxBufferTooSmall(usize), -} - -impl From> for WebsocketError { - fn from(value: EmbeddedWebsocketFramerError) -> Self { - match value { - EmbeddedWebsocketFramerError::Io(e) => WebsocketError::Io(e), - EmbeddedWebsocketFramerError::FrameTooLarge(e) => WebsocketError::FrameTooLarge(e), - EmbeddedWebsocketFramerError::Utf8(e) => WebsocketError::Utf8(e), - EmbeddedWebsocketFramerError::HttpHeader(_) => WebsocketError::HttpHeader, - EmbeddedWebsocketFramerError::WebSocket(e) => WebsocketError::WebSocket(e), - EmbeddedWebsocketFramerError::Disconnected => WebsocketError::Disconnected, - EmbeddedWebsocketFramerError::RxBufferTooSmall(e) => { - WebsocketError::RxBufferTooSmall(e) - } - } - } -} - -#[cfg(feature = "std")] -impl alloc::error::Error for WebsocketError {} diff --git a/src/client/websocket/exceptions.rs b/src/client/websocket/exceptions.rs new file mode 100644 index 0000000..b394643 --- /dev/null +++ b/src/client/websocket/exceptions.rs @@ -0,0 +1,49 @@ +use anyhow::anyhow; +use core::fmt::Debug; +use core::str::Utf8Error; +use embedded_websocket::framer_async::FramerError; +use thiserror_no_std::Error; + +#[derive(Debug, PartialEq, Eq, Error)] +pub enum WebsocketError { + #[error("Stream is not connected.")] + NotConnected, + // FramerError + #[error("I/O error: {0:?}")] + Io(E), + #[error("Frame too large (size: {0:?})")] + FrameTooLarge(usize), + #[error("Failed to interpret u8 to string (error: {0:?})")] + Utf8(Utf8Error), + #[error("Invalid HTTP header")] + HttpHeader, + #[error("Websocket error: {0:?}")] + WebSocket(embedded_websocket::Error), + #[error("Disconnected")] + Disconnected, + #[error("Read buffer is too small (size: {0:?})")] + RxBufferTooSmall(usize), +} + +impl From> for WebsocketError { + fn from(e: FramerError) -> Self { + match e { + FramerError::Io(e) => WebsocketError::Io(e), + FramerError::FrameTooLarge(size) => WebsocketError::FrameTooLarge(size), + FramerError::Utf8(e) => WebsocketError::Utf8(e), + FramerError::HttpHeader(_) => WebsocketError::HttpHeader, + FramerError::WebSocket(e) => WebsocketError::WebSocket(e), + FramerError::Disconnected => WebsocketError::Disconnected, + FramerError::RxBufferTooSmall(size) => WebsocketError::RxBufferTooSmall(size), + } + } +} + +impl Into for WebsocketError { + fn into(self) -> anyhow::Error { + anyhow!(self) + } +} + +#[cfg(feature = "std")] +impl alloc::error::Error for WebsocketError {} diff --git a/src/client/websocket/mod.rs b/src/client/websocket/mod.rs index e2e0e6b..36866de 100644 --- a/src/client/websocket/mod.rs +++ b/src/client/websocket/mod.rs @@ -1,4 +1,76 @@ -mod async_websocket_client; -pub mod errors; +pub mod exceptions; -pub use async_websocket_client::*; +use core::marker::PhantomData; + +use alloc::string::ToString; +use anyhow::Result; +use embedded_io_async::{Read, Write}; +use embedded_websocket::{framer_async::Framer, Client, WebSocketClient, WebSocketOptions}; +use exceptions::WebsocketError; +use rand_core::RngCore; +use url::Url; + +use crate::Err; + +pub struct WebsocketClosed; +pub struct WebsocketOpen; + +pub struct AsyncWebsocketClient { + inner: Framer, + status: PhantomData, +} + +impl AsyncWebsocketClient { + pub fn is_open(&self) -> bool { + core::any::type_name::() == core::any::type_name::() + } +} + +impl AsyncWebsocketClient { + pub async fn open( + buf: &mut [u8], + stream: &mut S, + uri: Url, + rng: T, + ) -> Result> + where + S: Read + Write + Unpin, + { + // replace the scheme with http or https + let scheme = match uri.scheme() { + "wss" => "https", + "ws" => "http", + _ => uri.scheme(), + }; + let port = match uri.port() { + Some(port) => port, + None => match uri.scheme() { + "wss" => 443, + "ws" => 80, + _ => 80, + }, + } + .to_string(); + let path = uri.path(); + let host = match uri.host_str() { + Some(host) => host, + None => return Err(WebsocketError::Disconnected.into()), + }; + let origin = scheme.to_string() + "://" + host + ":" + &port + path; + let websocket_options = WebSocketOptions { + path, + host, + origin: &origin, + sub_protocols: None, + additional_headers: None, + }; + let websocket = Framer::new(WebSocketClient::new_client(rng)); + match websocket.connect(stream, buf, &websocket_options).await { + Ok(_) => Ok(AsyncWebsocketClient { + inner: websocket, + status: PhantomData::, + }), + Err(e) => Err!(e), + } + } +} diff --git a/src/core/dns/mod.rs b/src/core/dns/mod.rs index f648c16..980cea2 100644 --- a/src/core/dns/mod.rs +++ b/src/core/dns/mod.rs @@ -1,14 +1,15 @@ mod queries; -use crate::core::dns::queries::{Aaaa, Lookup, A}; -use alloc::borrow::Cow; +pub use queries::DnsError; +use queries::{Aaaa, Lookup, A}; + use anyhow::Result; use core::marker::PhantomData; use embedded_nal_async::{IpAddr, Ipv4Addr, Ipv6Addr}; -pub use queries::DnsError; +use url::Url; /// Tries to look up IPv6 addresses first. If it fails it then tries to look up IPv4 addresses. -pub async fn lookup(url: Cow<'_, str>) -> Result { +pub async fn lookup(url: Url) -> Result { let dns_a = Dns::::new(url.clone()); let dns_aaaa = Dns::::new(url); @@ -18,13 +19,13 @@ pub async fn lookup(url: Cow<'_, str>) -> Result { } } -pub struct Dns<'a, T = Aaaa> { - url: Cow<'a, str>, +pub struct Dns { + url: Url, record_type: PhantomData, } -impl<'a, T> Dns<'a, T> { - pub fn new(url: Cow<'a, str>) -> Self { +impl Dns { + pub fn new(url: Url) -> Self { Self { url, record_type: PhantomData, @@ -32,14 +33,14 @@ impl<'a, T> Dns<'a, T> { } } -impl<'a> Dns<'a, A> { +impl Dns { pub async fn lookup(&self) -> Result { - A::lookup(self.url.clone()).await + A::lookup(&self.url).await } } -impl<'a> Dns<'a, Aaaa> { +impl Dns { pub async fn lookup(&self) -> Result { - Aaaa::lookup(self.url.clone()).await + Aaaa::lookup(&self.url).await } } diff --git a/src/core/dns/queries/a/mod.rs b/src/core/dns/queries/a/mod.rs index 0ba0e8e..bbb38e9 100644 --- a/src/core/dns/queries/a/mod.rs +++ b/src/core/dns/queries/a/mod.rs @@ -1,24 +1,25 @@ -use super::errors::DnsError; use crate::core::dns::queries::Lookup; -use alloc::borrow::Cow; -use alloc::vec::Vec; use anyhow::Result; -use core::net::SocketAddr; use embedded_nal_async::Ipv4Addr; #[derive(Debug)] pub struct A; #[cfg(feature = "std")] -mod if_std { +mod impl_lookup { + use core::net::SocketAddr; + use super::*; - use crate::Err; + use crate::{core::dns::DnsError, Err}; + use alloc::{string::ToString, vec::Vec}; use tokio::net::lookup_host; + use url::Url; - impl<'a> Lookup<'a, Ipv4Addr> for A { - async fn lookup(url: Cow<'a, str>) -> Result { + impl Lookup for A { + async fn lookup(url: &Url) -> Result { + let url = url.to_string(); let addresses = match lookup_host(&*url).await { - Err(_) => return Err!(DnsError::LookupError(url.clone())), + Err(_) => return Err!(DnsError::LookupError(url.into())), Ok(socket_addrs_iter) => socket_addrs_iter, }; return match addresses @@ -27,9 +28,22 @@ mod if_std { .first() { Some(SocketAddr::V4(addrs)) => Ok(Ipv4Addr::from(addrs.ip().octets())), - None => Err!(DnsError::LookupIpv4Error(url.clone())), - _ => Err!(DnsError::LookupIpv4Error(url.clone())), + None => Err!(DnsError::LookupIpv4Error(url.into())), + _ => Err!(DnsError::LookupIpv4Error(url.into())), }; } } } + +#[cfg(not(feature = "std"))] +mod impl_lookup { + use url::Url; + + use super::*; + + impl Lookup for A { + async fn lookup(_url: &Url) -> Result { + todo!("Implement lookup for A record type without std") + } + } +} diff --git a/src/core/dns/queries/aaaa/mod.rs b/src/core/dns/queries/aaaa/mod.rs index df499c2..d0c4817 100644 --- a/src/core/dns/queries/aaaa/mod.rs +++ b/src/core/dns/queries/aaaa/mod.rs @@ -1,24 +1,27 @@ -use super::errors::DnsError; -use alloc::borrow::Cow; -use alloc::vec::Vec; use anyhow::Result; -use core::net::SocketAddr; use embedded_nal_async::Ipv6Addr; #[derive(Debug)] pub struct Aaaa; #[cfg(feature = "std")] -mod if_std { +mod impl_lookup { + use core::net::SocketAddr; + use super::*; use crate::core::dns::queries::Lookup; + use crate::core::dns::DnsError; use crate::Err; + use alloc::string::ToString; + use alloc::vec::Vec; use tokio::net::lookup_host; + use url::Url; - impl<'a> Lookup<'a, Ipv6Addr> for Aaaa { - async fn lookup(url: Cow<'a, str>) -> Result { + impl Lookup for Aaaa { + async fn lookup(url: &Url) -> Result { + let url = url.to_string(); let addresses = match lookup_host(&*url).await { - Err(_) => return Err!(DnsError::LookupError(url.clone())), + Err(_) => return Err!(DnsError::LookupError(url.into())), Ok(socket_addrs_iter) => socket_addrs_iter, }; return match addresses @@ -27,9 +30,22 @@ mod if_std { .first() { Some(SocketAddr::V6(addrs)) => Ok(Ipv6Addr::from(addrs.ip().octets())), - None => Err!(DnsError::LookupIpv6Error(url.clone())), - _ => Err!(DnsError::LookupIpv6Error(url.clone())), + None => Err!(DnsError::LookupIpv6Error(url.into())), + _ => Err!(DnsError::LookupIpv6Error(url.into())), }; } } } + +#[cfg(not(feature = "std"))] +mod impl_lookup { + use super::*; + use crate::core::dns::queries::Lookup; + use url::Url; + + impl Lookup for Aaaa { + async fn lookup(_url: &Url) -> Result { + todo!("Implement lookup for Aaaa record type without std") + } + } +} diff --git a/src/core/dns/queries/mod.rs b/src/core/dns/queries/mod.rs index 7d9a6c6..c9e45e2 100644 --- a/src/core/dns/queries/mod.rs +++ b/src/core/dns/queries/mod.rs @@ -1,14 +1,14 @@ mod a; pub use a::A; -use alloc::borrow::Cow; mod aaaa; pub use aaaa::Aaaa; mod errors; pub use errors::DnsError; use anyhow::Result; +use url::Url; -pub trait Lookup<'a, T> { - async fn lookup(url: Cow<'a, str>) -> Result; +pub trait Lookup { + async fn lookup(url: &Url) -> Result; } diff --git a/src/core/framed/errors.rs b/src/core/framed/errors.rs index 2007f81..312ac19 100644 --- a/src/core/framed/errors.rs +++ b/src/core/framed/errors.rs @@ -53,7 +53,7 @@ pub enum IoError { // embedded_io errors #[error("{0:?}")] - Io(embedded_io::ErrorKind), + Io(embedded_io_async::ErrorKind), // Tls errors during IO #[cfg(feature = "tls")] @@ -61,11 +61,11 @@ pub enum IoError { TlsRead(embedded_tls::TlsError), } -impl embedded_io::Error for IoError { - fn kind(&self) -> embedded_io::ErrorKind { +impl embedded_io_async::Error for IoError { + fn kind(&self) -> embedded_io_async::ErrorKind { match self { Self::Io(k) => *k, - _ => embedded_io::ErrorKind::Other, + _ => embedded_io_async::ErrorKind::Other, } } } diff --git a/src/core/framed/framed_impl.rs b/src/core/framed/framed_impl.rs index 4ee64ba..a419d0d 100644 --- a/src/core/framed/framed_impl.rs +++ b/src/core/framed/framed_impl.rs @@ -25,9 +25,9 @@ pin_project! { pub(crate) struct FramedImpl { #[pin] - pub(crate) inner: T, + pub inner: T, pub(crate) state: State, - pub(crate) codec: C, + pub codec: C, } } diff --git a/src/core/framed/mod.rs b/src/core/framed/mod.rs index 9ffa4b2..f776221 100644 --- a/src/core/framed/mod.rs +++ b/src/core/framed/mod.rs @@ -8,7 +8,6 @@ mod framed_impl; use framed_impl::{FramedImpl, RWFrames, ReadFrame, WriteFrame}; pub mod errors; -pub use errors::*; use bytes::BytesMut; use core::fmt; diff --git a/src/core/io/async_read.rs b/src/core/io/async_read.rs index 5c8c4fc..60aa60a 100644 --- a/src/core/io/async_read.rs +++ b/src/core/io/async_read.rs @@ -1,8 +1,8 @@ -use crate::Err; -use alloc::boxed::Box; +// use crate::Err; +// use alloc::boxed::Box; use anyhow::Result; use core::fmt::{Debug, Display}; -use core::ops::DerefMut; +// use core::ops::DerefMut; use core::pin::Pin; use core::task::{Context, Poll}; @@ -11,7 +11,7 @@ use crate::core::io::ReadBuf; #[cfg(feature = "std")] use tokio::io::ReadBuf; -use crate::core::framed::IoError; +// use crate::core::framed::IoError; pub trait AsyncRead { type Error: Debug + Display; @@ -23,54 +23,54 @@ pub trait AsyncRead { ) -> Poll>; } -macro_rules! deref_async_read { - () => { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match Pin::new(&mut **self).poll_read(cx, buf) { - Poll::Ready(result) => match result { - Ok(_) => Poll::Ready(Ok(())), - Err(_) => Poll::Ready(Err!(IoError::DecodeWhileReadError)), - }, - Poll::Pending => Poll::Pending, - } - } - }; -} +// macro_rules! deref_async_read { +// () => { +// fn poll_read( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &mut ReadBuf<'_>, +// ) -> Poll> { +// match Pin::new(&mut **self).poll_read(cx, buf) { +// Poll::Ready(result) => match result { +// Ok(_) => Poll::Ready(Ok(())), +// Err(_) => Poll::Ready(Err!(IoError::DecodeWhileReadError)), +// }, +// Poll::Pending => Poll::Pending, +// } +// } +// }; +// } -impl AsyncRead for Box { - type Error = anyhow::Error; +// impl AsyncRead for Box { +// type Error = anyhow::Error; - deref_async_read!(); -} +// deref_async_read!(); +// } -impl AsyncRead for &mut T { - type Error = anyhow::Error; +// impl AsyncRead for &mut T { +// type Error = anyhow::Error; - deref_async_read!(); -} +// deref_async_read!(); +// } -impl

AsyncRead for Pin

-where - P: DerefMut + Unpin, - P::Target: AsyncRead, -{ - type Error = anyhow::Error; +// impl

AsyncRead for Pin

+// where +// P: DerefMut + Unpin, +// P::Target: AsyncRead, +// { +// type Error = anyhow::Error; - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut().as_mut().poll_read(cx, buf) { - Poll::Ready(result) => match result { - Ok(()) => Poll::Ready(Ok(())), - Err(err) => Poll::Ready(Err!(err)), - }, - Poll::Pending => Poll::Pending, - } - } -} +// fn poll_read( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &mut ReadBuf<'_>, +// ) -> Poll> { +// match self.get_mut().as_mut().poll_read(cx, buf) { +// Poll::Ready(result) => match result { +// Ok(()) => Poll::Ready(Ok(())), +// Err(err) => Poll::Ready(Err!(err)), +// }, +// Poll::Pending => Poll::Pending, +// } +// } +// } diff --git a/src/core/io/async_write.rs b/src/core/io/async_write.rs index 86d276f..6ff20b7 100644 --- a/src/core/io/async_write.rs +++ b/src/core/io/async_write.rs @@ -1,6 +1,6 @@ -use alloc::boxed::Box; +// use alloc::boxed::Box; use anyhow::Result; -use core::ops::DerefMut; +// use core::ops::DerefMut; use core::pin::Pin; use core::task::{Context, Poll}; @@ -30,75 +30,75 @@ pub trait AsyncWrite { } } -macro_rules! deref_async_write { - () => { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut **self).poll_write(cx, buf) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut **self).poll_write_vectored(cx, bufs) - } - - fn is_write_vectored(&self) -> bool { - (**self).is_write_vectored() - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut **self).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut **self).poll_shutdown(cx) - } - }; -} +// macro_rules! deref_async_write { +// () => { +// fn poll_write( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// buf: &[u8], +// ) -> Poll> { +// Pin::new(&mut **self).poll_write(cx, buf) +// } + +// fn poll_write_vectored( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// bufs: &[IoSlice<'_>], +// ) -> Poll> { +// Pin::new(&mut **self).poll_write_vectored(cx, bufs) +// } + +// fn is_write_vectored(&self) -> bool { +// (**self).is_write_vectored() +// } + +// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// Pin::new(&mut **self).poll_flush(cx) +// } + +// fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// Pin::new(&mut **self).poll_shutdown(cx) +// } +// }; +// } -impl AsyncWrite for Box { - deref_async_write!(); -} +// impl AsyncWrite for Box { +// deref_async_write!(); +// } -impl AsyncWrite for &mut T { - deref_async_write!(); -} +// impl AsyncWrite for &mut T { +// deref_async_write!(); +// } -impl

AsyncWrite for Pin

-where - P: DerefMut + Unpin, - P::Target: AsyncWrite, -{ - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - self.get_mut().as_mut().poll_write(cx, buf) - } +// impl

AsyncWrite for Pin

+// where +// P: DerefMut + Unpin, +// P::Target: AsyncWrite, +// { +// fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { +// self.get_mut().as_mut().poll_write(cx, buf) +// } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - self.get_mut().as_mut().poll_write_vectored(cx, bufs) - } +// fn poll_write_vectored( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// bufs: &[IoSlice<'_>], +// ) -> Poll> { +// self.get_mut().as_mut().poll_write_vectored(cx, bufs) +// } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().as_mut().poll_flush(cx) - } +// fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// self.get_mut().as_mut().poll_flush(cx) +// } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().as_mut().poll_shutdown(cx) - } +// fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// self.get_mut().as_mut().poll_shutdown(cx) +// } - fn is_write_vectored(&self) -> bool { - (**self).is_write_vectored() - } -} +// fn is_write_vectored(&self) -> bool { +// (**self).is_write_vectored() +// } +// } // TODO: implement if needed, otherwise delete // impl AsyncWrite for Vec { diff --git a/src/core/mod.rs b/src/core/mod.rs index 6fabff6..c96eef2 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,8 +1,9 @@ #[cfg(feature = "dns")] pub mod dns; -pub mod framed; -pub mod io; +// mod framed; +// mod io; +#[cfg(feature = "tcp")] pub mod tcp; // TODO: uncomment and make tls public as soon as it's working -// #[cfg(feature = "tls")] -// mod tls; +#[cfg(feature = "tls")] +pub mod tls; diff --git a/src/core/tcp/adapters.rs b/src/core/tcp/adapters.rs deleted file mode 100644 index f6e3940..0000000 --- a/src/core/tcp/adapters.rs +++ /dev/null @@ -1,120 +0,0 @@ -#[cfg(feature = "std")] -pub use std_adapters::TcpAdapterTokio; - -#[cfg(feature = "std")] -mod std_adapters { - use crate::core::io; - use crate::core::tcp::errors::TcpError; - use crate::Err; - use anyhow::Result; - use core::pin::Pin; - use core::task::{Context, Poll}; - use tokio::io::ReadBuf; - use tokio::io::{AsyncRead, AsyncWrite}; - use tokio::net::{TcpStream, ToSocketAddrs}; - #[derive(Debug)] - pub struct TcpAdapterTokio { - pub(crate) inner: TcpStream, - } - - impl TcpAdapterTokio { - pub async fn connect(ip: impl ToSocketAddrs) -> Result { - match TcpStream::connect(ip).await { - Err(_) => Err!(TcpError::UnableToConnect), // TODO: return the error returned by `tokio::net::TcpStream` - Ok(stream) => Ok(Self { inner: stream }), - } - } - } - - impl io::AsyncRead for TcpAdapterTokio { - type Error = anyhow::Error; - - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match Pin::new(&mut self.inner).poll_read(cx, buf) { - Poll::Ready(result) => match result { - Ok(()) => Poll::Ready(Ok(())), - Err(error) => Poll::Ready(Err!(error)), - }, - Poll::Pending => Poll::Pending, - } - } - } - - impl io::AsyncWrite for TcpAdapterTokio { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match Pin::new(&mut self.inner).poll_write(cx, buf) { - Poll::Ready(result) => match result { - Ok(size) => Poll::Ready(Ok(size)), - Err(error) => Poll::Ready(Err!(error)), - }, - Poll::Pending => Poll::Pending, - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner).poll_flush(cx) { - Poll::Ready(result) => match result { - Ok(()) => Poll::Ready(Ok(())), - Err(error) => Poll::Ready(Err!(error)), - }, - Poll::Pending => Poll::Pending, - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match Pin::new(&mut self.inner).poll_shutdown(cx) { - Poll::Ready(result) => match result { - Ok(()) => Poll::Ready(Ok(())), - Err(error) => Poll::Ready(Err!(error)), - }, - Poll::Pending => Poll::Pending, - } - } - } -} - -// mod no_std_adapters { -// use alloc::borrow::Cow; -// use core::cell::RefCell; -// use core::pin::Pin; -// use core::task::{Context, Poll}; -// use core::borrow::BorrowMut; -// use anyhow::Result; -// use embassy_net::tcp::TcpSocket; -// use embassy_net_driver::Driver; -// use super::AdapterConnect; -// use crate::core::io::ReadBuf; -// use crate::core::framed::IoError; -// use crate::core::io; -// use crate::core::tcp::errors::TcpError; -// use crate::Err; -// -// pub struct TcpAdapterEmbassy<'a> { -// inner: RefCell>>, -// } -// -// impl<'a> TcpAdapterEmbassy<'a> { -// pub fn new(socket: TcpSocket<'a>) -> Self { -// Self { inner: socket } -// } -// } -// -// impl<'a> AdapterConnect<'a> for TcpAdapterEmbassy<'a> { -// async fn connect(&self, ip: Cow<'a, str>) -> Result<()> { -// match self.inner.borrow_mut().as_mut() { -// None => Err!(TcpError::UnableToConnect), -// Ok(socket) => { -// socket.connect -// } -// } -// } -// } -// } diff --git a/src/core/tcp/errors.rs b/src/core/tcp/errors.rs deleted file mode 100644 index 7e393b9..0000000 --- a/src/core/tcp/errors.rs +++ /dev/null @@ -1,10 +0,0 @@ -use thiserror_no_std::Error; - -#[derive(Debug, Clone, PartialEq, Eq, Error)] -pub enum TcpError { - #[error("Unable to connect to host")] - UnableToConnect, -} - -#[cfg(feature = "std")] -impl alloc::error::Error for TcpError {} diff --git a/src/core/tcp/mod.rs b/src/core/tcp/mod.rs index 5dcd66e..777e189 100644 --- a/src/core/tcp/mod.rs +++ b/src/core/tcp/mod.rs @@ -1,124 +1,19 @@ -// use crate::core::framed::IoError; -// use crate::Err; -// use alloc::borrow::Cow; -// use anyhow::Result; -// use core::borrow::BorrowMut; -// use core::fmt::Debug; -// use core::future::poll_fn; -// use core::pin::Pin; -// use core::task::{Context, Poll}; -// use embedded_io::asynch::{Read, Write}; -// use embedded_io::Io; - -// #[cfg(not(feature = "std"))] -// use crate::core::io::ReadBuf; -// use crate::core::io::{AsyncRead, AsyncWrite}; -// #[cfg(feature = "std")] -// use tokio::io::ReadBuf; - -pub mod adapters; -pub mod errors; - -// // TODO: utilize to check `state` -// pub struct Socket; -// pub struct Stream; - -// #[derive(Debug)] -// pub struct TcpSocket { -// pub(crate) socket: T, -// } - -// impl TcpSocket { -// pub fn new(socket: T) -> Self { -// Self { socket } -// } -// } - -// impl<'a, T> TcpConnect<'a> for TcpSocket -// where -// T: AdapterConnect<'a>, -// { -// async fn connect(&mut self, socket_address: Cow<'a, str>) -> Result<()> { -// // TODO: `socket_address` should be of type `SocketAddr` -// self.socket.connect(socket_address).await -// } -// } - -// impl AsyncRead for TcpSocket -// where -// T: AsyncRead + Unpin, -// { -// type Error = anyhow::Error; - -// fn poll_read( -// mut self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &mut ReadBuf<'_>, -// ) -> Poll> { -// match Pin::new(&mut self.socket).poll_read(cx, buf) { -// Poll::Ready(Ok(())) => Poll::Ready(Ok(())), -// Poll::Ready(Err(error)) => Poll::Ready(Err!(error)), -// Poll::Pending => Poll::Pending, -// } -// } -// } - -// impl AsyncWrite for TcpSocket -// where -// T: AsyncWrite + Unpin, -// { -// fn poll_write( -// mut self: Pin<&mut Self>, -// cx: &mut Context<'_>, -// buf: &[u8], -// ) -> Poll> { -// Pin::new(&mut self.socket).poll_write(cx, buf) -// } - -// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// Pin::new(&mut self.socket).poll_flush(cx) -// } - -// fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { -// Pin::new(&mut self.socket).poll_shutdown(cx) -// } -// } - -// impl Io for TcpSocket { -// type Error = IoError; -// } - -// impl Read for TcpSocket -// where -// T: AsyncRead + Unpin, -// { -// async fn read(&mut self, buf: &mut [u8]) -> core::result::Result { -// let size = buf.len(); -// poll_fn(|cx| Pin::new(&mut self.socket).poll_read(cx, ReadBuf::new(buf).borrow_mut())) -// .await -// .map_err(|_| IoError::UnableToRead)?; - -// Ok(size) -// } -// } - -// impl Write for TcpSocket -// where -// T: AsyncWrite + Unpin, -// { -// async fn write(&mut self, buf: &[u8]) -> core::result::Result { -// let size = buf.len(); -// poll_fn(|cx| Pin::new(&mut self.socket).poll_write(cx, buf)) -// .await -// .map_err(|_| IoError::UnableToWrite)?; -// poll_fn(|cx| Pin::new(&mut self.socket).poll_flush(cx)) -// .await -// .map_err(|_| IoError::UnableToFlush)?; - -// Ok(size) -// } -// } - -// pub trait TcpConnect<'a> { -// async fn connect(&mut self, ip: Cow<'a, str>) -> Result<()>; -// } +#[cfg(not(feature = "std"))] +pub use _embassy::*; +#[cfg(feature = "std")] +pub use _tokio::*; + +#[cfg(not(feature = "std"))] +mod _embassy { + use embassy_net::tcp::TcpSocket as EmbassyTcpSocket; + + pub type TcpSocket<'a> = EmbassyTcpSocket<'a>; +} + +#[cfg(feature = "std")] +mod _tokio { + use embedded_io_adapters::tokio_1::FromTokio; + use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream}; + pub type TcpStream = FromTokio; + pub type TcpListener = FromTokio; +} diff --git a/src/core/tls/errors.rs b/src/core/tls/errors.rs deleted file mode 100644 index a5ad9b1..0000000 --- a/src/core/tls/errors.rs +++ /dev/null @@ -1,14 +0,0 @@ -use thiserror_no_std::Error; - -#[derive(Debug, Error)] -pub enum TlsError { - #[error("Tls is not connected (`inner` is not defined)")] - NotConnected, - #[error("Failed to establish tls handshake.")] - FailedToOpen, - #[error("{0:?}")] - Other(embedded_tls::TlsError), -} - -#[cfg(feature = "std")] -impl alloc::error::Error for TlsError {} diff --git a/src/core/tls/exceptions.rs b/src/core/tls/exceptions.rs new file mode 100644 index 0000000..93bced3 --- /dev/null +++ b/src/core/tls/exceptions.rs @@ -0,0 +1,35 @@ +use anyhow::anyhow; +use embedded_io_async::ErrorKind; +use thiserror_no_std::Error; + +#[derive(Debug, Error)] +pub enum TlsException { + #[error("I/O error: {0}")] + IoError(alloc::io::Error), + #[error("No domain")] + NoDomain, + #[error("Embedded IO async error")] + EmbeddedIoAsyncError(ErrorKind), +} + +impl From for TlsException { + fn from(e: alloc::io::Error) -> Self { + TlsException::IoError(e) + } +} + +impl Into for TlsException { + fn into(self) -> anyhow::Error { + anyhow!(self) + } +} + +impl embedded_io_async::Error for TlsException { + fn kind(&self) -> embedded_io_async::ErrorKind { + match self { + TlsException::IoError(_) => ErrorKind::Other, + TlsException::NoDomain => ErrorKind::Other, + TlsException::EmbeddedIoAsyncError(e) => *e, + } + } +} diff --git a/src/core/tls/mod.rs b/src/core/tls/mod.rs index f858894..edc8341 100644 --- a/src/core/tls/mod.rs +++ b/src/core/tls/mod.rs @@ -1,146 +1,103 @@ -pub mod errors; +mod exceptions; -use alloc::boxed::Box; -use anyhow::Result; -use core::future::Future; -use core::net::SocketAddr; -use core::pin::Pin; -use core::task::{Context, Poll}; +use embedded_io_adapters::tokio_1::FromTokio; +pub use exceptions::*; +use tokio_rustls::client::TlsStream; -use embedded_io::asynch::{Read, Write}; -use embedded_tls::{TlsCipherSuite, TlsConnection, TlsContext, TlsVerifier}; -use rand_core::{CryptoRng, RngCore}; +use anyhow::Result; +use embedded_io_async::{Read, Write}; #[cfg(not(feature = "std"))] -use crate::core::io::ReadBuf; +use rustls::{ClientConnection, ServerConnection}; #[cfg(feature = "std")] -use tokio::io::ReadBuf; - -use crate::core::framed::IoError; -use crate::core::io; -use crate::core::tcp::TcpConnect; -use errors::TlsError; - -use crate::Err; - -// exports -pub use embedded_tls::{ - blocking::{Aes128GcmSha256, Aes256GcmSha384}, - webpki::CertVerifier, - NoVerify, TlsConfig, -}; - -#[derive(Default)] -pub struct TlsSocket<'a, Socket, Cipher> -where - Socket: Read + Write + 'a, - Cipher: TlsCipherSuite + 'static, -{ - // TODO: This is just optional so that the `shutdown` works. - // TODO: `Option` should be required when I found an elegant solution to `shutdown` the TLS connection - inner: Option>, -} +use tokio_rustls::TlsConnector; -impl<'a, Socket, Cipher> TlsSocket<'a, Socket, Cipher> -where - Socket: Read + Write + TcpConnect + 'a, - Cipher: TlsCipherSuite + 'static, -{ - pub async fn connect>( - mut socket: Socket, - record_read_buf: &'a mut [u8], - record_write_buf: &'a mut [u8], - rng: &'a mut Rng, - config: &'a TlsConfig<'a, Cipher>, - socket_addr: SocketAddr, - ) -> Result { - socket.connect(socket_addr).await?; - let mut tls_connection = TlsConnection::new(socket, record_read_buf, record_write_buf); - if let Err(err) = tls_connection - .open::(TlsContext::new(config, rng)) - .await - { - return Err!(TlsError::Other(err)); - } +#[cfg(not(feature = "std"))] +pub struct TlsSocketClient(ClientConnection); +#[cfg(not(feature = "std"))] +pub struct TlsSocketServer(ServerConnection); - Ok(Self { - inner: Some(tls_connection), - }) - } -} +#[cfg(feature = "std")] +pub struct TlsSocket(FromTokio>); -impl<'a, Socket, Cipher> io::AsyncRead for TlsSocket<'a, Socket, Cipher> -where - Socket: Read + Write + Unpin + 'a, - Cipher: TlsCipherSuite + Unpin + 'static, - ::Hash: Unpin, - <<::Hash as crypto_common::OutputSizeUser>::OutputSize as generic_array::ArrayLength>::ArrayType: Unpin, - <<::Hash as crypto_common::BlockSizeUser>::BlockSize as generic_array::ArrayLength>::ArrayType: Unpin, -{ - type Error = IoError; - - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.inner.as_mut() { - None => { Poll::Ready(Err(IoError::ReadNotConnected)) } - Some(tls_connection) => match Pin::new(&mut Box::pin(tls_connection.read(buf.filled_mut()))).poll(cx) { - Poll::Ready(result) => match result { - Ok(0) => { - // no data ready - Poll::Pending - } - Ok(_) => Poll::Ready(Ok(())), - Err(e) => Poll::Ready(Err(IoError::TlsRead(e))), - }, - Poll::Pending => Poll::Pending, - } +#[cfg(feature = "std")] +mod tokio_tls_client { + use alloc::sync::Arc; + use embedded_io_async::ErrorType; + use rustls::{pki_types::ServerName, ClientConfig, RootCertStore}; + use tokio::io::{AsyncRead, AsyncWrite}; + use url::Url; + + use crate::core::tcp::TcpStream; + + use super::*; + + impl TlsSocket + where + S: AsyncRead + AsyncWrite + Unpin, + { + pub async fn connect<'a>(url: Url) -> Result> { + let stream = TcpStream::new(inner) + let mut root_cert_store = RootCertStore::empty(); + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let config = ClientConfig::builder() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); + let connector = TlsConnector::from(Arc::new(config)); + + let stream = connector + .connect(server_name, stream) + .await + .map_err(|e| TlsException::IoError(e).into())?; + + Ok(TlsSocket(FromTokio::new(stream))) } } -} -impl<'a, Socket, Cipher> io::AsyncWrite for TlsSocket<'a, Socket, Cipher> -where - Socket: Read + Write + Unpin + 'a, - Cipher: TlsCipherSuite + Unpin + 'static, - ::Hash: Unpin, - <<::Hash as crypto_common::OutputSizeUser>::OutputSize as generic_array::ArrayLength>::ArrayType: Unpin, - <<::Hash as crypto_common::BlockSizeUser>::BlockSize as generic_array::ArrayLength>::ArrayType: Unpin, -{ - fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - match self.inner.as_mut() { - None => { Poll::Ready(Err!(IoError::WriteNotConnected)) } - Some(tls_connection) => match Pin::new(&mut Box::pin(tls_connection.write(buf))).poll(cx) { - Poll::Ready(result) => match result { - Ok(size) => Poll::Ready(Ok(size)), - Err(_) => Poll::Ready(Err!(IoError::UnableToWrite)), - }, - Poll::Pending => Poll::Pending, - } + impl TlsSocket + where + S: AsyncRead + AsyncWrite + Unpin, + { + pub async fn accept(stream: S, url: &Url) -> Result { + todo!("Implement accept as TlsListener"); } + } + impl ErrorType for TlsSocket + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Error = TlsException; } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.inner.as_mut() { - None => { Poll::Ready(Err!(IoError::WriteNotConnected)) } - Some(tls_connection) => match Pin::new(&mut Box::pin(tls_connection.flush())).poll(cx) { - Poll::Ready(result) => match result { - Ok(_) => Poll::Ready(Ok(())), - Err(_) => Poll::Ready(Err!(IoError::UnableToFlush)), - }, - Poll::Pending => Poll::Pending, - } + impl Read for TlsSocket + where + S: AsyncRead + AsyncWrite + Unpin, + { + async fn read(&mut self, buf: &mut [u8]) -> core::result::Result { + self.0 + .read(buf) + .await + .map_err(|e| TlsException::IoError(e).into()) } } - fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - let tls_connection = core::mem::take(&mut self.inner).unwrap(); + impl Write for TlsSocket + where + S: AsyncRead + AsyncWrite + Unpin, + { + async fn write(&mut self, buf: &[u8]) -> core::result::Result { + self.0 + .write(buf) + .await + .map_err(|e| TlsException::IoError(e).into()) + } - // TODO: Find an elegant solution - let _ = tls_connection.close(); - Poll::Ready(Ok(())) + async fn flush(&mut self) -> core::result::Result<(), Self::Error> { + self.0 + .flush() + .await + .map_err(|e| TlsException::IoError(e).into()) + } } } diff --git a/src/lib.rs b/src/lib.rs index bd58ad2..600433b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,6 @@ #![no_std] // #![cfg_attr(not(feature = "std"), no_std)] #![allow(incomplete_features)] -#![feature(async_fn_in_trait)] -#![feature(ip_in_core)] #![allow(dead_code)] // Remove eventually #[cfg(not(feature = "std"))] diff --git a/tests/common/constants.rs b/tests/common/constants.rs index 843323d..f6faa35 100644 --- a/tests/common/constants.rs +++ b/tests/common/constants.rs @@ -1,3 +1,3 @@ pub const ECHO_WS_SERVER: &'static str = "ws://ws.vi-server.org/mirror/"; -pub const ECHO_WS_AS_IP_SERVER: &'static str = "192.236.209.31:80"; pub const ECHO_WSS_SERVER: &'static str = "wss://ws.vi-server.org/mirror/"; +pub const ECHO_WS_AS_IP_SERVER: &'static str = "192.236.209.31:80";