Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bump hyper version to 1.0
Browse files Browse the repository at this point in the history
Signed-off-by: csh <[email protected]>
L-jasmine committed Jan 22, 2024
1 parent 2dc4198 commit 6edb00a
Showing 9 changed files with 635 additions and 102 deletions.
5 changes: 3 additions & 2 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[build]
target="wasm32-wasi"
target = "wasm32-wasi"
rustflags = "--cfg tokio_unstable"

[target.wasm32-wasi]
runner = "wasmedge"
runner = "wasmedge"
15 changes: 10 additions & 5 deletions client-https/Cargo.toml
Original file line number Diff line number Diff line change
@@ -6,9 +6,14 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
hyper_wasi = { version = "0.15", features = ["full"]}
http-body-util = "0.1.0-rc.2"
tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]}
hyper = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] }
pretty_env_logger = "0.4.0"
wasmedge_rustls_api = { version = "0.1", features = [ "tokio_async" ] }
wasmedge_hyper_rustls = "0.1.0"

wasmedge_wasi_socket = "0.5"
pin-project = "1.1.3"
http-body-util = "0.1.0"

tokio-rustls = "0.25.0"
webpki-roots = "0.26.0"
rustls = "0.22.2"
162 changes: 149 additions & 13 deletions client-https/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,166 @@
use hyper::Client;
#![deny(warnings)]
#![warn(rust_2018_idioms)]

type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
// use tokio::io::{self, AsyncWriteExt as _};

use std::{
os::fd::{FromRawFd, IntoRawFd},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use http_body_util::{BodyExt, Empty};
use hyper::{body::Bytes, Request};

use rustls::pki_types::ServerName;
use tokio::net::TcpStream;

type MainResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

#[tokio::main(flavor = "current_thread")]
async fn main() {
async fn main() -> MainResult<()> {
pretty_env_logger::init();

let url = "https://httpbin.org/get?msg=WasmEdge"
.parse::<hyper::Uri>()
.unwrap();
fetch_https_url(url).await.unwrap();
fetch_https_url(url).await
}

use pin_project::pin_project;
use tokio_rustls::TlsConnector;

#[pin_project]
#[derive(Debug)]
struct TokioIo<T> {
#[pin]
inner: T,
}

impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

#[allow(dead_code)]
pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

async fn fetch_https_url(url: hyper::Uri) -> Result<()> {
let https = wasmedge_hyper_rustls::connector::new_https_connector(
wasmedge_rustls_api::ClientConfig::default(),
);
let client = Client::builder().build::<_, hyper::Body>(https);
impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::prelude::v1::Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

async fn fetch_https_url(url: hyper::Uri) -> MainResult<()> {
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(443);
let addr = format!("{}:{}", host, port);

let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();

let connector = TlsConnector::from(Arc::new(config));
let stream = unsafe {
let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd();
TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))?
};

let domain = ServerName::try_from(host.to_string()).unwrap();
let stream = connector.connect(domain, stream).await.unwrap();

let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});

let authority = url.authority().unwrap().clone();

let path = url.path();
let req = Request::builder()
.uri(path)
.header(hyper::header::HOST, authority.as_str())
.body(Empty::<Bytes>::new())?;

let res = client.get(url).await?;
let mut res = sender.send_request(req).await?;

println!("Response: {}", res.status());
println!("Headers: {:#?}\n", res.headers());

let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
println!("{}", String::from_utf8(body.into()).unwrap());
let mut resp_data = Vec::new();
while let Some(next) = res.frame().await {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
resp_data.extend_from_slice(&chunk);
}
}

println!("\n\nDone!");
println!("{}", String::from_utf8_lossy(&resp_data));

Ok(())
}
13 changes: 5 additions & 8 deletions client/Cargo.toml
Original file line number Diff line number Diff line change
@@ -4,12 +4,9 @@ version = "0.1.0"
edition = "2021"

[dependencies]
hyper_wasi = { version = "0.15", features = ["full"] }
tokio_wasi = { version = "1", features = [
"rt",
"macros",
"net",
"time",
"io-util",
] }
hyper = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] }
pretty_env_logger = "0.4.0"
wasmedge_wasi_socket = "0.5"
pin-project = "1.1.3"
http-body-util = "0.1.0"
193 changes: 161 additions & 32 deletions client/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
#![deny(warnings)]
#![warn(rust_2018_idioms)]
use hyper::{body::HttpBody as _, Client};
use hyper::{Body, Method, Request};

