diff --git a/Cargo.lock b/Cargo.lock index 192febf0..e8a4e077 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -111,6 +111,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -501,6 +513,12 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -1769,7 +1787,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e5bd22c71e77d60140b0bd5be56155a37e5bd14e24f5f87298040d0cc40d7" dependencies = [ - "heck", + "heck 0.3.3", "proc-macro2", "quote", "syn 1.0.109", @@ -2251,6 +2269,8 @@ dependencies = [ name = "worker-sandbox" version = "0.1.0" dependencies = [ + "axum", + "axum-macros", "blake2", "cfg-if", "chrono", @@ -2268,6 +2288,7 @@ dependencies = [ "serde-wasm-bindgen 0.6.5", "serde_json", "tokio", + "tower-service", "tungstenite", "uuid", "wasm-bindgen-test", diff --git a/worker-macros/src/lib.rs b/worker-macros/src/lib.rs index c0f1e3fc..235c4f06 100644 --- a/worker-macros/src/lib.rs +++ b/worker-macros/src/lib.rs @@ -1,5 +1,6 @@ mod durable_object; mod event; +mod send; use proc_macro::TokenStream; @@ -21,3 +22,21 @@ pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream { pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream { event::expand_macro(attr, item, false) } + +#[proc_macro_attribute] +/// Convert an async function which is `!Send` to be `Send`. +/// +/// This is useful for implementing async handlers in frameworks which +/// expect the handler to be `Send`, such as `axum`. +/// +/// ```rust +/// #[worker::send] +/// async fn foo() { +/// // JsFuture is !Send +/// let fut = JsFuture::from(promise); +/// fut.await +/// } +/// ``` +pub fn send(attr: TokenStream, stream: TokenStream) -> TokenStream { + send::expand_macro(attr, stream) +} diff --git a/worker-macros/src/send.rs b/worker-macros/src/send.rs new file mode 100644 index 00000000..69ed1e8f --- /dev/null +++ b/worker-macros/src/send.rs @@ -0,0 +1,26 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, ItemFn}; + +pub fn expand_macro(_attr: TokenStream, stream: TokenStream) -> TokenStream { + let stream_clone = stream.clone(); + let input = parse_macro_input!(stream_clone as ItemFn); + + let ItemFn { + attrs, + vis, + sig, + block, + } = input; + let stmts = &block.stmts; + + let tokens = quote! { + #(#attrs)* #vis #sig { + worker::send::SendFuture::new(async { + #(#stmts)* + }).await + } + }; + + TokenStream::from(tokens) +} diff --git a/worker-sandbox/Cargo.toml b/worker-sandbox/Cargo.toml index ca70afae..148c4ebc 100644 --- a/worker-sandbox/Cargo.toml +++ b/worker-sandbox/Cargo.toml @@ -14,7 +14,7 @@ path = "src/lib.rs" [features] default = ["console_error_panic_hook"] -http = ["worker/http"] +http = ["worker/http", "dep:axum", "dep:tower-service", "dep:axum-macros"] [dependencies] futures-channel.workspace = true @@ -38,6 +38,20 @@ uuid = { version = "1.3.3", features = ["v4", "serde"] } serde-wasm-bindgen = "0.6.1" md5 = "0.7.0" +[dependencies.axum] +version = "0.7" +optional = true +default-features = false + +[dependencies.axum-macros] +version = "0.4" +optional = true +default-features = false + +[dependencies.tower-service] +version = "0.3" +optional = true + [dev-dependencies] wasm-bindgen-test.workspace = true futures-channel = { version = "0.3.29", features = ["sink"] } diff --git a/worker-sandbox/src/alarm.rs b/worker-sandbox/src/alarm.rs index 3bf4681f..38fed43b 100644 --- a/worker-sandbox/src/alarm.rs +++ b/worker-sandbox/src/alarm.rs @@ -2,6 +2,8 @@ use std::time::Duration; use worker::*; +use super::SomeSharedData; + #[durable_object] pub struct AlarmObject { state: State, @@ -39,3 +41,31 @@ impl DurableObject for AlarmObject { Response::ok("ALARMED") } } + +#[worker::send] +pub async fn handle_alarm(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let namespace = env.durable_object("ALARM")?; + let stub = namespace.id_from_name("alarm")?.get_stub()?; + // when calling fetch to a Durable Object, a full URL must be used. Alternatively, a + // compatibility flag can be provided in wrangler.toml to opt-in to older behavior: + // https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url + stub.fetch_with_str("https://fake-host/alarm").await +} + +#[worker::send] +pub async fn handle_id(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let namespace = env.durable_object("COUNTER").expect("DAWJKHDAD"); + let stub = namespace.id_from_name("A")?.get_stub()?; + // when calling fetch to a Durable Object, a full URL must be used. Alternatively, a + // compatibility flag can be provided in wrangler.toml to opt-in to older behavior: + // https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url + stub.fetch_with_str("https://fake-host/").await +} + +#[worker::send] +pub async fn handle_put_raw(req: Request, env: Env, _data: SomeSharedData) -> Result { + let namespace = env.durable_object("PUT_RAW_TEST_OBJECT")?; + let id = namespace.unique_id()?; + let stub = id.get_stub()?; + stub.fetch_with_request(req).await +} diff --git a/worker-sandbox/src/cache.rs b/worker-sandbox/src/cache.rs new file mode 100644 index 00000000..19de6f34 --- /dev/null +++ b/worker-sandbox/src/cache.rs @@ -0,0 +1,123 @@ +use super::SomeSharedData; +use futures_util::stream::StreamExt; +use rand::Rng; +use std::time::Duration; +use worker::{console_log, Cache, Date, Delay, Env, Request, Response, Result}; + +fn key(req: &Request) -> Result> { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + Ok(segments.nth(2).map(|s| s.to_owned())) +} + +#[worker::send] +pub async fn handle_cache_example( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + console_log!("url: {}", req.url()?.to_string()); + let cache = Cache::default(); + let key = req.url()?.to_string(); + if let Some(resp) = cache.get(&key, true).await? { + console_log!("Cache HIT!"); + Ok(resp) + } else { + console_log!("Cache MISS!"); + let mut resp = + Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?; + + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + resp.headers_mut().set("cache-control", "s-maxage=10")?; + cache.put(key, resp.cloned()?).await?; + Ok(resp) + } +} + +#[worker::send] +pub async fn handle_cache_api_get( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + if let Some(key) = key(&req)? { + let cache = Cache::default(); + if let Some(resp) = cache.get(format!("https://{key}"), true).await? { + return Ok(resp); + } else { + return Response::ok("cache miss"); + } + } + Response::error("key missing", 400) +} + +#[worker::send] +pub async fn handle_cache_api_put( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + if let Some(key) = key(&req)? { + let cache = Cache::default(); + + let mut resp = + Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?; + + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + resp.headers_mut().set("cache-control", "s-maxage=10")?; + cache.put(format!("https://{key}"), resp.cloned()?).await?; + return Ok(resp); + } + Response::error("key missing", 400) +} + +#[worker::send] +pub async fn handle_cache_api_delete( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + if let Some(key) = key(&req)? { + let cache = Cache::default(); + + let res = cache.delete(format!("https://{key}"), true).await?; + return Response::ok(serde_json::to_string(&res)?); + } + Response::error("key missing", 400) +} + +#[worker::send] +pub async fn handle_cache_stream( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + console_log!("url: {}", req.url()?.to_string()); + let cache = Cache::default(); + let key = req.url()?.to_string(); + if let Some(resp) = cache.get(&key, true).await? { + console_log!("Cache HIT!"); + Ok(resp) + } else { + console_log!("Cache MISS!"); + let mut rng = rand::thread_rng(); + let count = rng.gen_range(0..10); + let stream = futures_util::stream::repeat("Hello, world!\n") + .take(count) + .then(|text| async move { + Delay::from(Duration::from_millis(50)).await; + Result::Ok(text.as_bytes().to_vec()) + }); + + let mut resp = Response::from_stream(stream)?; + console_log!("resp = {:?}", resp); + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + resp.headers_mut().set("cache-control", "s-maxage=10")?; + + cache.put(key, resp.cloned()?).await?; + Ok(resp) + } +} diff --git a/worker-sandbox/src/d1.rs b/worker-sandbox/src/d1.rs index 1e750e99..7f0f1a40 100644 --- a/worker-sandbox/src/d1.rs +++ b/worker-sandbox/src/d1.rs @@ -10,11 +10,13 @@ struct Person { age: u32, } +#[worker::send] pub async fn prepared_statement( _req: Request, - ctx: RouteContext, + env: Env, + _data: SomeSharedData, ) -> Result { - let db = ctx.env.d1("DB")?; + let db = env.d1("DB")?; let stmt = worker::query!(&db, "SELECT * FROM people WHERE name = ?", "Ryan Upton")?; // All rows @@ -49,8 +51,9 @@ pub async fn prepared_statement( Response::ok("ok") } -pub async fn batch(_req: Request, ctx: RouteContext) -> Result { - let db = ctx.env.d1("DB")?; +#[worker::send] +pub async fn batch(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let db = env.d1("DB")?; let mut results = db .batch(vec![ worker::query!(&db, "SELECT * FROM people WHERE id < 4"), @@ -73,8 +76,9 @@ pub async fn batch(_req: Request, ctx: RouteContext) -> Result) -> Result { - let db = ctx.env.d1("DB")?; +#[worker::send] +pub async fn exec(mut req: Request, env: Env, _data: SomeSharedData) -> Result { + let db = env.d1("DB")?; let result = db .exec(req.text().await?.as_ref()) .await @@ -83,14 +87,16 @@ pub async fn exec(mut req: Request, ctx: RouteContext) -> Result Response::ok(result.count().unwrap_or_default().to_string()) } -pub async fn dump(_req: Request, ctx: RouteContext) -> Result { - let db = ctx.env.d1("DB")?; +#[worker::send] +pub async fn dump(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let db = env.d1("DB")?; let bytes = db.dump().await?; Response::from_bytes(bytes) } -pub async fn error(_req: Request, ctx: RouteContext) -> Result { - let db = ctx.env.d1("DB")?; +#[worker::send] +pub async fn error(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let db = env.d1("DB")?; let error = db .exec("THIS IS NOT VALID SQL") .await diff --git a/worker-sandbox/src/fetch.rs b/worker-sandbox/src/fetch.rs new file mode 100644 index 00000000..038def08 --- /dev/null +++ b/worker-sandbox/src/fetch.rs @@ -0,0 +1,179 @@ +use super::{ApiData, SomeSharedData}; +use futures_util::future::Either; +use std::time::Duration; +use worker::{ + wasm_bindgen_futures, AbortController, Delay, Env, Fetch, Method, Request, RequestInit, + Response, Result, +}; + +#[worker::send] +pub async fn handle_fetch(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + let req = Request::new("https://example.com", Method::Post)?; + let resp = Fetch::Request(req).send().await?; + let resp2 = Fetch::Url("https://example.com".parse()?).send().await?; + Response::ok(format!( + "received responses with codes {} and {}", + resp.status_code(), + resp2.status_code() + )) +} + +#[worker::send] +pub async fn handle_fetch_json( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let data: ApiData = Fetch::Url( + "https://jsonplaceholder.typicode.com/todos/1" + .parse() + .unwrap(), + ) + .send() + .await? + .json() + .await?; + Response::ok(format!( + "API Returned user: {} with title: {} and completed: {}", + data.user_id, data.title, data.completed + )) +} + +#[worker::send] +pub async fn handle_proxy_request( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let uri = req.url()?; + let url = uri + .path_segments() + .unwrap() + .skip(1) + .collect::>() + .join("/"); + crate::console_log!("{}", url); + Fetch::Url(url.parse()?).send().await +} + +#[worker::send] +pub async fn handle_request_init_fetch( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let init = RequestInit::new(); + Fetch::Request(Request::new_with_init("https://cloudflare.com", &init)?) + .send() + .await +} + +#[worker::send] +pub async fn handle_request_init_fetch_post( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let mut init = RequestInit::new(); + init.method = Method::Post; + Fetch::Request(Request::new_with_init("https://httpbin.org/post", &init)?) + .send() + .await +} + +#[worker::send] +pub async fn handle_cancelled_fetch( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let controller = AbortController::default(); + let signal = controller.signal(); + + let (tx, rx) = futures_channel::oneshot::channel(); + + // Spawns a future that'll make our fetch request and not block this function. + wasm_bindgen_futures::spawn_local({ + async move { + let fetch = Fetch::Url("https://cloudflare.com".parse().unwrap()); + let res = fetch.send_with_signal(&signal).await; + tx.send(res).unwrap(); + } + }); + + // And then we try to abort that fetch as soon as we start it, hopefully before + // cloudflare.com responds. + controller.abort(); + + let res = rx.await.unwrap(); + let res = res.unwrap_or_else(|err| { + let text = err.to_string(); + Response::ok(text).unwrap() + }); + + Ok(res) +} + +#[worker::send] +pub async fn handle_fetch_timeout( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let controller = AbortController::default(); + let signal = controller.signal(); + + let fetch_fut = async { + let fetch = Fetch::Url("https://miniflare.mocks/delay".parse().unwrap()); + let mut res = fetch.send_with_signal(&signal).await?; + let text = res.text().await?; + Ok::(text) + }; + let delay_fut = async { + Delay::from(Duration::from_millis(100)).await; + controller.abort(); + Response::ok("Cancelled") + }; + + futures_util::pin_mut!(fetch_fut); + futures_util::pin_mut!(delay_fut); + + match futures_util::future::select(delay_fut, fetch_fut).await { + Either::Left((res, cancelled_fut)) => { + // Ensure that the cancelled future returns an AbortError. + match cancelled_fut.await { + Err(e) if e.to_string().contains("AbortError") => { /* Yay! It worked, let's do nothing to celebrate */ + } + Err(e) => panic!( + "Fetch errored with a different error than expected: {:#?}", + e + ), + Ok(text) => panic!("Fetch unexpectedly succeeded: {}", text), + } + + res + } + Either::Right(_) => panic!("Delay future should have resolved first"), + } +} + +#[worker::send] +pub async fn handle_cloned_fetch( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let mut resp = Fetch::Url( + "https://jsonplaceholder.typicode.com/todos/1" + .parse() + .unwrap(), + ) + .send() + .await?; + let mut resp1 = resp.cloned()?; + + let left = resp.text().await?; + let right = resp1.text().await?; + + Response::ok((left == right).to_string()) +} diff --git a/worker-sandbox/src/form.rs b/worker-sandbox/src/form.rs new file mode 100644 index 00000000..bccff11f --- /dev/null +++ b/worker-sandbox/src/form.rs @@ -0,0 +1,130 @@ +use super::SomeSharedData; +use blake2::Blake2b512; +use blake2::Digest; +use serde::{Deserialize, Serialize}; +use worker::kv; +use worker::{Env, Request, Result}; +use worker::{FormEntry, Response}; + +#[worker::send] +pub async fn handle_formdata_name( + mut req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let form = req.form_data().await?; + const NAME: &str = "name"; + let bad_request = Response::error("Bad Request", 400); + + if !form.has(NAME) { + return bad_request; + } + + let names: Vec = form + .get_all(NAME) + .unwrap_or_default() + .into_iter() + .map(|entry| match entry { + FormEntry::Field(s) => s, + FormEntry::File(f) => f.name(), + }) + .collect(); + if names.len() > 1 { + return Response::from_json(&serde_json::json!({ "names": names })); + } + + if let Some(value) = form.get(NAME) { + match value { + FormEntry::Field(v) => Response::from_json(&serde_json::json!({ NAME: v })), + _ => bad_request, + } + } else { + bad_request + } +} + +#[derive(Deserialize, Serialize)] +struct FileSize { + name: String, + size: u32, +} + +#[worker::send] +pub async fn handle_formdata_file_size( + mut req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let form = req.form_data().await?; + + if let Some(entry) = form.get("file") { + return match entry { + FormEntry::File(file) => { + let kv: kv::KvStore = env.kv("FILE_SIZES")?; + + // create a new FileSize record to store + let b = file.bytes().await?; + let record = FileSize { + name: file.name(), + size: b.len() as u32, + }; + + // hash the file, and use result as the key + let mut hasher = Blake2b512::new(); + hasher.update(b); + let hash = hasher.finalize(); + let key = hex::encode(&hash[..]); + + // serialize the record and put it into kv + let val = serde_json::to_string(&record)?; + kv.put(&key, val)?.execute().await?; + + // list the default number of keys from the namespace + Response::from_json(&kv.list().execute().await?.keys) + } + _ => Response::error("Bad Request", 400), + }; + } + + Response::error("Bad Request", 400) +} + +#[worker::send] +pub async fn handle_formdata_file_size_hash( + req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + let hash = segments.nth(1); + if let Some(hash) = hash { + let kv = env.kv("FILE_SIZES")?; + return match kv.get(hash).json::().await? { + Some(val) => Response::from_json(&val), + None => Response::error("Not Found", 404), + }; + } + + Response::error("Bad Request", 400) +} + +#[worker::send] +pub async fn handle_is_secret( + mut req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let form = req.form_data().await?; + if let Some(secret) = form.get("secret") { + match secret { + FormEntry::Field(name) => { + let val = env.secret(&name)?; + return Response::ok(val.to_string()); + } + _ => return Response::error("Bad Request", 400), + }; + } + + Response::error("Bad Request", 400) +} diff --git a/worker-sandbox/src/kv.rs b/worker-sandbox/src/kv.rs new file mode 100644 index 00000000..da221ca4 --- /dev/null +++ b/worker-sandbox/src/kv.rs @@ -0,0 +1,22 @@ +use super::SomeSharedData; +use worker::{Env, Request, Response, Result}; + +#[worker::send] +pub async fn handle_post_key_value( + req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + let key = segments.nth(1); + let value = segments.next(); + let kv = env.kv("SOME_NAMESPACE")?; + if let Some(key) = key { + if let Some(value) = value { + kv.put(key, value)?.execute().await?; + } + } + + Response::from_json(&kv.list().execute().await?) +} diff --git a/worker-sandbox/src/lib.rs b/worker-sandbox/src/lib.rs index 8e18e040..fb889395 100644 --- a/worker-sandbox/src/lib.rs +++ b/worker-sandbox/src/lib.rs @@ -1,25 +1,27 @@ -use blake2::{Blake2b512, Digest}; -use futures_util::{future::Either, StreamExt, TryStreamExt}; -use rand::Rng; use serde::{Deserialize, Serialize}; -#[cfg(feature = "http")] -use std::convert::TryInto; -use std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - Mutex, - }, - time::Duration, +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Mutex, }; -use uuid::Uuid; +#[cfg(feature = "http")] +use tower_service::Service; use worker::*; - mod alarm; +mod cache; mod counter; mod d1; +mod fetch; +mod form; +mod kv; +mod queue; mod r2; +mod request; +mod router; +mod service; mod test; +mod user; mod utils; +mod ws; #[derive(Deserialize, Serialize)] struct MyData { @@ -38,49 +40,14 @@ struct ApiData { completed: bool, } -#[derive(Serialize)] -struct User { - id: String, - timestamp: u64, - date_from_int: String, - date_from_str: String, -} - -#[derive(Deserialize, Serialize)] -struct FileSize { - name: String, - size: u32, -} - +#[derive(Clone)] pub struct SomeSharedData { regex: regex::Regex, } -fn handle_a_request(req: Request, _ctx: RouteContext) -> Result { - Response::ok(format!( - "req at: {}, located at: {:?}, within: {}", - req.path(), - req.cf().map(|cf| cf.coordinates().unwrap_or_default()), - req.cf() - .map(|cf| cf.region().unwrap_or_else(|| "unknown region".into())) - .unwrap_or(String::from("No CF properties")) - )) -} - -async fn handle_async_request(req: Request, _ctx: RouteContext) -> Result { - Response::ok(format!( - "[async] req at: {}, located at: {:?}, within: {}", - req.path(), - req.cf().map(|cf| cf.coordinates().unwrap_or_default()), - req.cf() - .map(|cf| cf.region().unwrap_or_else(|| "unknown region".into())) - .unwrap_or(String::from("No CF properties")) - )) -} - static GLOBAL_STATE: AtomicBool = AtomicBool::new(false); -static GLOBAL_QUEUE_STATE: Mutex> = Mutex::new(Vec::new()); +static GLOBAL_QUEUE_STATE: Mutex> = Mutex::new(Vec::new()); // We're able to specify a start event that is called when the WASM is initialized before any // requests. This is useful if you have some global state or setup code, like a logger. This is @@ -98,7 +65,7 @@ type HandlerRequest = HttpRequest; #[cfg(not(feature = "http"))] type HandlerRequest = Request; #[cfg(feature = "http")] -type HandlerResponse = HttpResponse; +type HandlerResponse = http::Response; #[cfg(not(feature = "http"))] type HandlerResponse = Response; @@ -112,706 +79,17 @@ pub async fn main( regex: regex::Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(), }; - let router = Router::with_data(data); // if no data is needed, pass `()` or any other valid data - #[cfg(feature = "http")] - let req: Request = request.try_into()?; - #[cfg(not(feature = "http"))] - let req = request; - - let worker_response = router - .get("/request", handle_a_request) // can pass a fn pointer to keep routes tidy - .get_async("/async-request", handle_async_request) - .get("/websocket", |_, ctx| { - // Accept / handle a websocket connection - let pair = WebSocketPair::new()?; - let server = pair.server; - server.accept()?; - - let some_namespace_kv = ctx.kv("SOME_NAMESPACE")?; - - wasm_bindgen_futures::spawn_local(async move { - let mut event_stream = server.events().expect("could not open stream"); - - while let Some(event) = event_stream.next().await { - match event.expect("received error in websocket") { - WebsocketEvent::Message(msg) => { - if let Some(text) = msg.text() { - server.send_with_str(text).expect("could not relay text"); - } - } - WebsocketEvent::Close(_) => { - // Sets a key in a test KV so the integration tests can query if we - // actually got the close event. We can't use the shared dat a for this - // because miniflare resets that every request. - some_namespace_kv - .put("got-close-event", "true") - .unwrap() - .execute() - .await - .unwrap(); - } - } - } - }); - - Response::from_websocket(pair.client) - }) - .get_async("/got-close-event", |_, ctx| async move { - let some_namespace_kv = ctx.kv("SOME_NAMESPACE")?; - let got_close_event = some_namespace_kv - .get("got-close-event") - .text() - .await? - .unwrap_or_else(|| "false".into()); - - // Let the integration tests have some way of knowing if we successfully received the closed event. - Response::ok(got_close_event) - }) - .get_async("/ws-client", |_, _| async move { - let ws = WebSocket::connect("wss://echo.miniflare.mocks/".parse()?).await?; - - // It's important that we call this before we send our first message, otherwise we will - // not have any event listeners on the socket to receive the echoed message. - let mut event_stream = ws.events()?; - - ws.accept()?; - ws.send_with_str("Hello, world!")?; - - while let Some(event) = event_stream.next().await { - let event = event?; - - if let WebsocketEvent::Message(msg) = event { - if let Some(text) = msg.text() { - return Response::ok(text); - } - } - } - - Response::error("never got a message echoed back :(", 500) - }) - .get("/test-data", |_, ctx| { - // just here to test data works - if ctx.data.regex.is_match("2014-01-01") { - Response::ok("data ok") - } else { - Response::error("bad match", 500) - } - }) - .post("/xor/:num", |mut req, ctx| { - let num: u8 = match ctx.param("num").unwrap().parse() { - Ok(num) => num, - Err(_) => return Response::error("invalid byte", 400), - }; - - let xor_stream = req.stream()?.map_ok(move |mut buf| { - buf.iter_mut().for_each(|x| *x ^= num); - buf - }); - - Response::from_stream(xor_stream) - }) - .post("/headers", |req, _ctx| { - let mut headers: http::HeaderMap = req.headers().into(); - headers.append("Hello", "World!".parse().unwrap()); - - Response::ok("returned your headers to you.") - .map(|res| res.with_headers(headers.into())) - }) - .post_async("/formdata-name", |mut req, _ctx| async move { - let form = req.form_data().await?; - const NAME: &str = "name"; - let bad_request = Response::error("Bad Request", 400); - - if !form.has(NAME) { - return bad_request; - } - - let names: Vec = form - .get_all(NAME) - .unwrap_or_default() - .into_iter() - .map(|entry| match entry { - FormEntry::Field(s) => s, - FormEntry::File(f) => f.name(), - }) - .collect(); - if names.len() > 1 { - return Response::from_json(&serde_json::json!({ "names": names })); - } - - if let Some(value) = form.get(NAME) { - match value { - FormEntry::Field(v) => Response::from_json(&serde_json::json!({ NAME: v })), - _ => bad_request, - } - } else { - bad_request - } - }) - .post_async("/is-secret", |mut req, ctx| async move { - let form = req.form_data().await?; - if let Some(secret) = form.get("secret") { - match secret { - FormEntry::Field(name) => { - let val = ctx.secret(&name)?; - return Response::ok(val.to_string()); - } - _ => return Response::error("Bad Request", 400), - }; - } - - Response::error("Bad Request", 400) - }) - .post_async("/formdata-file-size", |mut req, ctx| async move { - let form = req.form_data().await?; - - if let Some(entry) = form.get("file") { - return match entry { - FormEntry::File(file) => { - let kv: kv::KvStore = ctx.kv("FILE_SIZES")?; - - // create a new FileSize record to store - let b = file.bytes().await?; - let record = FileSize { - name: file.name(), - size: b.len() as u32, - }; - - // hash the file, and use result as the key - let mut hasher = Blake2b512::new(); - hasher.update(b); - let hash = hasher.finalize(); - let key = hex::encode(&hash[..]); - - // serialize the record and put it into kv - let val = serde_json::to_string(&record)?; - kv.put(&key, val)?.execute().await?; - - // list the default number of keys from the namespace - Response::from_json(&kv.list().execute().await?.keys) - } - _ => Response::error("Bad Request", 400), - }; - } - - Response::error("Bad Request", 400) - }) - .get_async("/formdata-file-size/:hash", |_, ctx| async move { - if let Some(hash) = ctx.param("hash") { - let kv = ctx.kv("FILE_SIZES")?; - return match kv.get(hash).json::().await? { - Some(val) => Response::from_json(&val), - None => Response::error("Not Found", 404), - }; - } - - Response::error("Bad Request", 400) - }) - .post_async("/post-file-size", |mut req, _| async move { - let bytes = req.bytes().await?; - Response::ok(format!("size = {}", bytes.len())) - }) - .get("/user/:id/test", |_req, ctx| { - if let Some(id) = ctx.param("id") { - return Response::ok(format!("TEST user id: {id}")); - } - - Response::error("Error", 500) - }) - .get("/user/:id", |_req, ctx| { - if let Some(id) = ctx.param("id") { - return Response::from_json(&User { - id: id.to_string(), - timestamp: Date::now().as_millis(), - date_from_int: Date::new(DateInit::Millis(1234567890)).to_string(), - date_from_str: Date::new(DateInit::String( - "Wed Jan 14 1980 23:56:07 GMT-0700 (Mountain Standard Time)".into(), - )) - .to_string(), - }); - } - - Response::error("Bad Request", 400) - }) - .post("/account/:id/zones", |_, ctx| { - Response::ok(format!( - "Create new zone for Account: {}", - ctx.param("id").unwrap_or(&"not found".into()) - )) - }) - .get("/account/:id/zones", |_, ctx| { - Response::ok(format!( - "Account id: {}..... You get a zone, you get a zone!", - ctx.param("id").unwrap_or(&"not found".into()) - )) - }) - .post_async("/async-text-echo", |mut req, _ctx| async move { - Response::ok(req.text().await?) - }) - .get_async("/fetch", |_req, _ctx| async move { - let req = Request::new("https://example.com", Method::Post)?; - let resp = Fetch::Request(req).send().await?; - let resp2 = Fetch::Url("https://example.com".parse()?).send().await?; - Response::ok(format!( - "received responses with codes {} and {}", - resp.status_code(), - resp2.status_code() - )) - }) - .get_async("/fetch_json", |_req, _ctx| async move { - let data: ApiData = Fetch::Url( - "https://jsonplaceholder.typicode.com/todos/1" - .parse() - .unwrap(), - ) - .send() - .await? - .json() - .await?; - Response::ok(format!( - "API Returned user: {} with title: {} and completed: {}", - data.user_id, data.title, data.completed - )) - }) - .get_async("/proxy_request/*url", |_req, ctx| async move { - let url = ctx.param("url").unwrap(); - Fetch::Url(url.parse()?).send().await - }) - .get_async("/durable/alarm", |_req, ctx| async move { - let namespace = ctx.durable_object("ALARM")?; - let stub = namespace.id_from_name("alarm")?.get_stub()?; - // when calling fetch to a Durable Object, a full URL must be used. Alternatively, a - // compatibility flag can be provided in wrangler.toml to opt-in to older behavior: - // https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url - stub.fetch_with_str("https://fake-host/alarm").await - }) - .get_async("/durable/:id", |_req, ctx| async move { - let namespace = ctx.durable_object("COUNTER").expect("DAWJKHDAD"); - let stub = namespace.id_from_name("A")?.get_stub()?; - // when calling fetch to a Durable Object, a full URL must be used. Alternatively, a - // compatibility flag can be provided in wrangler.toml to opt-in to older behavior: - // https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url - stub.fetch_with_str("https://fake-host/").await - }) - .get_async("/durable/put-raw", |req, ctx| async move { - let namespace = ctx.durable_object("PUT_RAW_TEST_OBJECT")?; - let id = namespace.unique_id()?; - let stub = id.get_stub()?; - stub.fetch_with_request(req).await - }) - .get("/secret", |_req, ctx| { - Response::ok(ctx.secret("SOME_SECRET")?.to_string()) - }) - .get("/var", |_req, ctx| { - Response::ok(ctx.var("SOME_VARIABLE")?.to_string()) - }) - .post_async("/kv/:key/:value", |_req, ctx| async move { - let kv = ctx.kv("SOME_NAMESPACE")?; - if let Some(key) = ctx.param("key") { - if let Some(value) = ctx.param("value") { - kv.put(key, value)?.execute().await?; - } - } - - Response::from_json(&kv.list().execute().await?) - }) - .get("/bytes", |_, _| { - Response::from_bytes(vec![1, 2, 3, 4, 5, 6, 7]) - }) - .post_async("/api-data", |mut req, _ctx| async move { - let data = req.bytes().await?; - let mut todo: ApiData = match serde_json::from_slice(&data) { - Ok(todo) => todo, - Err(e) => { - return Response::ok(e.to_string()); - } - }; - - unsafe { todo.title.as_mut_vec().reverse() }; - - console_log!("todo = (title {}) (id {})", todo.title, todo.user_id); - - Response::from_bytes(serde_json::to_vec(&todo)?) - }) - .post_async("/nonsense-repeat", |_, ctx| async move { - if ctx.data.regex.is_match("2014-01-01") { - Response::ok("data ok") - } else { - Response::error("bad match", 500) - } - }) - .get("/status/:code", |_, ctx| { - if let Some(code) = ctx.param("code") { - return match code.parse::() { - Ok(status) => Response::ok("You set the status code!") - .map(|resp| resp.with_status(status)), - Err(_e) => Response::error("Failed to parse your status code.", 400), - }; - } - - Response::error("Bad Request", 400) - }) - .put("/", respond) - .patch("/", respond) - .delete("/", respond) - .head("/", respond) - .put_async("/async", respond_async) - .patch_async("/async", respond_async) - .delete_async("/async", respond_async) - .head_async("/async", respond_async) - .options("/*catchall", |_, ctx| { - Response::ok(ctx.param("catchall").unwrap()) - }) - .get_async("/request-init-fetch", |_, _| async move { - let init = RequestInit::new(); - Fetch::Request(Request::new_with_init("https://cloudflare.com", &init)?) - .send() - .await - }) - .get_async("/request-init-fetch-post", |_, _| async move { - let mut init = RequestInit::new(); - init.method = Method::Post; - Fetch::Request(Request::new_with_init("https://httpbin.org/post", &init)?) - .send() - .await - }) - .get_async("/cancelled-fetch", |_, _| async move { - let controller = AbortController::default(); - let signal = controller.signal(); - - let (tx, rx) = futures_channel::oneshot::channel(); - - // Spawns a future that'll make our fetch request and not block this function. - wasm_bindgen_futures::spawn_local({ - async move { - let fetch = Fetch::Url("https://cloudflare.com".parse().unwrap()); - let res = fetch.send_with_signal(&signal).await; - tx.send(res).unwrap(); - } - }); - - // And then we try to abort that fetch as soon as we start it, hopefully before - // cloudflare.com responds. - controller.abort(); - - let res = rx.await.unwrap(); - let res = res.unwrap_or_else(|err| { - let text = err.to_string(); - Response::ok(text).unwrap() - }); - - Ok(res) - }) - .get_async("/fetch-timeout", |_, _| async move { - let controller = AbortController::default(); - let signal = controller.signal(); - - let fetch_fut = async { - let fetch = Fetch::Url("https://miniflare.mocks/delay".parse().unwrap()); - let mut res = fetch.send_with_signal(&signal).await?; - let text = res.text().await?; - Ok::(text) - }; - let delay_fut = async { - Delay::from(Duration::from_millis(100)).await; - controller.abort(); - Response::ok("Cancelled") - }; - - futures_util::pin_mut!(fetch_fut); - futures_util::pin_mut!(delay_fut); - - match futures_util::future::select(delay_fut, fetch_fut).await { - Either::Left((res, cancelled_fut)) => { - // Ensure that the cancelled future returns an AbortError. - match cancelled_fut.await { - Err(e) if e.to_string().contains("AbortError") => { /* Yay! It worked, let's do nothing to celebrate */}, - Err(e) => panic!("Fetch errored with a different error than expected: {:#?}", e), - Ok(text) => panic!("Fetch unexpectedly succeeded: {}", text) - } - - res - }, - Either::Right(_) => panic!("Delay future should have resolved first"), - } - }) - .get("/redirect-default", |_, _| { - Response::redirect("https://example.com".parse().unwrap()) - }) - .get("/redirect-307", |_, _| { - Response::redirect_with_status("https://example.com".parse().unwrap(), 307) - }) - .get("/now", |_, _| { - let now = chrono::Utc::now(); - let js_date: Date = now.into(); - Response::ok(js_date.to_string()) - }) - .get_async("/cloned", |_, _| async { - let mut resp = Response::ok("Hello")?; - let mut resp1 = resp.cloned()?; - - let left = resp.text().await?; - let right = resp1.text().await?; - - Response::ok((left == right).to_string()) - }) - .get_async("/cloned-stream", |_, _| async { - let stream = futures_util::stream::repeat(()) - .take(10) - .enumerate() - .then(|(index, _)| async move { - Delay::from(Duration::from_millis(100)).await; - Result::Ok(index.to_string().into_bytes()) - }); - - let mut resp = Response::from_stream(stream)?; - let mut resp1 = resp.cloned()?; - - let left = resp.text().await?; - let right = resp1.text().await?; - - Response::ok((left == right).to_string()) - }) - .get_async("/cloned-fetch", |_, _| async { - let mut resp = Fetch::Url( - "https://jsonplaceholder.typicode.com/todos/1" - .parse() - .unwrap(), - ) - .send() - .await?; - let mut resp1 = resp.cloned()?; - - let left = resp.text().await?; - let right = resp1.text().await?; - - Response::ok((left == right).to_string()) - }) - .get_async("/wait/:delay", |_, ctx| async move { - let delay: Delay = match ctx.param("delay").unwrap().parse() { - Ok(delay) => Duration::from_millis(delay).into(), - Err(_) => return Response::error("invalid delay", 400), - }; - - // Wait for the delay to pass - delay.await; - - Response::ok("Waited!\n") - }) - .get("/custom-response-body", |_, _| { - Response::from_body(ResponseBody::Body(vec![b'h', b'e', b'l', b'l', b'o'])) - }) - .get("/init-called", |_, _| { - let init_called = GLOBAL_STATE.load(Ordering::SeqCst); - Response::ok(init_called.to_string()) - }) - .get_async("/cache-example", |req, _| async move { - console_log!("url: {}", req.url()?.to_string()); - let cache = Cache::default(); - let key = req.url()?.to_string(); - if let Some(resp) = cache.get(&key, true).await? { - console_log!("Cache HIT!"); - Ok(resp) - } else { - - console_log!("Cache MISS!"); - let mut resp = Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?; - - // Cache API respects Cache-Control headers. Setting s-max-age to 10 - // will limit the response to be in cache for 10 seconds max - resp.headers_mut().set("cache-control", "s-maxage=10")?; - cache.put(key, resp.cloned()?).await?; - Ok(resp) - } - }) - .get_async("/cache-api/get/:key", |_req, ctx| async move { - if let Some(key) = ctx.param("key") { - let cache = Cache::default(); - if let Some(resp) = cache.get(format!("https://{key}"), true).await? { - return Ok(resp); - } else { - return Response::ok("cache miss"); - } - } - Response::error("key missing", 400) - }) - .put_async("/cache-api/put/:key", |_req, ctx| async move { - if let Some(key) = ctx.param("key") { - let cache = Cache::default(); - - let mut resp = Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?; - - // Cache API respects Cache-Control headers. Setting s-max-age to 10 - // will limit the response to be in cache for 10 seconds max - resp.headers_mut().set("cache-control", "s-maxage=10")?; - cache.put(format!("https://{key}"), resp.cloned()?).await?; - return Ok(resp); - } - Response::error("key missing", 400) - }) - .post_async("/cache-api/delete/:key", |_req, ctx| async move { - if let Some(key) = ctx.param("key") { - let cache = Cache::default(); - - let res = cache.delete(format!("https://{key}"), true).await?; - return Response::ok(serde_json::to_string(&res)?); - } - Response::error("key missing", 400) - }) - .get_async("/cache-stream", |req, _| async move { - console_log!("url: {}", req.url()?.to_string()); - let cache = Cache::default(); - let key = req.url()?.to_string(); - if let Some(resp) = cache.get(&key, true).await? { - console_log!("Cache HIT!"); - Ok(resp) - } else { - console_log!("Cache MISS!"); - let mut rng = rand::thread_rng(); - let count = rng.gen_range(0..10); - let stream = futures_util::stream::repeat("Hello, world!\n") - .take(count) - .then(|text| async move { - Delay::from(Duration::from_millis(50)).await; - Result::Ok(text.as_bytes().to_vec()) - }); - - let mut resp = Response::from_stream(stream)?; - console_log!("resp = {:?}", resp); - // Cache API respects Cache-Control headers. Setting s-max-age to 10 - // will limit the response to be in cache for 10 seconds max - resp.headers_mut().set("cache-control", "s-maxage=10")?; - - cache.put(key, resp.cloned()?).await?; - Ok(resp) - } - }) - .get_async("/remote-by-request", |req, ctx| async move { - let fetcher = ctx.service("remote")?; - - #[cfg(feature="http")] - let http_request = req.try_into()?; - #[cfg(not(feature="http"))] - let http_request = req; - - let response = fetcher.fetch_request(http_request).await?; - - #[cfg(feature="http")] - let result = Ok(TryInto::::try_into(response)?); - #[cfg(not(feature="http"))] - let result = Ok(response); - - result - }) - .get_async("/remote-by-path", |req, ctx| async move { - let fetcher = ctx.service("remote")?; - let mut init = RequestInit::new(); - init.with_method(Method::Post); - let response = fetcher.fetch(req.url()?.to_string(), Some(init)).await?; - - #[cfg(feature="http")] - let result = Ok(TryInto::::try_into(response)?); - #[cfg(not(feature="http"))] - let result = Ok(response); - - result - }) - .post_async("/queue/send/:id", |_req, ctx| async move { - let id = match ctx.param("id").map(|id|Uuid::try_parse(id).ok()).and_then(|u|u) { - Some(id) => id, - None => { - return Response::error("Failed to parse id, expected a UUID", 400); - } - }; - let my_queue = match ctx.env.queue("my_queue") { - Ok(queue) => queue, - Err(err) => { - return Response::error(format!("Failed to get queue: {err:?}"), 500) - } - }; - match my_queue.send(&QueueBody { - id, - id_string: id.to_string(), - }).await { - Ok(_) => { - Response::ok("Message sent") - } - Err(err) => { - Response::error(format!("Failed to send message to queue: {err:?}"), 500) - } - } - }).get_async("/queue", |_req, _ctx| async move { - let guard = GLOBAL_QUEUE_STATE.lock().unwrap(); - let messages: Vec = guard.clone(); - Response::from_json(&messages) - }) - .get_async("/d1/prepared", d1::prepared_statement) - .get_async("/d1/batch", d1::batch) - .get_async("/d1/dump", d1::dump) - .post_async("/d1/exec", d1::exec) - .get_async("/d1/error", d1::error) - .get_async("/r2/list-empty", r2::list_empty) - .get_async("/r2/list", r2::list) - .get_async("/r2/get-empty", r2::get_empty) - .get_async("/r2/get", r2::get) - .put_async("/r2/put", r2::put) - .put_async("/r2/put-properties", r2::put_properties) - .put_async("/r2/put-multipart", r2::put_multipart) - .delete_async("/r2/delete", r2::delete) - .or_else_any_method_async("/*catchall", |_, ctx| async move { - console_log!( - "[or_else_any_method_async] caught: {}", - ctx.param("catchall").unwrap_or(&"?".to_string()) - ); + let res = { + let mut router = router::make_router(data, env); + Ok(Service::call(&mut router, request).await?) + }; - Fetch::Url("https://github.com/404".parse().unwrap()) - .send() - .await - .map(|resp| resp.with_status(404)) - }) - .run(req, env) - .await; - #[cfg(feature = "http")] - let res = worker_response.map(|r| r.try_into())?; #[cfg(not(feature = "http"))] - let res = worker_response; - res -} - -fn respond(req: Request, _ctx: RouteContext) -> Result { - Response::ok(format!("Ok: {}", String::from(req.method()))).map(|resp| { - let mut headers = Headers::new(); - headers.set("x-testing", "123").unwrap(); - resp.with_headers(headers) - }) -} - -async fn respond_async(req: Request, _ctx: RouteContext) -> Result { - Response::ok(format!("Ok (async): {}", String::from(req.method()))).map(|resp| { - let mut headers = Headers::new(); - headers.set("x-testing", "123").unwrap(); - resp.with_headers(headers) - }) -} - -#[derive(Serialize, Debug, Clone, Deserialize)] -pub struct QueueBody { - pub id: Uuid, - pub id_string: String, -} + let res = { + let router = router::make_router(data); + router.run(request, env).await + }; -#[event(queue)] -pub async fn queue(message_batch: MessageBatch, _env: Env, _ctx: Context) -> Result<()> { - let mut guard = GLOBAL_QUEUE_STATE.lock().unwrap(); - for message in message_batch.messages()? { - console_log!( - "Received queue message {:?}, with id {} and timestamp: {}", - message.body, - message.id, - message.timestamp.to_string() - ); - guard.push(message.body); - } - Ok(()) + res } diff --git a/worker-sandbox/src/queue.rs b/worker-sandbox/src/queue.rs new file mode 100644 index 00000000..896b6a7d --- /dev/null +++ b/worker-sandbox/src/queue.rs @@ -0,0 +1,61 @@ +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use super::{SomeSharedData, GLOBAL_QUEUE_STATE}; +use worker::{console_log, event, Context, Env, MessageBatch, Request, Response, Result}; +#[derive(Serialize, Debug, Clone, Deserialize)] +pub struct QueueBody { + pub id: Uuid, + pub id_string: String, +} + +#[event(queue)] +pub async fn queue(message_batch: MessageBatch, _env: Env, _ctx: Context) -> Result<()> { + let mut guard = GLOBAL_QUEUE_STATE.lock().unwrap(); + for message in message_batch.messages()? { + console_log!( + "Received queue message {:?}, with id {} and timestamp: {}", + message.body, + message.id, + message.timestamp.to_string() + ); + guard.push(message.body); + } + Ok(()) +} + +#[worker::send] +pub async fn handle_queue_send(req: Request, env: Env, _data: SomeSharedData) -> Result { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + let id = match segments + .nth(2) + .map(|id| Uuid::try_parse(id).ok()) + .and_then(|u| u) + { + Some(id) => id, + None => { + return Response::error("Failed to parse id, expected a UUID", 400); + } + }; + let my_queue = match env.queue("my_queue") { + Ok(queue) => queue, + Err(err) => return Response::error(format!("Failed to get queue: {err:?}"), 500), + }; + match my_queue + .send(&QueueBody { + id, + id_string: id.to_string(), + }) + .await + { + Ok(_) => Response::ok("Message sent"), + Err(err) => Response::error(format!("Failed to send message to queue: {err:?}"), 500), + } +} + +pub async fn handle_queue(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + let guard = GLOBAL_QUEUE_STATE.lock().unwrap(); + let messages: Vec = guard.clone(); + Response::from_json(&messages) +} diff --git a/worker-sandbox/src/r2.rs b/worker-sandbox/src/r2.rs index 2c201edf..59e67244 100644 --- a/worker-sandbox/src/r2.rs +++ b/worker-sandbox/src/r2.rs @@ -2,8 +2,8 @@ use std::{collections::HashMap, sync::Mutex}; use futures_util::StreamExt; use worker::{ - Bucket, Conditional, Data, Date, FixedLengthStream, HttpMetadata, Include, Request, Response, - Result, RouteContext, + Bucket, Conditional, Data, Date, Env, FixedLengthStream, HttpMetadata, Include, Request, + Response, Result, }; use crate::SomeSharedData; @@ -32,8 +32,9 @@ pub async fn seed_bucket(bucket: &Bucket) -> Result<()> { Ok(()) } -pub async fn list_empty(_req: Request, ctx: RouteContext) -> Result { - let bucket = ctx.bucket("EMPTY_BUCKET")?; +#[worker::send] +pub async fn list_empty(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("EMPTY_BUCKET")?; let objects = bucket.list().execute().await?; assert_eq!(objects.objects().len(), 0); @@ -43,8 +44,9 @@ pub async fn list_empty(_req: Request, ctx: RouteContext) -> Res Response::ok("ok") } -pub async fn list(_req: Request, ctx: RouteContext) -> Result { - let bucket = ctx.bucket("SEEDED_BUCKET")?; +#[worker::send] +pub async fn list(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("SEEDED_BUCKET")?; seed_bucket(&bucket).await?; let objects = bucket.list().execute().await?; @@ -96,8 +98,9 @@ pub async fn list(_req: Request, ctx: RouteContext) -> Result) -> Result { - let bucket = ctx.bucket("EMPTY_BUCKET")?; +#[worker::send] +pub async fn get_empty(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("EMPTY_BUCKET")?; let object = bucket.get("doesnt-exist").execute().await?; assert!(object.is_none()); @@ -118,8 +121,9 @@ pub async fn get_empty(_req: Request, ctx: RouteContext) -> Resu Response::ok("ok") } -pub async fn get(_req: Request, ctx: RouteContext) -> Result { - let bucket = ctx.bucket("SEEDED_BUCKET")?; +#[worker::send] +pub async fn get(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("SEEDED_BUCKET")?; seed_bucket(&bucket).await?; let item = bucket.get("no-props").execute().await?.unwrap(); @@ -139,8 +143,9 @@ pub async fn get(_req: Request, ctx: RouteContext) -> Result) -> Result { - let bucket = ctx.bucket("PUT_BUCKET")?; +#[worker::send] +pub async fn put(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("PUT_BUCKET")?; // R2 requires that we use a fixed-length-stream for the body. let stream = futures_util::stream::repeat_with(|| Ok(vec![0u8; 16])).take(16); @@ -172,8 +177,9 @@ pub async fn put(_req: Request, ctx: RouteContext) -> Result) -> Result { - let bucket = ctx.bucket("PUT_BUCKET")?; +#[worker::send] +pub async fn put_properties(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("PUT_BUCKET")?; let (http_metadata, custom_metadata, object_with_props) = put_full_properties("with_props", &bucket).await?; @@ -185,11 +191,12 @@ pub async fn put_properties(_req: Request, ctx: RouteContext) -> Response::ok("ok") } -pub async fn put_multipart(_req: Request, ctx: RouteContext) -> Result { +#[worker::send] +pub async fn put_multipart(_req: Request, env: Env, _data: SomeSharedData) -> Result { const R2_MULTIPART_CHUNK_MIN_SIZE: usize = 5 * 1_024 * 1_024; // 5MiB. // const TEST_CHUNK_COUNT: usize = 3; - let bucket = ctx.bucket("PUT_BUCKET")?; + let bucket = env.bucket("PUT_BUCKET")?; let upload = bucket .create_multipart_upload("multipart_upload") @@ -236,8 +243,9 @@ pub async fn put_multipart(_req: Request, ctx: RouteContext) -> Response::ok("ok") } -pub async fn delete(_req: Request, ctx: RouteContext) -> Result { - let bucket = ctx.bucket("DELETE_BUCKET")?; +#[worker::send] +pub async fn delete(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let bucket = env.bucket("DELETE_BUCKET")?; bucket.put("key", Data::Empty).execute().await?; diff --git a/worker-sandbox/src/request.rs b/worker-sandbox/src/request.rs new file mode 100644 index 00000000..00ca6b30 --- /dev/null +++ b/worker-sandbox/src/request.rs @@ -0,0 +1,225 @@ +use crate::SomeSharedData; + +use super::ApiData; +use futures_util::StreamExt; +use futures_util::TryStreamExt; +use std::time::Duration; +use worker::Env; +use worker::{console_log, Date, Delay, Request, Response, ResponseBody, Result}; +pub fn handle_a_request(req: Request, _env: Env, _data: SomeSharedData) -> Result { + Response::ok(format!( + "req at: {}, located at: {:?}, within: {}", + req.path(), + req.cf().map(|cf| cf.coordinates().unwrap_or_default()), + req.cf() + .map(|cf| cf.region().unwrap_or_else(|| "unknown region".into())) + .unwrap_or(String::from("No CF properties")) + )) +} + +pub async fn handle_async_request( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + Response::ok(format!( + "[async] req at: {}, located at: {:?}, within: {}", + req.path(), + req.cf().map(|cf| cf.coordinates().unwrap_or_default()), + req.cf() + .map(|cf| cf.region().unwrap_or_else(|| "unknown region".into())) + .unwrap_or(String::from("No CF properties")) + )) +} + +pub async fn handle_test_data(_req: Request, _env: Env, data: SomeSharedData) -> Result { + // just here to test data works + if data.regex.is_match("2014-01-01") { + Response::ok("data ok") + } else { + Response::error("bad match", 500) + } +} + +pub async fn handle_xor(mut req: Request, _env: Env, _data: SomeSharedData) -> Result { + let url = req.url()?; + let num: u8 = match url.path_segments().unwrap().nth(1).unwrap().parse() { + Ok(num) => num, + Err(_) => return Response::error("invalid byte", 400), + }; + + let xor_stream = req.stream()?.map_ok(move |mut buf| { + buf.iter_mut().for_each(|x| *x ^= num); + buf + }); + + Response::from_stream(xor_stream) +} + +pub async fn handle_headers(req: Request, _env: Env, _data: SomeSharedData) -> Result { + let mut headers: http::HeaderMap = req.headers().into(); + headers.append("Hello", "World!".parse().unwrap()); + + Response::ok("returned your headers to you.").map(|res| res.with_headers(headers.into())) +} + +#[worker::send] +pub async fn handle_post_file_size( + mut req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let bytes = req.bytes().await?; + Response::ok(format!("size = {}", bytes.len())) +} + +#[worker::send] +pub async fn handle_async_text_echo( + mut req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + Response::ok(req.text().await?) +} + +pub async fn handle_secret(_req: Request, env: Env, _data: SomeSharedData) -> Result { + Response::ok(env.secret("SOME_SECRET")?.to_string()) +} + +pub async fn handle_var(_req: Request, env: Env, _data: SomeSharedData) -> Result { + Response::ok(env.var("SOME_VARIABLE")?.to_string()) +} + +pub async fn handle_bytes(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + Response::from_bytes(vec![1, 2, 3, 4, 5, 6, 7]) +} + +#[worker::send] +pub async fn handle_api_data( + mut req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let data = req.bytes().await?; + let mut todo: ApiData = match serde_json::from_slice(&data) { + Ok(todo) => todo, + Err(e) => { + return Response::ok(e.to_string()); + } + }; + + unsafe { todo.title.as_mut_vec().reverse() }; + + console_log!("todo = (title {}) (id {})", todo.title, todo.user_id); + + Response::from_bytes(serde_json::to_vec(&todo)?) +} + +pub async fn handle_nonsense_repeat( + _req: Request, + _env: Env, + data: SomeSharedData, +) -> Result { + if data.regex.is_match("2014-01-01") { + Response::ok("data ok") + } else { + Response::error("bad match", 500) + } +} + +pub async fn handle_status(req: Request, _env: Env, _data: SomeSharedData) -> Result { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + let code = segments.nth(1); + if let Some(code) = code { + return match code.parse::() { + Ok(status) => { + Response::ok("You set the status code!").map(|resp| resp.with_status(status)) + } + Err(_e) => Response::error("Failed to parse your status code.", 400), + }; + } + + Response::error("Bad Request", 400) +} + +pub async fn handle_redirect_default( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + Response::redirect("https://example.com".parse().unwrap()) +} + +pub async fn handle_redirect_307( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + Response::redirect_with_status("https://example.com".parse().unwrap(), 307) +} + +pub async fn handle_now(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + let now = chrono::Utc::now(); + let js_date: Date = now.into(); + Response::ok(js_date.to_string()) +} + +#[worker::send] +pub async fn handle_cloned(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + let mut resp = Response::ok("Hello")?; + let mut resp1 = resp.cloned()?; + + let left = resp.text().await?; + let right = resp1.text().await?; + + Response::ok((left == right).to_string()) +} + +#[worker::send] +pub async fn handle_cloned_stream( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let stream = + futures_util::stream::repeat(()) + .take(10) + .enumerate() + .then(|(index, _)| async move { + Delay::from(Duration::from_millis(100)).await; + Result::Ok(index.to_string().into_bytes()) + }); + + let mut resp = Response::from_stream(stream)?; + let mut resp1 = resp.cloned()?; + + let left = resp.text().await?; + let right = resp1.text().await?; + + Response::ok((left == right).to_string()) +} + +pub async fn handle_custom_response_body( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + Response::from_body(ResponseBody::Body(vec![b'h', b'e', b'l', b'l', b'o'])) +} + +#[worker::send] +pub async fn handle_wait_delay(req: Request, _env: Env, _data: SomeSharedData) -> Result { + let uri = req.url()?; + let mut segments = uri.path_segments().unwrap(); + let delay = segments.nth(1); + let delay: Delay = match delay.unwrap().parse() { + Ok(delay) => Duration::from_millis(delay).into(), + Err(_) => return Response::error("invalid delay", 400), + }; + + // Wait for the delay to pass + delay.await; + + Response::ok("Waited!\n") +} diff --git a/worker-sandbox/src/router.rs b/worker-sandbox/src/router.rs new file mode 100644 index 00000000..979cbf6f --- /dev/null +++ b/worker-sandbox/src/router.rs @@ -0,0 +1,373 @@ +use crate::{ + alarm, cache, d1, fetch, form, kv, queue, r2, request, service, user, ws, SomeSharedData, + GLOBAL_STATE, +}; +#[cfg(feature = "http")] +use std::convert::TryInto; +use std::sync::atomic::Ordering; + +use worker::{console_log, Env, Fetch, Headers, Request, Response, Result}; + +#[cfg(not(feature = "http"))] +use worker::{RouteContext, Router}; + +#[cfg(feature = "http")] +use axum::{ + routing::{delete, get, head, options, patch, post, put}, + Extension, +}; + +/// Rewrites a handler with legacy http types to use axum extractors / response type. +#[cfg(feature = "http")] +macro_rules! handler ( + ($name:path) => { + |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| async { + let resp = $name(req.try_into().expect("convert request"), env, data).await.expect("handler result"); + Into::>::into(resp) + } + } +); + +#[cfg(feature = "http")] +macro_rules! handler_sync ( + ($name:path) => { + |Extension(env): Extension, Extension(data): Extension, req: axum::extract::Request| async { + let resp = $name(req.try_into().expect("convert request"), env, data).expect("handler result"); + Into::>::into(resp) + } + } +); + +#[cfg(not(feature = "http"))] +macro_rules! handler ( + ($name:path) => { + |req: Request, ctx: RouteContext| async { + $name(req, ctx.env, ctx.data).await + } + } +); + +#[cfg(not(feature = "http"))] +macro_rules! handler_sync ( + ($name:path) => { + |req: Request, ctx: RouteContext| { + $name(req, ctx.env, ctx.data) + } + } +); + +#[cfg(feature = "http")] +pub fn make_router(data: SomeSharedData, env: Env) -> axum::Router { + axum::Router::new() + .route("/request", get(handler_sync!(request::handle_a_request))) + .route( + "/async-request", + get(handler!(request::handle_async_request)), + ) + .route("/websocket", get(handler!(ws::handle_websocket))) + .route("/got-close-event", get(handler!(handle_close_event))) + .route("/ws-client", get(handler!(ws::handle_websocket_client))) + .route("/test-data", get(handler!(request::handle_test_data))) + .route("/xor/:num", post(handler!(request::handle_xor))) + .route("/headers", post(handler!(request::handle_headers))) + .route("/formdata-name", post(handler!(form::handle_formdata_name))) + .route("/is-secret", post(handler!(form::handle_is_secret))) + .route( + "/formdata-file-size", + post(handler!(form::handle_formdata_file_size)), + ) + .route( + "/formdata-file-size/:hash", + get(handler!(form::handle_formdata_file_size_hash)), + ) + .route( + "/post-file-size", + post(handler!(request::handle_post_file_size)), + ) + .route("/user/:id/test", get(handler!(user::handle_user_id_test))) + .route("/user/:id", get(handler!(user::handle_user_id))) + .route( + "/account/:id/zones", + post(handler!(user::handle_post_account_id_zones)), + ) + .route( + "/account/:id/zones", + get(handler!(user::handle_get_account_id_zones)), + ) + .route( + "/async-text-echo", + post(handler!(request::handle_async_text_echo)), + ) + .route("/fetch", get(handler!(fetch::handle_fetch))) + .route("/fetch_json", get(handler!(fetch::handle_fetch_json))) + .route( + "/proxy_request/*url", + get(handler!(fetch::handle_proxy_request)), + ) + .route("/durable/alarm", get(handler!(alarm::handle_alarm))) + .route("/durable/:id", get(handler!(alarm::handle_id))) + .route("/durable/put-raw", get(handler!(alarm::handle_put_raw))) + .route("/var", get(handler!(request::handle_var))) + .route("/secret", get(handler!(request::handle_secret))) + .route("/kv/:key/:value", post(handler!(kv::handle_post_key_value))) + .route("/bytes", get(handler!(request::handle_bytes))) + .route("/api-data", post(handler!(request::handle_api_data))) + .route( + "/nonsense-repeat", + post(handler!(request::handle_nonsense_repeat)), + ) + .route("/status/:code", get(handler!(request::handle_status))) + .route("/", put(handler_sync!(respond))) + .route("/", patch(handler_sync!(respond))) + .route("/", delete(handler_sync!(respond))) + .route("/", head(handler_sync!(respond))) + .route("/async", put(handler!(respond_async))) + .route("/async", patch(handler!(respond_async))) + .route("/async", delete(handler!(respond_async))) + .route("/async", head(handler!(respond_async))) + .route("/*catchall", options(handler!(handle_options_catchall))) + .route( + "/request-init-fetch", + get(handler!(fetch::handle_request_init_fetch)), + ) + .route( + "/request-init-fetch-post", + get(handler!(fetch::handle_request_init_fetch_post)), + ) + .route( + "/cancelled-fetch", + get(handler!(fetch::handle_cancelled_fetch)), + ) + .route("/fetch-timeout", get(handler!(fetch::handle_fetch_timeout))) + .route( + "/redirect-default", + get(handler!(request::handle_redirect_default)), + ) + .route("/redirect-307", get(handler!(request::handle_redirect_307))) + .route("/now", get(handler!(request::handle_now))) + .route("/cloned", get(handler!(request::handle_cloned))) + .route( + "/cloned-stream", + get(handler!(request::handle_cloned_stream)), + ) + .route("/cloned-fetch", get(handler!(fetch::handle_cloned_fetch))) + .route("/wait/:delay", get(handler!(request::handle_wait_delay))) + .route( + "/custom-response-body", + get(handler!(request::handle_custom_response_body)), + ) + .route("/init-called", get(handler!(handle_init_called))) + .route("/cache-example", get(handler!(cache::handle_cache_example))) + .route( + "/cache-api/get/:key", + get(handler!(cache::handle_cache_api_get)), + ) + .route( + "/cache-api/put/:key", + put(handler!(cache::handle_cache_api_put)), + ) + .route( + "/cache-api/delete/:key", + post(handler!(cache::handle_cache_api_delete)), + ) + .route("/cache-stream", get(handler!(cache::handle_cache_stream))) + .route( + "/remote-by-request", + get(handler!(service::handle_remote_by_request)), + ) + .route( + "/remote-by-path", + get(handler!(service::handle_remote_by_path)), + ) + .route("/queue/send/:id", post(handler!(queue::handle_queue_send))) + .route("/queue", get(handler!(queue::handle_queue))) + .route("/d1/prepared", get(handler!(d1::prepared_statement))) + .route("/d1/batch", get(handler!(d1::batch))) + .route("/d1/dump", get(handler!(d1::dump))) + .route("/d1/exec", post(handler!(d1::exec))) + .route("/d1/error", get(handler!(d1::error))) + .route("/r2/list-empty", get(handler!(r2::list_empty))) + .route("/r2/list", get(handler!(r2::list))) + .route("/r2/get-empty", get(handler!(r2::get_empty))) + .route("/r2/get", get(handler!(r2::get))) + .route("/r2/put", put(handler!(r2::put))) + .route("/r2/put-properties", put(handler!(r2::put_properties))) + .route("/r2/put-multipart", put(handler!(r2::put_multipart))) + .route("/r2/delete", delete(handler!(r2::delete))) + .fallback(get(handler!(catchall))) + .layer(Extension(env)) + .layer(Extension(data)) +} + +#[cfg(not(feature = "http"))] +pub fn make_router<'a>(data: SomeSharedData) -> Router<'a, SomeSharedData> { + Router::with_data(data) + .get("/request", handler_sync!(request::handle_a_request)) // can pass a fn pointer to keep routes tidy + .get_async("/async-request", handler!(request::handle_async_request)) + .get_async("/websocket", handler!(ws::handle_websocket)) + .get_async("/got-close-event", handler!(handle_close_event)) + .get_async("/ws-client", handler!(ws::handle_websocket_client)) + .get_async("/test-data", handler!(request::handle_test_data)) + .post_async("/xor/:num", handler!(request::handle_xor)) + .post_async("/headers", handler!(request::handle_headers)) + .post_async("/formdata-name", handler!(form::handle_formdata_name)) + .post_async("/is-secret", handler!(form::handle_is_secret)) + .post_async( + "/formdata-file-size", + handler!(form::handle_formdata_file_size), + ) + .get_async( + "/formdata-file-size/:hash", + handler!(form::handle_formdata_file_size_hash), + ) + .post_async("/post-file-size", handler!(request::handle_post_file_size)) + .get_async("/user/:id/test", handler!(user::handle_user_id_test)) + .get_async("/user/:id", handler!(user::handle_user_id)) + .post_async( + "/account/:id/zones", + handler!(user::handle_post_account_id_zones), + ) + .get_async( + "/account/:id/zones", + handler!(user::handle_get_account_id_zones), + ) + .post_async( + "/async-text-echo", + handler!(request::handle_async_text_echo), + ) + .get_async("/fetch", handler!(fetch::handle_fetch)) + .get_async("/fetch_json", handler!(fetch::handle_fetch_json)) + .get_async("/proxy_request/*url", handler!(fetch::handle_proxy_request)) + .get_async("/durable/alarm", handler!(alarm::handle_alarm)) + .get_async("/durable/:id", handler!(alarm::handle_id)) + .get_async("/durable/put-raw", handler!(alarm::handle_put_raw)) + .get_async("/secret", handler!(request::handle_secret)) + .get_async("/var", handler!(request::handle_var)) + .post_async("/kv/:key/:value", handler!(kv::handle_post_key_value)) + .get_async("/bytes", handler!(request::handle_bytes)) + .post_async("/api-data", handler!(request::handle_api_data)) + .post_async( + "/nonsense-repeat", + handler!(request::handle_nonsense_repeat), + ) + .get_async("/status/:code", handler!(request::handle_status)) + .put("/", handler_sync!(respond)) + .patch("/", handler_sync!(respond)) + .delete("/", handler_sync!(respond)) + .head("/", handler_sync!(respond)) + .put_async("/async", handler!(respond_async)) + .patch_async("/async", handler!(respond_async)) + .delete_async("/async", handler!(respond_async)) + .head_async("/async", handler!(respond_async)) + .options_async("/*catchall", handler!(handle_options_catchall)) + .get_async( + "/request-init-fetch", + handler!(fetch::handle_request_init_fetch), + ) + .get_async( + "/request-init-fetch-post", + handler!(fetch::handle_request_init_fetch_post), + ) + .get_async("/cancelled-fetch", handler!(fetch::handle_cancelled_fetch)) + .get_async("/fetch-timeout", handler!(fetch::handle_fetch_timeout)) + .get_async( + "/redirect-default", + handler!(request::handle_redirect_default), + ) + .get_async("/redirect-307", handler!(request::handle_redirect_307)) + .get_async("/now", handler!(request::handle_now)) + .get_async("/cloned", handler!(request::handle_cloned)) + .get_async("/cloned-stream", handler!(request::handle_cloned_stream)) + .get_async("/cloned-fetch", handler!(fetch::handle_cloned_fetch)) + .get_async("/wait/:delay", handler!(request::handle_wait_delay)) + .get_async( + "/custom-response-body", + handler!(request::handle_custom_response_body), + ) + .get_async("/init-called", handler!(handle_init_called)) + .get_async("/cache-example", handler!(cache::handle_cache_example)) + .get_async("/cache-api/get/:key", handler!(cache::handle_cache_api_get)) + .put_async("/cache-api/put/:key", handler!(cache::handle_cache_api_put)) + .post_async( + "/cache-api/delete/:key", + handler!(cache::handle_cache_api_delete), + ) + .get_async("/cache-stream", handler!(cache::handle_cache_stream)) + .get_async( + "/remote-by-request", + handler!(service::handle_remote_by_request), + ) + .get_async("/remote-by-path", handler!(service::handle_remote_by_path)) + .post_async("/queue/send/:id", handler!(queue::handle_queue_send)) + .get_async("/queue", handler!(queue::handle_queue)) + .get_async("/d1/prepared", handler!(d1::prepared_statement)) + .get_async("/d1/batch", handler!(d1::batch)) + .get_async("/d1/dump", handler!(d1::dump)) + .post_async("/d1/exec", handler!(d1::exec)) + .get_async("/d1/error", handler!(d1::error)) + .get_async("/r2/list-empty", handler!(r2::list_empty)) + .get_async("/r2/list", handler!(r2::list)) + .get_async("/r2/get-empty", handler!(r2::get_empty)) + .get_async("/r2/get", handler!(r2::get)) + .put_async("/r2/put", handler!(r2::put)) + .put_async("/r2/put-properties", handler!(r2::put_properties)) + .put_async("/r2/put-multipart", handler!(r2::put_multipart)) + .delete_async("/r2/delete", handler!(r2::delete)) + .or_else_any_method_async("/*catchall", handler!(catchall)) +} + +fn respond(req: Request, _env: Env, _data: SomeSharedData) -> Result { + Response::ok(format!("Ok: {}", String::from(req.method()))).map(|resp| { + let mut headers = Headers::new(); + headers.set("x-testing", "123").unwrap(); + resp.with_headers(headers) + }) +} + +async fn respond_async(req: Request, _env: Env, _data: SomeSharedData) -> Result { + Response::ok(format!("Ok (async): {}", String::from(req.method()))).map(|resp| { + let mut headers = Headers::new(); + headers.set("x-testing", "123").unwrap(); + resp.with_headers(headers) + }) +} + +#[worker::send] +async fn handle_close_event(_req: Request, env: Env, _data: SomeSharedData) -> Result { + let some_namespace_kv = env.kv("SOME_NAMESPACE")?; + let got_close_event = some_namespace_kv + .get("got-close-event") + .text() + .await? + .unwrap_or_else(|| "false".into()); + + // Let the integration tests have some way of knowing if we successfully received the closed event. + Response::ok(got_close_event) +} + +#[worker::send] +async fn catchall(req: Request, _env: Env, _data: SomeSharedData) -> Result { + let uri = req.url()?; + let path = uri.path(); + console_log!("[or_else_any_method_async] caught: {}", path); + + Fetch::Url("https://github.com/404".parse().unwrap()) + .send() + .await + .map(|resp| resp.with_status(404)) +} + +async fn handle_options_catchall( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let uri = req.url()?; + let path = uri.path(); + Response::ok(path) +} + +async fn handle_init_called(_req: Request, _env: Env, _data: SomeSharedData) -> Result { + let init_called = GLOBAL_STATE.load(Ordering::SeqCst); + Response::ok(init_called.to_string()) +} diff --git a/worker-sandbox/src/service.rs b/worker-sandbox/src/service.rs new file mode 100644 index 00000000..15c83f69 --- /dev/null +++ b/worker-sandbox/src/service.rs @@ -0,0 +1,46 @@ +use super::SomeSharedData; +#[cfg(feature = "http")] +use std::convert::TryInto; +use worker::{Env, Method, Request, RequestInit, Response, Result}; + +#[worker::send] +pub async fn handle_remote_by_request( + req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let fetcher = env.service("remote")?; + + #[cfg(feature = "http")] + let http_request = req.try_into()?; + #[cfg(not(feature = "http"))] + let http_request = req; + + let response = fetcher.fetch_request(http_request).await?; + + #[cfg(feature = "http")] + let result = Ok(TryInto::::try_into(response)?); + #[cfg(not(feature = "http"))] + let result = Ok(response); + + result +} + +#[worker::send] +pub async fn handle_remote_by_path( + req: Request, + env: Env, + _data: SomeSharedData, +) -> Result { + let fetcher = env.service("remote")?; + let mut init = RequestInit::new(); + init.with_method(Method::Post); + let response = fetcher.fetch(req.url()?.to_string(), Some(init)).await?; + + #[cfg(feature = "http")] + let result = Ok(TryInto::::try_into(response)?); + #[cfg(not(feature = "http"))] + let result = Ok(response); + + result +} diff --git a/worker-sandbox/src/user.rs b/worker-sandbox/src/user.rs new file mode 100644 index 00000000..6f5e2e69 --- /dev/null +++ b/worker-sandbox/src/user.rs @@ -0,0 +1,70 @@ +use serde::Serialize; +use worker::{Date, DateInit, Env, Request, Response, Result}; + +use crate::SomeSharedData; + +#[derive(Serialize)] +struct User { + id: String, + timestamp: u64, + date_from_int: String, + date_from_str: String, +} + +pub async fn handle_user_id_test( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let url = req.url()?; + let id = url.path_segments().unwrap().nth(1); + if let Some(id) = id { + return Response::ok(format!("TEST user id: {id}")); + } + + Response::error("Error", 500) +} + +pub async fn handle_user_id(req: Request, _env: Env, _data: SomeSharedData) -> Result { + let url = req.url()?; + let id = url.path_segments().unwrap().nth(1); + if let Some(id) = id { + return Response::from_json(&User { + id: id.to_string(), + timestamp: Date::now().as_millis(), + date_from_int: Date::new(DateInit::Millis(1234567890)).to_string(), + date_from_str: Date::new(DateInit::String( + "Wed Jan 14 1980 23:56:07 GMT-0700 (Mountain Standard Time)".into(), + )) + .to_string(), + }); + } + + Response::error("Bad Request", 400) +} + +pub async fn handle_post_account_id_zones( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let url = req.url()?; + let id = url.path_segments().unwrap().nth(1); + Response::ok(format!( + "Create new zone for Account: {}", + id.unwrap_or("not found") + )) +} + +pub async fn handle_get_account_id_zones( + req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let url = req.url()?; + let id = url.path_segments().unwrap().nth(1); + Response::ok(format!( + "Account id: {}..... You get a zone, you get a zone!", + id.unwrap_or("not found") + )) +} diff --git a/worker-sandbox/src/ws.rs b/worker-sandbox/src/ws.rs new file mode 100644 index 00000000..e0c3899a --- /dev/null +++ b/worker-sandbox/src/ws.rs @@ -0,0 +1,69 @@ +use super::SomeSharedData; +use futures_util::StreamExt; +use worker::{ + wasm_bindgen_futures, Env, Request, Response, Result, WebSocket, WebSocketPair, WebsocketEvent, +}; + +pub async fn handle_websocket(_req: Request, env: Env, _data: SomeSharedData) -> Result { + // Accept / handle a websocket connection + let pair = WebSocketPair::new()?; + let server = pair.server; + server.accept()?; + + let some_namespace_kv = env.kv("SOME_NAMESPACE")?; + + wasm_bindgen_futures::spawn_local(async move { + let mut event_stream = server.events().expect("could not open stream"); + + while let Some(event) = event_stream.next().await { + match event.expect("received error in websocket") { + WebsocketEvent::Message(msg) => { + if let Some(text) = msg.text() { + server.send_with_str(text).expect("could not relay text"); + } + } + WebsocketEvent::Close(_) => { + // Sets a key in a test KV so the integration tests can query if we + // actually got the close event. We can't use the shared dat a for this + // because miniflare resets that every request. + some_namespace_kv + .put("got-close-event", "true") + .unwrap() + .execute() + .await + .unwrap(); + } + } + } + }); + + Response::from_websocket(pair.client) +} + +#[worker::send] +pub async fn handle_websocket_client( + _req: Request, + _env: Env, + _data: SomeSharedData, +) -> Result { + let ws = WebSocket::connect("wss://echo.miniflare.mocks/".parse()?).await?; + + // It's important that we call this before we send our first message, otherwise we will + // not have any event listeners on the socket to receive the echoed message. + let mut event_stream = ws.events()?; + + ws.accept()?; + ws.send_with_str("Hello, world!")?; + + while let Some(event) = event_stream.next().await { + let event = event?; + + if let WebsocketEvent::Message(msg) = event { + if let Some(text) = msg.text() { + return Response::ok(text); + } + } + } + + Response::error("never got a message echoed back :(", 500) +} diff --git a/worker-sandbox/tests/request.spec.ts b/worker-sandbox/tests/request.spec.ts index ed99bdb9..6f648887 100644 --- a/worker-sandbox/tests/request.spec.ts +++ b/worker-sandbox/tests/request.spec.ts @@ -189,7 +189,7 @@ test("catchall", async () => { method: "OPTIONS", }); - expect(await resp.text()).toBe("hello-world"); + expect(await resp.text()).toBe("/hello-world"); }); test("redirect default", async () => { diff --git a/worker/src/context.rs b/worker/src/context.rs index 0b4a5773..e96bd06f 100644 --- a/worker/src/context.rs +++ b/worker/src/context.rs @@ -11,6 +11,9 @@ pub struct Context { inner: JsContext, } +unsafe impl Send for Context {} +unsafe impl Sync for Context {} + impl Context { /// Constructs a context from an underlying JavaScript context object. pub fn new(inner: JsContext) -> Self { diff --git a/worker/src/d1/macros.rs b/worker/src/d1/macros.rs index e32fa31d..2a7c7af5 100644 --- a/worker/src/d1/macros.rs +++ b/worker/src/d1/macros.rs @@ -2,7 +2,7 @@ /// /// Any parameter provided is required to implement [`serde::Serialize`] to be used. /// -/// Using [`query`] is equivalent to using db.prepare('').bind('') in Javascript. +/// Using [`query`](crate::query) is equivalent to using db.prepare('').bind('') in Javascript. /// /// # Example /// diff --git a/worker/src/env.rs b/worker/src/env.rs index 0ea2f107..ecd6daf7 100644 --- a/worker/src/env.rs +++ b/worker/src/env.rs @@ -12,6 +12,7 @@ use worker_kv::KvStore; #[wasm_bindgen] extern "C" { /// Env contains any bindings you have associated with the Worker when you uploaded it. + #[derive(Clone)] pub type Env; } diff --git a/worker/src/http/body.rs b/worker/src/http/body.rs index a0fb8492..523ae0b9 100644 --- a/worker/src/http/body.rs +++ b/worker/src/http/body.rs @@ -12,6 +12,8 @@ use futures_util::TryStreamExt; use futures_util::{stream::FusedStream, Stream, StreamExt}; use http_body::{Body as HttpBody, Frame}; use js_sys::Uint8Array; +use pin_project::pin_project; +use wasm_bindgen::JsValue; #[derive(Debug)] pub struct Body(Option>); @@ -132,3 +134,49 @@ impl Stream for Body { }) } } + +#[pin_project] +pub(crate) struct BodyStream { + #[pin] + inner: B, +} + +impl BodyStream { + pub(crate) fn new(inner: B) -> Self { + Self { inner } + } +} + +impl> Stream for BodyStream { + type Item = std::result::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let inner: Pin<&mut B> = this.inner; + inner.poll_frame(cx).map(|o| { + if let Some(r) = o { + match r { + Ok(f) => { + if f.is_data() { + // Should not be Err after checking on previous line + let b = f.into_data().unwrap(); + let array = Uint8Array::new_with_length(b.len() as _); + array.copy_from(&b); + Some(Ok(array.into())) + } else { + None + } + } + Err(_) => Some(Err(JsValue::from_str("Error polling body"))), + } + } else { + None + } + }) + } + + fn size_hint(&self) -> (usize, Option) { + let hint = self.inner.size_hint(); + (hint.lower() as usize, hint.upper().map(|u| u as usize)) + } +} diff --git a/worker/src/http/request.rs b/worker/src/http/request.rs index 59798e12..0785a11a 100644 --- a/worker/src/http/request.rs +++ b/worker/src/http/request.rs @@ -4,7 +4,9 @@ use crate::Result; use crate::{http::redirect::RequestRedirect, AbortSignal}; use worker_sys::ext::RequestExt; +use crate::http::body::BodyStream; use crate::http::header::{header_map_from_web_sys_headers, web_sys_headers_from_header_map}; +use bytes::Bytes; fn version_from_string(version: &str) -> http::Version { match version { @@ -18,7 +20,7 @@ fn version_from_string(version: &str) -> http::Version { } /// **Requires** `http` feature. Convert [`web_sys::Request`](web_sys::Request) -/// to [`worker::HttpRequest`](worker::HttpRequest) +/// to [`worker::HttpRequest`](crate::HttpRequest) pub fn from_wasm(req: web_sys::Request) -> Result> { let mut builder = http::request::Builder::new() .uri(req.url()) @@ -43,9 +45,11 @@ pub fn from_wasm(req: web_sys::Request) -> Result> { }) } -/// **Requires** `http` feature. Convert [`worker::HttpRequest`](worker::HttpRequest) +/// **Requires** `http` feature. Convert [`http::Request`](http::Request) /// to [`web_sys::Request`](web_sys::Request) -pub fn to_wasm(mut req: http::Request) -> Result { +pub fn to_wasm + 'static>( + mut req: http::Request, +) -> Result { let mut init = web_sys::RequestInit::new(); init.method(req.method().as_str()); let headers = web_sys_headers_from_header_map(req.headers())?; @@ -73,7 +77,10 @@ pub fn to_wasm(mut req: http::Request) -> Result { let _ = r; } - if let Some(readable_stream) = req.into_body().into_inner() { + let body = req.into_body(); + if !body.is_end_stream() { + let readable_stream = + wasm_streams::ReadableStream::from_stream(BodyStream::new(body)).into_raw(); init.body(Some(readable_stream.as_ref())); } diff --git a/worker/src/http/response.rs b/worker/src/http/response.rs index 1b6022db..755bcc3d 100644 --- a/worker/src/http/response.rs +++ b/worker/src/http/response.rs @@ -4,63 +4,12 @@ use crate::HttpResponse; use crate::Result; use crate::WebSocket; use bytes::Bytes; -use futures_util::Stream; -use js_sys::Uint8Array; -use pin_project::pin_project; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use wasm_bindgen::JsValue; + +use crate::http::body::BodyStream; use worker_sys::ext::ResponseExt; use worker_sys::ext::ResponseInitExt; -#[pin_project] -struct BodyStream { - #[pin] - inner: B, -} - -impl BodyStream { - fn new(inner: B) -> Self { - Self { inner } - } -} - -impl> Stream for BodyStream { - type Item = std::result::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - let inner: Pin<&mut B> = this.inner; - inner.poll_frame(cx).map(|o| { - if let Some(r) = o { - match r { - Ok(f) => { - if f.is_data() { - // Should not be Err after checking on previous line - let b = f.into_data().unwrap(); - let array = Uint8Array::new_with_length(b.len() as _); - array.copy_from(&b); - Some(Ok(array.into())) - } else { - None - } - } - Err(_) => Some(Err(JsValue::from_str("Error polling body"))), - } - } else { - None - } - }) - } - - fn size_hint(&self) -> (usize, Option) { - let hint = self.inner.size_hint(); - (hint.lower() as usize, hint.upper().map(|u| u as usize)) - } -} - -/// **Requires** `http` feature. Convert generic [`http::Response`](worker::HttpResponse) +/// **Requires** `http` feature. Convert generic [`http::Response`](crate::HttpResponse) /// to [`web_sys::Resopnse`](web_sys::Response) where `B` can be any [`http_body::Body`](http_body::Body) pub fn to_wasm(mut res: http::Response) -> Result where @@ -92,8 +41,8 @@ where )?) } -/// **Requires** `http` feature. Convert [`web_sys::Resopnse`](web_sys::Response) -/// to [`worker::HttpResponse`](worker::HttpResponse) +/// **Requires** `http` feature. Convert [`web_sys::Response`](web_sys::Response) +/// to [`worker::HttpResponse`](crate::HttpResponse) pub fn from_wasm(res: web_sys::Response) -> Result { let mut builder = http::response::Builder::new().status(http::StatusCode::from_u16(res.status())?); @@ -109,3 +58,27 @@ pub fn from_wasm(res: web_sys::Response) -> Result { builder.body(Body::empty())? }) } + +#[cfg(feature = "http")] +impl From for http::Response { + fn from(resp: crate::Response) -> http::Response { + let res: web_sys::Response = resp.into(); + let mut builder = http::response::Builder::new() + .status(http::StatusCode::from_u16(res.status()).unwrap()); + if let Some(headers) = builder.headers_mut() { + crate::http::header::header_map_from_web_sys_headers(res.headers(), headers).unwrap(); + } + if let Some(ws) = res.websocket() { + builder = builder.extension(WebSocket::from(ws)); + } + if let Some(body) = res.body() { + builder + .body(axum::body::Body::new(crate::Body::new(body))) + .unwrap() + } else { + builder + .body(axum::body::Body::new(crate::Body::empty())) + .unwrap() + } + } +} diff --git a/worker/src/lib.rs b/worker/src/lib.rs index edb0f4dd..dca00a50 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -21,11 +21,11 @@ //! ## `http` //! `worker` `0.0.21` introduced an `http` feature flag which starts to replace custom types with widely used types from the [`http`](https://docs.rs/http/latest/http/) crate. //! -//! This makes it much easier to use crates which use these standard types such as [`axum`](axum). +//! This makes it much easier to use crates which use these standard types such as [`axum`]. //! //! This currently does a few things: //! -//! 1. Introduce [`Body`](worker::Body), which implements [`http_body::Body`](http_body::Body) and is a simple wrapper around [`web_sys::ReadableStream`](web_sys::ReadableStream). +//! 1. Introduce [`Body`], which implements [`http_body::Body`] and is a simple wrapper around [`web_sys::ReadableStream`]. //! 1. The `req` argument when using the [`[event(fetch)]`](worker_macros::event) macro becomes `http::Request`. //! 1. The expected return type for the fetch handler is `http::Response` where `B` can be any [`http_body::Body`](http_body::Body). //! 1. The argument for [`Fetcher::fetch_request`](Fetcher::fetch_request) is `http::Request`. @@ -52,7 +52,55 @@ //! } //! ``` //! -//! We also implement `try_from` between `worker::Request` and `http::Request`, and between `worker::Response` and `http::Response`. This allows you to convert your code incrementally if it is tightly coupled to the original types. +//! We also implement `try_from` between `worker::Request` and `http::Request`, and between `worker::Response` and `http::Response`. +//! This allows you to convert your code incrementally if it is tightly coupled to the original types. +//! +//! ### `Send` Helpers +//! +//! A number of frameworks (including `axum`) require that objects that they are given (including route handlers) can be +//! sent between threads (i.e are marked as `Send`). Unfortuntately, objects which interact with JavaScript are frequently +//! not marked as `Send`. In the Workers environment, this is not an issue, because Workers are single threaded. There are still +//! some ergonomic difficulties which we address with some wrapper types: +//! +//! 1. [`send::SendFuture`] - wraps any `Future` and marks it as `Send`: +//! +//! ```rust +//! // `fut` is `Send` +//! let fut = send::SendFuture::new(async move { +//! // `JsFuture` is not `Send` +//! JsFuture::from(promise).await +//! }); +//! ``` +//! +//! 2. [`send::SendWrapper`] - Marks an arbitrary object as `Send` and implements `Deref` and `DerefMut`, as well as `Clone`, `Debug`, and `Display` if the +//! inner type does. This is useful for attaching types as state to an `axum` `Router`: +//! +//! ```rust +//! // `KvStore` is not `Send` +//! let store = env.kv("FOO")?; +//! // `state` is `Send` +//! let state = send::SendWrapper::new(store); +//! let router = axum::Router::new() +//! .layer(Extension(state)); +//! ``` +//! +//! 3. [`[worker::send]`](macro@crate::send) - Macro to make any `async` function `Send`. This can be a little tricky to identify as the problem, but +//! `axum`'s `[debug_handler]` macro can help, and looking for warnings that a function or object cannot safely be sent +//! between threads. +//! +//! ```rust +//! // This macro makes the whole function (i.e. the `Future` it returns) `Send`. +//! #[worker::send] +//! async fn handler(Extension(env): Extension) -> Response { +//! let kv = env.kv("FOO").unwrap()?; +//! // Holding `kv`, which is not `Send` across `await` boundary would mark this function as `!Send` +//! let value = kv.get("foo").text().await?; +//! Ok(format!("Got value: {:?}", value)); +//! } +//! +//! let router = axum::Router::new() +//! .route("/", get(handler)) +//! ``` #[doc(hidden)] use std::result::Result as StdResult; @@ -69,7 +117,7 @@ pub use wasm_bindgen_futures; pub use worker_kv as kv; pub use cf::{Cf, TlsClientAuth}; -pub use worker_macros::{durable_object, event}; +pub use worker_macros::{durable_object, event, send}; #[doc(hidden)] pub use worker_sys; pub use worker_sys::{console_debug, console_error, console_log, console_warn}; @@ -98,7 +146,6 @@ pub use crate::request::Request; pub use crate::request_init::*; pub use crate::response::{Response, ResponseBody}; pub use crate::router::{RouteContext, RouteParams, Router}; -pub use crate::schedule::*; pub use crate::socket::*; pub use crate::streams::*; pub use crate::websocket::*; @@ -130,7 +177,7 @@ mod request; mod request_init; mod response; mod router; -mod schedule; +pub mod send; mod socket; mod streams; mod websocket; diff --git a/worker/src/request.rs b/worker/src/request.rs index ebb24094..7cbfe9ae 100644 --- a/worker/src/request.rs +++ b/worker/src/request.rs @@ -24,10 +24,13 @@ pub struct Request { immutable: bool, } +unsafe impl Send for Request {} +unsafe impl Sync for Request {} + #[cfg(feature = "http")] -impl TryFrom for Request { +impl + 'static> TryFrom> for Request { type Error = crate::Error; - fn try_from(req: crate::HttpRequest) -> Result { + fn try_from(req: http::Request) -> Result { let web_request: web_sys::Request = crate::http::request::to_wasm(req)?; Ok(Request::from(web_request)) } diff --git a/worker/src/response.rs b/worker/src/response.rs index e419e241..122c4f34 100644 --- a/worker/src/response.rs +++ b/worker/src/response.rs @@ -5,6 +5,8 @@ use crate::ByteStream; use crate::Result; use crate::WebSocket; +#[cfg(feature = "http")] +use bytes::Bytes; use futures_util::{TryStream, TryStreamExt}; use js_sys::Uint8Array; use serde::{de::DeserializeOwned, Serialize}; @@ -35,9 +37,9 @@ pub struct Response { } #[cfg(feature = "http")] -impl TryFrom for Response { +impl + 'static> TryFrom> for Response { type Error = crate::Error; - fn try_from(res: crate::HttpResponse) -> Result { + fn try_from(res: http::Response) -> Result { let resp = crate::http::response::to_wasm(res)?; Ok(resp.into()) } diff --git a/worker/src/send.rs b/worker/src/send.rs new file mode 100644 index 00000000..5c6d16c7 --- /dev/null +++ b/worker/src/send.rs @@ -0,0 +1,96 @@ +//! This module provides utilities for working with JavaScript types +//! which do not implement `Send`, in contexts where `Send` is required. +//! Workers is guaranteed to be single-threaded, so it is safe to +//! wrap any type with `Send` and `Sync` traits. + +use futures_util::future::Future; +use pin_project::pin_project; +use std::fmt::Debug; +use std::fmt::Display; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +#[pin_project] +/// Wrap any future to make it `Send`. +/// +/// ```rust +/// let fut = SendFuture::new(JsFuture::from(promise)); +/// fut.await +/// ``` +pub struct SendFuture { + #[pin] + inner: F, +} + +impl SendFuture { + pub fn new(inner: F) -> Self { + Self { inner } + } +} + +unsafe impl Send for SendFuture {} + +impl Future for SendFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.inner.poll(cx) + } +} + +/// Wrap any type to make it `Send`. +/// +/// ```rust +/// // js_sys::Promise is !Send +/// let send_promise = SendWrapper::new(promise); +/// ``` +pub struct SendWrapper(pub T); + +unsafe impl Send for SendWrapper {} +unsafe impl Sync for SendWrapper {} + +impl SendWrapper { + pub fn new(inner: T) -> Self { + Self(inner) + } +} + +impl std::ops::Deref for SendWrapper { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for SendWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Debug for SendWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SendWrapper({:?})", self.0) + } +} + +impl Clone for SendWrapper { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Default for SendWrapper { + fn default() -> Self { + Self(T::default()) + } +} + +impl Display for SendWrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SendWrapper({})", self.0) + } +} diff --git a/worker/src/websocket.rs b/worker/src/websocket.rs index 8c5afec4..aeddb37d 100644 --- a/worker/src/websocket.rs +++ b/worker/src/websocket.rs @@ -1,16 +1,20 @@ -use crate::{Error, Fetch, Method, Request, Result}; +use crate::{Error, Method, Request, Result}; use futures_channel::mpsc::UnboundedReceiver; use futures_util::Stream; use serde::Serialize; use url::Url; use worker_sys::ext::WebSocketExt; +#[cfg(not(feature = "http"))] +use crate::Fetch; use std::pin::Pin; use std::rc::Rc; use std::task::{Context, Poll}; use wasm_bindgen::convert::FromWasmAbi; use wasm_bindgen::prelude::Closure; use wasm_bindgen::JsCast; +#[cfg(feature = "http")] +use wasm_bindgen_futures::JsFuture; pub use crate::ws_events::*; @@ -105,7 +109,10 @@ impl WebSocket { } } + #[cfg(not(feature = "http"))] let res = Fetch::Request(req).send().await?; + #[cfg(feature = "http")] + let res: crate::Response = fetch_with_request_raw(req).await?.into(); match res.websocket() { Some(ws) => Ok(ws), @@ -440,3 +447,15 @@ pub mod ws_events { } } } + +/// TODO: Convert WebSocket to use `http` types and `reqwest`. +#[cfg(feature = "http")] +async fn fetch_with_request_raw(request: crate::Request) -> Result { + let req = request.inner(); + let fut = { + let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into(); + crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req))) + }; + let resp = fut.await?; + Ok(resp.dyn_into()?) +}