// use tokio::io::{self, AsyncWriteExt as _};

use std::{
os::fd::{FromRawFd, IntoRawFd},
pin::Pin,
task::{Context, Poll},
};

use http_body_util::{BodyExt, Empty};
use hyper::{body::Bytes, Request};

use tokio::net::TcpStream;

// A simple type alias so as to DRY.
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
type MainResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<()> {
async fn main() -> MainResult<()> {
pretty_env_logger::init();

let url_str = "http://eu.httpbin.org/get?msg=Hello";
@@ -24,7 +34,7 @@ async fn main() -> Result<()> {
let url_str = "http://eu.httpbin.org/get?msg=WasmEdge";
println!("\nGET and get result as string: {}", url_str);
let url = url_str.parse::<hyper::Uri>().unwrap();
fetch_url_return_str(url).await?;
fetch_url(url).await?;
// tokio::time::sleep(std::time::Duration::from_secs(5)).await;

let url_str = "http://eu.httpbin.org/post";
@@ -35,51 +45,170 @@ async fn main() -> Result<()> {
post_url_return_str(url, post_body_str.as_bytes()).await
}

async fn fetch_url(url: hyper::Uri) -> Result<()> {
let client = Client::new();
let mut res = client.get(url).await?;
use pin_project::pin_project;

println!("Response: {}", res.status());
println!("Headers: {:#?}\n", res.headers());
#[pin_project]
#[derive(Debug)]
struct TokioIo<T> {
#[pin]
inner: T,
}

// Stream the body, writing each chunk to stdout as we get it
// (instead of buffering and printing at the end).
while let Some(next) = res.data().await {
let chunk = next?;
println!("{:#?}", chunk);
// io::stdout().write_all(&chunk).await?;
impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

Ok(())
#[allow(dead_code)]
pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::prelude::v1::Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

async fn fetch_url_return_str(url: hyper::Uri) -> Result<()> {
let client = Client::new();
let mut res = client.get(url).await?;
async fn fetch_url(url: hyper::Uri) -> MainResult<()> {
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
let stream = unsafe {
let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd();
TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))?
};

let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});

let authority = url.authority().unwrap().clone();

let path = url.path();
let req = Request::builder()
.uri(path)
.header(hyper::header::HOST, authority.as_str())
.body(Empty::<Bytes>::new())?;

let mut res = sender.send_request(req).await?;

println!("Response: {}", res.status());
println!("Headers: {:#?}\n", res.headers());

let mut resp_data = Vec::new();
while let Some(next) = res.data().await {
let chunk = next?;
resp_data.extend_from_slice(&chunk);
while let Some(next) = res.frame().await {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
resp_data.extend_from_slice(&chunk);
}
}

println!("{}", String::from_utf8_lossy(&resp_data));

Ok(())
}

async fn post_url_return_str(url: hyper::Uri, post_body: &'static [u8]) -> Result<()> {
let client = Client::new();
async fn post_url_return_str(url: hyper::Uri, post_body: &'static [u8]) -> MainResult<()> {
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
let stream = unsafe {
let fd = wasmedge_wasi_socket::TcpStream::connect(addr)?.into_raw_fd();
TcpStream::from_std(std::net::TcpStream::from_raw_fd(fd))?
};

let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
}
});

let authority = url.authority().unwrap().clone();

let path = url.path();
let req = Request::builder()
.method(Method::POST)
.uri(url)
.body(Body::from(post_body))?;
let mut res = client.request(req).await?;
.uri(path)
.method("POST")
.header(hyper::header::HOST, authority.as_str())
.body(http_body_util::Full::new(post_body))?;

let mut res = sender.send_request(req).await?;

println!("Response: {}", res.status());
println!("Headers: {:#?}\n", res.headers());

let mut resp_data = Vec::new();
while let Some(next) = res.data().await {
let chunk = next?;
resp_data.extend_from_slice(&chunk);
while let Some(next) = res.frame().await {
let frame = next?;
if let Some(chunk) = frame.data_ref() {
resp_data.extend_from_slice(&chunk);
}
}

println!("{}", String::from_utf8_lossy(&resp_data));

Ok(())
20 changes: 17 additions & 3 deletions server-tflite/Cargo.toml
Original file line number Diff line number Diff line change
@@ -4,8 +4,22 @@ version = "0.1.0"
edition = "2021"

[dependencies]
hyper_wasi = { version = "0.15", features = ["full"]}
tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]}
image = { version = "0.23.14", default-features = false, features = ["gif", "jpeg", "ico", "png", "tiff", "webp", "bmp"] }
hyper = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] }
wasmedge_wasi_socket = "0.5"

image = { version = "0.23.14", default-features = false, features = [
"gif",
"jpeg",
"ico",
"png",
"tiff",
"webp",
"bmp",
] }
wasi-nn = "0.4.0"
anyhow = "1.0"

pin-project = "1.1.3"
http-body-util = "0.1.0"
bytes = "1"
154 changes: 129 additions & 25 deletions server-tflite/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, StatusCode, Server};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::result::Result;
use std::io::Cursor;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode};
use image::io::Reader;
use image::DynamicImage;
use wasi_nn::{GraphBuilder, GraphEncoding, ExecutionTarget, TensorType};
use std::io::Cursor;
use std::net::SocketAddr;
use std::os::fd::{FromRawFd, IntoRawFd};
use std::pin::Pin;
use std::result::Result;
use std::task::{Context, Poll};
use tokio::net::TcpListener;
use wasi_nn::{ExecutionTarget, GraphBuilder, GraphEncoding, TensorType};

/// This is our service handler. It receives a Request, routes on its
/// path, and returns a Future of a Response.
async fn classify(req: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
let model_data: &[u8] = include_bytes!("models/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_quant.tflite");
async fn classify(
req: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, anyhow::Error> {
let model_data: &[u8] =
include_bytes!("models/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_quant.tflite");
let labels = include_str!("models/mobilenet_v1_1.0_224/labels_mobilenet_quant_v1_224.txt");
let graph = GraphBuilder::new(GraphEncoding::TensorflowLite, ExecutionTarget::CPU).build_from_bytes(&[model_data])?;
let graph = GraphBuilder::new(GraphEncoding::TensorflowLite, ExecutionTarget::CPU)
.build_from_bytes(&[model_data])?;
let mut ctx = graph.init_execution_context()?;
/*
let graph = unsafe {
@@ -29,14 +39,15 @@ async fn classify(req: Request<Body>) -> Result<Response<Body>, anyhow::Error> {

match (req.method(), req.uri().path()) {
// Serve some instructions at /
(&Method::GET, "/") => Ok(Response::new(Body::from(
(&Method::GET, "/") => Ok(Response::new(full(
"Try POSTing data to /classify such as: `curl http://localhost:3000/classify -X POST --data-binary '@grace_hopper.jpg'`",
))),

(&Method::POST, "/classify") => {
let buf = hyper::body::to_bytes(req.into_body()).await?;
let buf = req.collect().await?.to_bytes();

let tensor_data = image_to_tensor(&buf, 224, 224);
ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?;
ctx.set_input(0, TensorType::U8, &[1, 224, 224, 3], &tensor_data)?;
/*
let tensor = wasi_nn::Tensor {
dimensions: &[1, 224, 224, 3],
@@ -79,7 +90,7 @@ async fn classify(req: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
let class_name = labels.lines().nth(results[0].0).unwrap_or("Unknown");
println!("result: {} {}", class_name, results[0].1);

Ok(Response::new(Body::from(format!("{} is detected with {}/255 confidence", class_name, results[0].1))))
Ok(Response::new(full(format!("{} is detected with {}/255 confidence", class_name, results[0].1))))
}

// Return the 404 Not Found for other routes.
@@ -91,21 +102,114 @@ async fn classify(req: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
}
}

use pin_project::pin_project;

#[pin_project]
#[derive(Debug)]
struct TokioIo<T> {
#[pin]
inner: T,
}

impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

#[allow(dead_code)]
pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::prelude::v1::Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
let make_svc = make_service_fn(|_| {
async move {
Ok::<_, Infallible>(service_fn(move |req| {
classify(req)
}))
}
});
let server = Server::bind(&addr).serve(make_svc);
if let Err(e) = server.await {
eprintln!("server error: {}", e);
let listener = unsafe {
let fd = wasmedge_wasi_socket::TcpListener::bind(addr, true)?.into_raw_fd();
TcpListener::from_std(std::net::TcpListener::from_raw_fd(fd))?
};

loop {
let (stream, _) = listener.accept().await?;
println!("accept");
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(io, service_fn(classify))
.await
{
println!("Error serving connection: {:?}", err);
}
});
}
Ok(())

/*
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
9 changes: 7 additions & 2 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -4,5 +4,10 @@ version = "0.1.0"
edition = "2021"

[dependencies]
hyper_wasi = { version = "0.15", features = ["full"]}
tokio_wasi = { version = "1", features = ["rt", "macros", "net", "time", "io-util"]}
hyper = { version = "1", features = ["full"] }
tokio = { version = "1", features = ["rt", "macros", "net", "time", "io-util"] }
wasmedge_wasi_socket = "0.5"

pin-project = "1.1.3"
http-body-util = "0.1.0"
bytes = "1"
166 changes: 154 additions & 12 deletions server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,191 @@
#![deny(warnings)]

use std::net::SocketAddr;
use std::os::fd::{FromRawFd, IntoRawFd};
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::server::conn::Http;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full};
use hyper::body::Frame;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Body, Method, Request, Response, StatusCode};
use hyper::{body::Body, Method, Request, Response, StatusCode};
use pin_project::pin_project;
use tokio::net::TcpListener;

#[pin_project]
#[derive(Debug)]
struct TokioIo<T> {
#[pin]
inner: T,
}

impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

#[allow(dead_code)]
pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::prelude::v1::Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

/// This is our service handler. It receives a Request, routes on its
/// path, and returns a Future of a Response.
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
async fn echo(
req: Request<hyper::body::Incoming>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
match (req.method(), req.uri().path()) {
// Serve some instructions at /
(&Method::GET, "/") => Ok(Response::new(Body::from(
"Try POSTing data to /echo such as: `curl localhost:8080/echo -XPOST -d 'hello world'`",
(&Method::GET, "/") => Ok(Response::new(full(
"Try POSTing data to /echo such as: `curl localhost:3000/echo -XPOST -d \"hello world\"`",
))),

// Simply echo the body back to the client.
(&Method::POST, "/echo") => Ok(Response::new(req.into_body())),
(&Method::POST, "/echo") => Ok(Response::new(req.into_body().boxed())),

// Convert to uppercase before sending back to client using a stream.
(&Method::POST, "/echo/uppercase") => {
let frame_stream = req.into_body().map_frame(|frame| {
let frame = if let Ok(data) = frame.into_data() {
data.iter()
.map(|byte| byte.to_ascii_uppercase())
.collect::<Bytes>()
} else {
Bytes::new()
};

Frame::data(frame)
});

Ok(Response::new(frame_stream.boxed()))
}

// Reverse the entire body before sending back to the client.
//
// Since we don't know the end yet, we can't simply stream
// the chunks as they arrive as we did with the above uppercase endpoint.
// So here we do `.await` on the future, waiting on concatenating the full body,
// then afterwards the content can be reversed. Only then can we return a `Response`.
(&Method::POST, "/echo/reversed") => {
let whole_body = hyper::body::to_bytes(req.into_body()).await?;
// To protect our server, reject requests with bodies larger than
// 64kbs of data.
let max = req.body().size_hint().upper().unwrap_or(u64::MAX);
if max > 1024 * 64 {
let mut resp = Response::new(full("Body too big"));
*resp.status_mut() = hyper::StatusCode::PAYLOAD_TOO_LARGE;
return Ok(resp);
}

let whole_body = req.collect().await?.to_bytes();

let reversed_body = whole_body.iter().rev().cloned().collect::<Vec<u8>>();
Ok(Response::new(Body::from(reversed_body)))
Ok(Response::new(full(reversed_body)))
}

// Return the 404 Not Found for other routes.
_ => {
let mut not_found = Response::default();
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}

fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}

fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));

let listener = unsafe {
let fd = wasmedge_wasi_socket::TcpListener::bind(addr, true)?.into_raw_fd();
TcpListener::from_std(std::net::TcpListener::from_raw_fd(fd))?
};

let listener = TcpListener::bind(addr).await?;
println!("Listening on http://{}", addr);
loop {
let (stream, _) = listener.accept().await?;
println!("accept");
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = Http::new().serve_connection(stream, service_fn(echo)).await {
if let Err(err) = http1::Builder::new()
.serve_connection(io, service_fn(echo))
.await
{
println!("Error serving connection: {:?}", err);
}
});

0 comments on commit 6edb00a

Please sign in to comment.