diff --git a/Cargo.lock b/Cargo.lock index 43c855a5..9016602a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2093,6 +2093,7 @@ name = "worker-sandbox" version = "0.1.0" dependencies = [ "blake2", + "bytes", "cfg-if", "chrono", "console_error_panic_hook", diff --git a/package.json b/package.json index 35d45efb..97bfc9f3 100644 --- a/package.json +++ b/package.json @@ -22,6 +22,6 @@ "vitest": "^0.32.0" }, "scripts": { - "test": "" + "test": "cd worker-sandbox && worker-build --dev && vitest" } } diff --git a/worker-sandbox/Cargo.toml b/worker-sandbox/Cargo.toml index e2815c61..6fa76ba8 100644 --- a/worker-sandbox/Cargo.toml +++ b/worker-sandbox/Cargo.toml @@ -14,6 +14,7 @@ default = ["console_error_panic_hook"] [dependencies] blake2 = "0.10.6" +bytes = "1.4.0" chrono = { version = "0.4.26", default-features = false, features = [ "wasmbind", "clock", @@ -26,11 +27,11 @@ http = "0.2.9" regex = "1.8.4" serde = { version = "1.0.164", features = ["derive"] } serde_json = "1.0.96" -worker = { path = "../worker", version = "0.0.17", features= ["queue", "d1"] } +worker = { path = "../worker", version = "0.0.17", features = ["queue", "d1"] } futures-channel = "0.3.28" futures-util = { version = "0.3.28", default-features = false } rand = "0.8.5" -uuid = {version = "1.3.3", features = ["v4", "serde"]} +uuid = { version = "1.3.3", features = ["v4", "serde"] } serde-wasm-bindgen = "0.5.0" md5 = "0.7.0" router-service = { git = "https://github.com/zebp/router-service", version = "0.1.0" } @@ -38,7 +39,9 @@ tower = "0.4.13" [dev-dependencies] futures-channel = { version = "0.3.28", features = ["sink"] } -futures-util = { version = "0.3.28", default-features = false, features = ["sink"] } +futures-util = { version = "0.3.28", default-features = false, features = [ + "sink", +] } reqwest = { version = "0.11.18", features = [ "blocking", "json", diff --git a/worker-sandbox/src/d1.rs b/worker-sandbox/src/d1.rs index 96f992cf..58268777 100644 --- a/worker-sandbox/src/d1.rs +++ b/worker-sandbox/src/d1.rs @@ -8,10 +8,7 @@ struct Person { age: u32, } -pub async fn prepared_statement( - _req: http::Request, - env: Env, -) -> Result> { +pub async fn prepared_statement(env: &Env) -> Result> { let db = env.d1("DB")?; let stmt = worker::query!(&db, "SELECT * FROM people WHERE name = ?", "Ryan Upton")?; @@ -47,7 +44,7 @@ pub async fn prepared_statement( Ok(http::Response::new("ok".into())) } -pub async fn batch(_req: http::Request, env: Env) -> Result> { +pub async fn batch(env: &Env) -> Result> { let db = env.d1("DB")?; let mut results = db .batch(vec![ @@ -71,7 +68,7 @@ pub async fn batch(_req: http::Request, env: Env) -> Result, env: Env) -> Result> { +pub async fn exec(req: http::Request, env: &Env) -> Result> { let db = env.d1("DB")?; let result = db @@ -84,13 +81,13 @@ pub async fn exec(req: http::Request, env: Env) -> Result, env: Env) -> Result> { +pub async fn dump(env: &Env) -> Result> { let db = env.d1("DB")?; let bytes = db.dump().await?; Ok(http::Response::new(bytes.into())) } -pub async fn error(_req: http::Request, env: Env) -> Result> { +pub async fn error(env: &Env) -> Result> { let db = env.d1("DB")?; let error = db .exec("THIS IS NOT VALID SQL") @@ -98,7 +95,10 @@ pub async fn error(_req: http::Request, env: Env) -> Result { + if let Some(text) = msg.text() { + server.send_with_str(text).expect("could not relay text"); + } + } + WebsocketEvent::Close(_) => { + // Sets a key in a test KV so the integration tests can query if we + // actually got the close event. We can't use the shared dat a for this + // because miniflare resets that every request. + some_namespace_kv + .put("got-close-event", "true") + .unwrap() + .execute() + .await + .unwrap(); + } + } + } + }); + + let mut response = Response::builder() + .status(101) + .body(Body::empty()) + .unwrap(); + + response.extensions_mut().insert(pair.client); + + Ok(response) + }) + .post("/xor/:num", |req, ctx| async move { + let num: u8 = match ctx.param("num").unwrap().parse() { + Ok(num) => num, + Err(_) => return Response::builder() + .status(400) + .body("invalid byte".into()) + .map_err(|e| Error::RustError(e.to_string())) + }; + + let xor_stream = req.into_body().into_stream().map_ok(move |buf| { + let mut vec = buf.to_vec(); + vec.iter_mut().for_each(|x| *x ^= num); + Bytes::from(vec) + }); + + let body = worker::body::Body::from_stream(xor_stream)?; + let resp = Response::builder() + .body(body) + .unwrap(); + Ok(resp) + }) + .get("/request-init-fetch", |_, _| async move { + let init = RequestInit::new(); + let req = http::Request::post("https://cloudflare.com").body(()).unwrap(); + fetch_with_init(req, &init).await + }) + .get("/request-init-fetch-post", |_, _| async move { + let mut init = RequestInit::new(); + init.method = Method::POST; + + let req = http::Request::post("https://httpbin.org/post").body(()).unwrap(); + fetch_with_init(req, &init).await + }) + .get("/cancelled-fetch", |_, _| async move { + let controller = AbortController::default(); + let signal = controller.signal(); + + let (tx, rx) = futures_channel::oneshot::channel(); + + // Spawns a future that'll make our fetch request and not block this function. + wasm_bindgen_futures::spawn_local({ + async move { + let req = http::Request::post("https://cloudflare.com").body(()).unwrap(); + let resp = fetch_with_signal(req, &signal).await; + tx.send(resp).unwrap(); + } + }); + + // And then we try to abort that fetch as soon as we start it, hopefully before + // cloudflare.com responds. + controller.abort(); + + let res = rx.await.unwrap(); + let res = res.unwrap_or_else(|err| { + let text = err.to_string(); + Response::new(text.into()) + }); + + Ok(res) + }) + .get("/fetch-timeout", |_, _| async move { + let controller = AbortController::default(); + let signal = controller.signal(); + + let fetch_fut = async { + let req = http::Request::post("https://miniflare.mocks/delay").body(()).unwrap(); + let resp = fetch_with_signal(req, &signal).await?; + let text = resp.into_body().text().await?; + Ok::(text) + }; + let delay_fut = async { + Delay::from(Duration::from_millis(1)).await; + controller.abort(); + Ok(Response::new("Cancelled".into())) + }; + + futures_util::pin_mut!(fetch_fut); + futures_util::pin_mut!(delay_fut); + + match futures_util::future::select(delay_fut, fetch_fut).await { + Either::Left((res, cancelled_fut)) => { + // Ensure that the cancelled future returns an AbortError. + match cancelled_fut.await { + Err(e) if e.to_string().contains("AbortError") => { /* Yay! It worked, let's do nothing to celebrate */}, + Err(e) => panic!("Fetch errored with a different error than expected: {:#?}", e), + Ok(text) => panic!("Fetch unexpectedly succeeded: {}", text) + } + + res + }, + Either::Right(_) => panic!("Delay future should have resolved first"), + } + }) + .get("/redirect-default", |_, _| async move { + Ok(Response::builder() + .status(302) + .header("Location", "https://example.com/") + .body(Body::empty()).unwrap()) + }) + .get("/redirect-307", |_, _| async move { + Ok(Response::builder() + .status(307) + .header("Location", "https://example.com/") + .body(Body::empty()).unwrap()) + }) + .get("/now", |_, _| async move { + let now = chrono::Utc::now(); + let js_date: Date = now.into(); + Ok(Response::new(js_date.to_string().into())) + }) + .get("/custom-response-body", |_, _| async move { + Ok(Response::new(vec![b'h', b'e', b'l', b'l', b'o'].into())) + }) + .get("/init-called", |_, _| async move { + let init_called = GLOBAL_STATE.load(Ordering::SeqCst); + Ok(Response::new(init_called.to_string().into())) + }) + .get("/cache-example", |req, _| async move { + //console_log!("url: {}", req.uri().to_string()); + let cache = Cache::default(); + let key = req.uri().to_string(); + if let Some(resp) = cache.get(&key, true).await? { + //console_log!("Cache HIT!"); + Ok(resp) + } else { + //console_log!("Cache MISS!"); + + let mut resp = Response::builder() + .header("content-type", "application/json") + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + .header("cache-control", "s-maxage=10") + .body(serde_json::json!({ "timestamp": Date::now().as_millis() }).to_string().into()) + .map_err(|e| Error::RustError(e.to_string())) + .unwrap(); + + cache.put(key, resp.clone()).await?; + Ok(resp) + } + }) + .get("/cache-api/get/:key", |_req, ctx| async move { + if let Some(key) = ctx.param("key") { + let cache = Cache::default(); + if let Some(resp) = cache.get(format!("https://{key}"), true).await? { + return Ok(resp); + } else { + return Ok(Response::new("cache miss".into())); + } + } + + Response::builder() + .status(400) + .body("key missing".into()) + .map_err(|e| Error::RustError(e.to_string())) + }) + .put("/cache-api/put/:key", |_req, ctx| async move { + if let Some(key) = ctx.param("key") { + let cache = Cache::default(); + + let mut resp = Response::builder() + .header("content-type", "application/json") + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + .header("cache-control", "s-maxage=10") + .body(serde_json::json!({ "timestamp": Date::now().as_millis() }).to_string().into()) + .map_err(|e| Error::RustError(e.to_string())) + .unwrap(); + + cache.put(format!("https://{key}"), resp.clone()).await?; + return Ok(resp); + } + + Response::builder() + .status(400) + .body("key missing".into()) + .map_err(|e| Error::RustError(e.to_string())) + }) + .post("/cache-api/delete/:key", |_req, ctx| async move { + if let Some(key) = ctx.param("key") { + let cache = Cache::default(); + + let res = cache.delete(format!("https://{key}"), true).await?; + return Ok(Response::new(serde_json::to_string(&res)?.into())); + } + + Response::builder() + .status(400) + .body("key missing".into()) + .map_err(|e| Error::RustError(e.to_string())) + }) + .get("/cache-stream", |req, _| async move { + //console_log!("url: {}", req.uri().to_string()); + let cache = Cache::default(); + let key = req.uri().to_string(); + if let Some(resp) = cache.get(&key, true).await? { + //console_log!("Cache HIT!"); + Ok(resp) + } else { + //console_log!("Cache MISS!"); + let mut rng = rand::thread_rng(); + let count = rng.gen_range(0..10); + let stream = futures_util::stream::repeat("Hello, world!\n") + .take(count) + .then(|text| async move { + Delay::from(Duration::from_millis(50)).await; + Result::Ok(text.as_bytes().to_vec()) + }); + + let body = worker::body::Body::from_stream(stream)?; + + //console_log!("resp = {:?}", resp); + + let mut resp = Response::builder() + // Cache API respects Cache-Control headers. Setting s-max-age to 10 + // will limit the response to be in cache for 10 seconds max + .header("cache-control", "s-maxage=10") + .body(body) + .unwrap(); + + cache.put(key, resp.clone()).await?; + Ok(resp) + } + }) + .get("/remote-by-request", |req, ctx| async move { + let fetcher = ctx.data.service("remote")?; + fetcher.fetch_request(req).await + }) + .get("/remote-by-path", |req, ctx| async move { + let fetcher = ctx.data.service("remote")?; + let mut init = RequestInit::new(); + init.with_method(Method::POST); + + fetcher.fetch(req.uri().to_string(), Some(init)).await + }) + .post("/queue/send/:id", |_req, ctx| async move { + let id = match ctx.param("id").map(|id| Uuid::try_parse(id).ok()).and_then(|u|u) { + Some(id) => id, + None => { + return Response::builder() + .status(400) + .body("error".into()) + .map_err(|_| Error::RustError("Failed to parse id, expected a UUID".into())); + } + }; + let my_queue = match ctx.data.queue("my_queue") { + Ok(queue) => queue, + Err(err) => { + return Response::builder() + .status(500) + .body(format!("Failed to get queue: {err:?}").into()) + .map_err(|e| Error::RustError(e.to_string())); + } + }; + match my_queue.send(&QueueBody { + id: id.to_string(), + }).await { + Ok(_) => { + Ok(Response::new("Message sent".into())) + } + Err(err) => { + Response::builder() + .status(500) + .body(format!("Failed to send message to queue: {err:?}").into()) + .map_err(|e| Error::RustError(e.to_string())) + } + } + }) + .get("/queue", |_req, _ctx| async move { + let guard = GLOBAL_QUEUE_STATE.lock().unwrap(); + let messages: Vec = guard.clone(); + let json = serde_json::to_string(&messages).unwrap(); + Ok(Response::new(Body::from(json))) + }) + .get("/d1/prepared", |_, ctx| async move { + d1::prepared_statement(ctx.data.as_ref()).await + }) + .get("/d1/batch", |_, ctx| async move { + d1::batch(ctx.data.as_ref()).await + }) + .get("/d1/dump", |_, ctx| async move { + d1::dump(ctx.data.as_ref()).await + }) + .post("/d1/exec", |req, ctx| async move { + d1::exec(req, ctx.data.as_ref()).await + }) + .get("/d1/error", |_, ctx| async move { + d1::error(ctx.data.as_ref()).await + }) + .get("/r2/list-empty", |_, ctx| async move { + r2::list_empty(ctx.data.as_ref()).await + }) + .get("/r2/list", |_, ctx| async move { + r2::list(ctx.data.as_ref()).await + }) + .get("/r2/get-empty", |_, ctx| async move { + r2::get_empty(ctx.data.as_ref()).await + }) + .get("/r2/get", |_, ctx| async move { + r2::get(ctx.data.as_ref()).await + }) + .put("/r2/put", |_, ctx| async move { + r2::put(ctx.data.as_ref()).await + }) + .put("/r2/put-properties", |_, ctx| async move { + r2::put_properties(ctx.data.as_ref()).await + }) + .put("/r2/put-multipart", |_, ctx| async move { + r2::put_multipart(ctx.data.as_ref()).await + }) + .delete("/r2/delete", |_, ctx| async move { + r2::delete(ctx.data.as_ref()).await + }) .any("/*catchall", |_, ctx| async move { Ok(Response::builder() .status(404) @@ -411,12 +770,12 @@ pub struct QueueBody { pub async fn queue(message_batch: MessageBatch, _env: Env, _ctx: Context) -> Result<()> { let mut guard = GLOBAL_QUEUE_STATE.lock().unwrap(); for message in message_batch.messages()? { - console_log!( + /*console_log!( "Received queue message {:?}, with id {} and timestamp: {}", message.body, message.id, message.timestamp.to_string() - ); + );*/ guard.push(message.body); } Ok(()) diff --git a/worker-sandbox/tests/d1.spec.ts b/worker-sandbox/tests/d1.spec.ts index 5e426a52..1a7849d1 100644 --- a/worker-sandbox/tests/d1.spec.ts +++ b/worker-sandbox/tests/d1.spec.ts @@ -1,11 +1,8 @@ import { describe, test, expect, beforeAll } from "vitest"; - -const hasLocalDevServer = await fetch("http://localhost:8787/request") - .then((resp) => resp.ok) - .catch(() => false); +import { mf } from "./mf"; async function exec(query: string): Promise { - const resp = await fetch("http://localhost:8787/d1/exec", { + const resp = await mf.dispatchFetch("https://fake.host/d1/exec", { method: "POST", body: query.split("\n").join(""), }); @@ -15,7 +12,7 @@ async function exec(query: string): Promise { return Number(body); } -describe.skipIf(!hasLocalDevServer)("d1", () => { +describe("d1", () => { test("create table", async () => { const query = `CREATE TABLE IF NOT EXISTS uniqueTable ( id INTEGER PRIMARY KEY, @@ -49,22 +46,22 @@ describe.skipIf(!hasLocalDevServer)("d1", () => { }); test("prepared statement", async () => { - const resp = await fetch("http://localhost:8787/d1/prepared"); + const resp = await mf.dispatchFetch("https://fake.host/d1/prepared"); expect(resp.status).toBe(200); }); test("batch", async () => { - const resp = await fetch("http://localhost:8787/d1/batch"); + const resp = await mf.dispatchFetch("https://fake.host/d1/batch"); expect(resp.status).toBe(200); }); test("dump", async () => { - const resp = await fetch("http://localhost:8787/d1/dump"); + const resp = await mf.dispatchFetch("https://fake.host/d1/dump"); expect(resp.status).toBe(200); }); - test("dump", async () => { - const resp = await fetch("http://localhost:8787/d1/error"); + test("error", async () => { + const resp = await mf.dispatchFetch("https://fake.host/d1/error"); expect(resp.status).toBe(200); }); }); diff --git a/worker-sandbox/tests/mf.ts b/worker-sandbox/tests/mf.ts index a2560212..0c6bc64d 100644 --- a/worker-sandbox/tests/mf.ts +++ b/worker-sandbox/tests/mf.ts @@ -6,6 +6,7 @@ export const mf = new Miniflare({ cache: true, cachePersist: false, d1Persist: false, + d1Databases: ["DB"], kvPersist: false, r2Persist: false, modules: true, diff --git a/worker-sandbox/tests/request.spec.ts b/worker-sandbox/tests/request.spec.ts index 56a24a0c..96f85b62 100644 --- a/worker-sandbox/tests/request.spec.ts +++ b/worker-sandbox/tests/request.spec.ts @@ -118,7 +118,7 @@ test("catchall", async () => { method: "OPTIONS", }); - expect(await resp.text()).toBe("/hello-world"); + expect(await resp.text()).toBe("hello-world"); }); test("redirect default", async () => { diff --git a/worker-sandbox/tests/subrequest.spec.ts b/worker-sandbox/tests/subrequest.spec.ts index 83e1ec2c..ffe5aee2 100644 --- a/worker-sandbox/tests/subrequest.spec.ts +++ b/worker-sandbox/tests/subrequest.spec.ts @@ -17,7 +17,7 @@ describe("subrequest", () => { expect(await resp.text()).toBe("Cancelled"); }); - test.skip("request init fetch post", async () => { + test("request init fetch post", async () => { const resp = await mf.dispatchFetch( "https://fake.host/request-init-fetch-post" ); diff --git a/worker-sys/src/types/incoming_request_cf_properties.rs b/worker-sys/src/types/incoming_request_cf_properties.rs index 7e745422..057492db 100644 --- a/worker-sys/src/types/incoming_request_cf_properties.rs +++ b/worker-sys/src/types/incoming_request_cf_properties.rs @@ -14,6 +14,9 @@ extern "C" { #[wasm_bindgen(method, getter)] pub fn asn(this: &IncomingRequestCfProperties) -> u32; + #[wasm_bindgen(method, getter, js_name=asOrganization)] + pub fn as_organization(this: &IncomingRequestCfProperties) -> String; + #[wasm_bindgen(method, getter)] pub fn country(this: &IncomingRequestCfProperties) -> Option; diff --git a/worker/src/body/body.rs b/worker/src/body/body.rs index 7718e133..2b0c45ef 100644 --- a/worker/src/body/body.rs +++ b/worker/src/body/body.rs @@ -9,9 +9,12 @@ use crate::{ Error, }; use bytes::Bytes; -use futures_util::{AsyncRead, Stream}; +use futures_util::{AsyncRead, Stream, TryStream, TryStreamExt}; use http::HeaderMap; +use js_sys::Uint8Array; use serde::de::DeserializeOwned; +use wasm_bindgen::{JsCast, JsValue}; +use web_sys::ReadableStream; type BoxBody = http_body::combinators::UnsyncBoxBody; @@ -189,6 +192,31 @@ impl Body { crate::body::BodyInner::None => None, } } + + /// Create a `Body` using a [`Stream`](futures::stream::Stream) + pub fn from_stream(stream: S) -> Result + where + S: TryStream + 'static, + S::Ok: Into>, + S::Error: Into, + { + let js_stream = stream + .map_ok(|item| -> Vec { item.into() }) + .map_ok(|chunk| { + let array = Uint8Array::new_with_length(chunk.len() as _); + array.copy_from(&chunk); + + array.into() + }) + .map_err(|err| -> crate::Error { err.into() }) + .map_err(|e| JsValue::from(e.to_string())); + + let stream = wasm_streams::ReadableStream::from_stream(js_stream); + let stream: ReadableStream = stream.into_raw().dyn_into().unwrap(); + + let edge_res = web_sys::Response::new_with_opt_readable_stream(Some(&stream))?; + Ok(Self::from(edge_res)) + } } impl Default for Body { diff --git a/worker/src/cf.rs b/worker/src/cf.rs index 41fe5a04..8de0054e 100644 --- a/worker/src/cf.rs +++ b/worker/src/cf.rs @@ -1,7 +1,3 @@ -mod properties; - -pub use properties::{CfProperties, MinifyConfig, PolishConfig}; - /// In addition to the methods on the `Request` struct, the `Cf` struct on an inbound Request contains information about the request provided by Cloudflare’s edge. /// /// [Details](https://developers.cloudflare.com/workers/runtime-apis/request#incomingrequestcfproperties) @@ -18,6 +14,10 @@ impl Cf { Self { inner } } + pub fn inner(&self) -> &worker_sys::IncomingRequestCfProperties { + &self.inner + } + /// The three-letter airport code (e.g. `ATX`, `LUX`) representing /// the colocation which processed the request pub fn colo(&self) -> String { @@ -29,6 +29,11 @@ impl Cf { self.inner.asn() } + /// The Autonomous System organization name of the request, e.g. `Cloudflare, Inc.` + pub fn as_organization(&self) -> String { + self.inner.as_organization() + } + /// The two-letter country code of origin for the request. /// This is the same value as that provided in the CF-IPCountry header, e.g. `"US"` pub fn country(&self) -> Option { @@ -89,9 +94,7 @@ impl Cf { /// Information about the client's authorization. /// Only set when using Cloudflare Access or API Shield. pub fn tls_client_auth(&self) -> Option { - self.inner - .tls_client_auth() - .map(|inner| TlsClientAuth { inner }) + self.inner.tls_client_auth().map(Into::into) } /// The TLS version of the connection to Cloudflare, e.g. TLSv1.3. @@ -178,6 +181,12 @@ pub struct RequestPriority { pub group_weight: usize, } +impl From for Cf { + fn from(inner: worker_sys::IncomingRequestCfProperties) -> Self { + Self { inner } + } +} + /// Only set when using Cloudflare Access or API Shield #[derive(Debug)] pub struct TlsClientAuth { @@ -236,3 +245,9 @@ impl TlsClientAuth { self.inner.cert_subject_dn_rfc2253() } } + +impl From for TlsClientAuth { + fn from(inner: worker_sys::TlsClientAuth) -> Self { + Self { inner } + } +} diff --git a/worker/src/durable.rs b/worker/src/durable.rs index 3ecc2475..cf119ca4 100644 --- a/worker/src/durable.rs +++ b/worker/src/durable.rs @@ -27,7 +27,7 @@ use chrono::{DateTime, Utc}; use futures_util::Future; use js_sys::{Map, Number, Object}; use serde::{de::DeserializeOwned, Serialize}; -use wasm_bindgen::{prelude::*, JsCast}; +use wasm_bindgen::prelude::*; use wasm_bindgen_futures::future_to_promise; use worker_sys::{ DurableObject as EdgeDurableObject, DurableObjectId, diff --git a/worker/src/env.rs b/worker/src/env.rs index 57607f81..3a67ead8 100644 --- a/worker/src/env.rs +++ b/worker/src/env.rs @@ -6,7 +6,7 @@ use crate::Queue; use crate::{durable::ObjectNamespace, Bucket, DynamicDispatcher, Fetcher, Result}; use js_sys::Object; -use wasm_bindgen::{prelude::*, JsCast, JsValue}; +use wasm_bindgen::prelude::*; use worker_kv::KvStore; #[wasm_bindgen] diff --git a/worker/src/fetch.rs b/worker/src/fetch.rs index 12e9a7aa..22588d3d 100644 --- a/worker/src/fetch.rs +++ b/worker/src/fetch.rs @@ -5,7 +5,7 @@ use crate::{ body::Body, futures::SendJsFuture, http::{request, response}, - Error, Result, + AbortSignal, Error, RequestInit, Result, }; /// Fetch a resource from the network. @@ -46,6 +46,47 @@ pub async fn fetch(req: http::Request>) -> Result>, + init: &RequestInit, +) -> Result> { + let fut = { + let req = req.map(Into::into); + let global = js_sys::global().unchecked_into::(); + + let req = request::into_wasm(req); + let promise = global.fetch_with_request_and_init(&req, &init.into()); + + SendJsFuture::from(promise) + }; + + fut.await + .map(|res| response::from_wasm(res.unchecked_into())) + .map_err(Error::from) +} + +pub async fn fetch_with_signal( + req: http::Request>, + signal: &AbortSignal, +) -> Result> { + let mut init = web_sys::RequestInit::new(); + init.signal(Some(signal.inner())); + + let fut = { + let req = req.map(Into::into); + let global = js_sys::global().unchecked_into::(); + + let req = request::into_wasm(req); + let promise = global.fetch_with_request_and_init(&req, &init); + + SendJsFuture::from(promise) + }; + + fut.await + .map(|res| response::from_wasm(res.unchecked_into())) + .map_err(Error::from) +} + fn _assert_send() { use crate::futures::assert_send_value; assert_send_value(fetch(http::Request::new(()))); diff --git a/worker/src/fetcher.rs b/worker/src/fetcher.rs index 533d4f15..06fd0742 100644 --- a/worker/src/fetcher.rs +++ b/worker/src/fetcher.rs @@ -5,7 +5,7 @@ use crate::{ env::EnvBinding, futures::SendJsFuture, http::{request, response}, - Result, + RequestInit, Result, }; /// A struct for invoking fetch events to other Workers. @@ -16,7 +16,27 @@ unsafe impl Sync for Fetcher {} impl Fetcher { /// Invoke a fetch event in a worker with a url and optionally a [RequestInit]. - pub async fn fetch(&self, req: http::Request) -> Result> { + pub async fn fetch( + &self, + url: impl Into, + init: Option, + ) -> Result> { + let path = url.into(); + let fut = { + let promise = match init { + Some(ref init) => self.0.fetch_with_str_and_init(&path, &init.into()), + None => self.0.fetch_with_str(&path), + }; + + SendJsFuture::from(promise) + }; + + let res = fut.await?.dyn_into()?; + Ok(response::from_wasm(res)) + } + + /// Invoke a fetch event with an existing [Request]. + pub async fn fetch_request(&self, req: http::Request) -> Result> { let fut = { let req = request::into_wasm(req); let promise = self.0.fetch(&req); diff --git a/worker/src/headers.rs b/worker/src/headers.rs new file mode 100644 index 00000000..1d97ab61 --- /dev/null +++ b/worker/src/headers.rs @@ -0,0 +1,183 @@ +use crate::{error::Error, Result}; + +use std::{ + iter::{FromIterator, Map}, + result::Result as StdResult, + str::FromStr, +}; + +use http::{header::HeaderName, HeaderMap, HeaderValue}; +use js_sys::Array; +use wasm_bindgen::JsValue; +use worker_sys::ext::HeadersExt; + +/// A [Headers](https://developer.mozilla.org/en-US/docs/Web/API/Headers) representation used in +/// Request and Response objects. +pub struct Headers(pub web_sys::Headers); + +impl std::fmt::Debug for Headers { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Headers {\n")?; + for (k, v) in self.entries() { + f.write_str(&format!("{k} = {v}\n"))?; + } + f.write_str("}\n") + } +} + +impl Headers { + /// Construct a new `Headers` struct. + pub fn new() -> Self { + Default::default() + } + + /// Returns all the values of a header within a `Headers` object with a given name. + /// Returns an error if the name is invalid (e.g. contains spaces) + pub fn get(&self, name: &str) -> Result> { + self.0.get(name).map_err(Error::from) + } + + /// Returns a boolean stating whether a `Headers` object contains a certain header. + /// Returns an error if the name is invalid (e.g. contains spaces) + pub fn has(&self, name: &str) -> Result { + self.0.has(name).map_err(Error::from) + } + + /// Returns an error if the name is invalid (e.g. contains spaces) + pub fn append(&mut self, name: &str, value: &str) -> Result<()> { + self.0.append(name, value).map_err(Error::from) + } + + /// Sets a new value for an existing header inside a `Headers` object, or adds the header if it does not already exist. + /// Returns an error if the name is invalid (e.g. contains spaces) + pub fn set(&mut self, name: &str, value: &str) -> Result<()> { + self.0.set(name, value).map_err(Error::from) + } + + /// Deletes a header from a `Headers` object. + /// Returns an error if the name is invalid (e.g. contains spaces) + /// or if the JS Headers object's guard is immutable (e.g. for an incoming request) + pub fn delete(&mut self, name: &str) -> Result<()> { + self.0.delete(name).map_err(Error::from) + } + + /// Returns an iterator allowing to go through all key/value pairs contained in this object. + pub fn entries(&self) -> HeaderIterator { + self.0 + .entries() + .into_iter() + // The entries iterator.next() will always return a proper value: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Iteration_protocols + .map((|a| a.unwrap().into()) as F1) + // The entries iterator always returns an array[2] of strings + .map(|a: Array| (a.get(0).as_string().unwrap(), a.get(1).as_string().unwrap())) + } + + /// Returns an iterator allowing you to go through all keys of the key/value pairs contained in + /// this object. + pub fn keys(&self) -> impl Iterator { + js_sys::Object::keys(&self.0) + .into_iter() + // The keys iterator.next() will always return a proper value containing a string + .map(|a| a.as_string().unwrap()) + } + + /// Returns an iterator allowing you to go through all values of the key/value pairs contained + /// in this object. + pub fn values(&self) -> impl Iterator { + js_sys::Object::values(&self.0) + .into_iter() + // The values iterator.next() will always return a proper value containing a string + .map(|a| a.as_string().unwrap()) + } +} + +impl Default for Headers { + fn default() -> Self { + // This cannot throw an error: https://developer.mozilla.org/en-US/docs/Web/API/Headers/Headers + Headers(web_sys::Headers::new().unwrap()) + } +} + +type F1 = fn(StdResult) -> Array; +type HeaderIterator = Map, fn(Array) -> (String, String)>; + +impl IntoIterator for &Headers { + type Item = (String, String); + + type IntoIter = HeaderIterator; + + fn into_iter(self) -> Self::IntoIter { + self.entries() + } +} + +impl> FromIterator<(T, T)> for Headers { + fn from_iter>(iter: U) -> Self { + let mut headers = Headers::new(); + iter.into_iter().for_each(|(name, value)| { + headers.append(name.as_ref(), value.as_ref()).ok(); + }); + headers + } +} + +impl<'a, T: AsRef> FromIterator<&'a (T, T)> for Headers { + fn from_iter>(iter: U) -> Self { + let mut headers = Headers::new(); + iter.into_iter().for_each(|(name, value)| { + headers.append(name.as_ref(), value.as_ref()).ok(); + }); + headers + } +} + +impl AsRef for Headers { + fn as_ref(&self) -> &JsValue { + &self.0 + } +} + +impl From<&HeaderMap> for Headers { + fn from(map: &HeaderMap) -> Self { + map.keys() + .flat_map(|name| { + map.get_all(name) + .into_iter() + .map(move |value| (name.to_string(), value.to_str().unwrap().to_owned())) + }) + .collect() + } +} + +impl From for Headers { + fn from(map: HeaderMap) -> Self { + (&map).into() + } +} + +impl From<&Headers> for HeaderMap { + fn from(headers: &Headers) -> Self { + headers + .into_iter() + .map(|(name, value)| { + ( + HeaderName::from_str(&name).unwrap(), + HeaderValue::from_str(&value).unwrap(), + ) + }) + .collect() + } +} + +impl From for HeaderMap { + fn from(headers: Headers) -> Self { + (&headers).into() + } +} + +impl Clone for Headers { + fn clone(&self) -> Self { + // Headers constructor doesn't throw an error + Headers(web_sys::Headers::new_with_headers(&self.0).unwrap()) + } +} diff --git a/worker/src/http/request.rs b/worker/src/http/request.rs index a789b25e..8b89611d 100644 --- a/worker/src/http/request.rs +++ b/worker/src/http/request.rs @@ -3,7 +3,7 @@ use wasm_bindgen::JsCast; use worker_sys::ext::{HeadersExt, RequestExt}; -use crate::{AbortSignal, Cf, CfProperties}; +use crate::{AbortSignal, Cf}; use crate::body::Body; @@ -113,12 +113,12 @@ pub fn into_wasm(mut req: http::Request) -> web_sys::Request { init.redirect(redirect.into()); } - if let Some(cf) = req.extensions_mut().remove::() { + if let Some(cf) = req.extensions_mut().remove::() { // TODO: this should be handled in worker-sys let r = ::js_sys::Reflect::set( init.as_ref(), &wasm_bindgen::JsValue::from("cf"), - &wasm_bindgen::JsValue::from(&cf), + &wasm_bindgen::JsValue::from(cf.inner()), ); debug_assert!( r.is_ok(), diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 95f201ad..ae104db2 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -33,11 +33,14 @@ pub use crate::dynamic_dispatch::*; pub use crate::env::{Env, EnvBinding, Secret, Var}; pub use crate::error::Error; pub use crate::fetch::fetch; +pub use crate::fetch::fetch_with_init; +pub use crate::fetch::fetch_with_signal; pub use crate::fetcher::Fetcher; // pub use crate::futures::spawn_local; #[cfg(feature = "queue")] pub use crate::queue::*; pub use crate::r2::*; +pub use crate::request_init::RequestInit; pub use crate::schedule::*; pub use crate::socket::*; pub use crate::streams::*; @@ -59,10 +62,12 @@ mod error; mod fetch; mod fetcher; mod futures; +mod headers; pub mod http; #[cfg(feature = "queue")] mod queue; mod r2; +mod request_init; mod schedule; mod socket; mod streams; diff --git a/worker/src/queue.rs b/worker/src/queue.rs index 60dce091..8de04584 100644 --- a/worker/src/queue.rs +++ b/worker/src/queue.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use crate::{env::EnvBinding, futures::SendJsFuture, Date, Error, Result}; use js_sys::Array; use serde::{de::DeserializeOwned, Serialize}; -use wasm_bindgen::{prelude::*, JsCast}; +use wasm_bindgen::prelude::*; use worker_sys::{MessageBatch as MessageBatchSys, Queue as EdgeQueue}; static BODY_KEY_STR: &str = "body"; diff --git a/worker/src/cf/properties.rs b/worker/src/request_init.rs similarity index 68% rename from worker/src/cf/properties.rs rename to worker/src/request_init.rs index 99f5ab4e..26aa418a 100644 --- a/worker/src/cf/properties.rs +++ b/worker/src/request_init.rs @@ -1,12 +1,97 @@ -// TODO: the worker-sys crate should contain the JS bindings rather than doing it in here - use std::collections::HashMap; +use crate::headers::Headers; +use crate::http::Method; + use js_sys::Object; -use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; +use serde::Serialize; +use wasm_bindgen::prelude::*; + +/// Optional options struct that contains settings to apply to the `Request`. +pub struct RequestInit { + /// Currently requires a manual conversion from your data into a [`wasm_bindgen::JsValue`]. + pub body: Option, + /// Headers associated with the outbound `Request`. + pub headers: Headers, + /// Cloudflare-specific properties that can be set on the `Request` that control how Cloudflare’s + /// edge handles the request. + pub cf: CfProperties, + /// The HTTP Method used for this `Request`. + pub method: Method, + /// The redirect mode to use: follow, error, or manual. The default for a new Request object is + /// follow. Note, however, that the incoming Request property of a FetchEvent will have redirect + /// mode manual. + pub redirect: RequestRedirect, +} + +impl RequestInit { + pub fn new() -> Self { + Default::default() + } + + pub fn with_headers(&mut self, headers: Headers) -> &mut Self { + self.headers = headers; + self + } + + pub fn with_method(&mut self, method: Method) -> &mut Self { + self.method = method; + self + } + + pub fn with_redirect(&mut self, redirect: RequestRedirect) -> &mut Self { + self.redirect = redirect; + self + } + + pub fn with_body(&mut self, body: Option) -> &mut Self { + self.body = body; + self + } + + pub fn with_cf_properties(&mut self, props: CfProperties) -> &mut Self { + self.cf = props; + self + } +} + +impl From<&RequestInit> for web_sys::RequestInit { + fn from(req: &RequestInit) -> Self { + let mut inner = web_sys::RequestInit::new(); + inner.headers(req.headers.as_ref()); + inner.method(req.method.as_ref()); + inner.redirect(req.redirect.into()); + inner.body(req.body.as_ref()); + + // set the Cloudflare-specific `cf` property on FFI RequestInit + let r = ::js_sys::Reflect::set( + inner.as_ref(), + &JsValue::from("cf"), + &JsValue::from(&req.cf), + ); + debug_assert!( + r.is_ok(), + "setting properties should never fail on our dictionary objects" + ); + let _ = r; + + inner + } +} + +impl Default for RequestInit { + fn default() -> Self { + Self { + body: None, + headers: Headers::new(), + cf: CfProperties::default(), + method: Method::GET, + redirect: RequestRedirect::default(), + } + } +} /// -#[derive(Clone)] pub struct CfProperties { /// Whether Cloudflare Apps should be enabled for this request. Defaults to `true`. pub apps: Option, @@ -25,7 +110,7 @@ pub struct CfProperties { pub cache_ttl: Option, /// This option is a version of the cacheTtl feature which chooses a TTL based on the response’s /// status code. If the response to this request has a status code that matches, Cloudflare will - /// cache for the instructed time, and override cache instructives sent by the origin. For + /// cache for the instructed time, and override cache directives sent by the origin. For /// example: { "200-299": 86400, 404: 1, "500-599": 0 }. The value can be any integer, including /// zero and negative integers. A value of 0 indicates that the cache asset expires immediately. /// Any negative value instructs Cloudflare not to cache at all. @@ -58,13 +143,11 @@ pub struct CfProperties { pub scrape_shield: Option, } -unsafe impl Send for CfProperties {} -unsafe impl Sync for CfProperties {} - impl From<&CfProperties> for JsValue { fn from(props: &CfProperties) -> Self { let obj = js_sys::Object::new(); let defaults = CfProperties::default(); + let serializer = serde_wasm_bindgen::Serializer::new().serialize_maps_as_objects(true); set_prop( &obj, @@ -110,7 +193,7 @@ impl From<&CfProperties> for JsValue { set_prop( &obj, &JsValue::from("cacheTtlByStatus"), - &serde_wasm_bindgen::to_value(&ttl_status_map).unwrap_or_default(), + &ttl_status_map.serialize(&serializer).unwrap_or_default(), ); set_prop( @@ -198,9 +281,6 @@ pub struct MinifyConfig { pub css: bool, } -unsafe impl Send for MinifyConfig {} -unsafe impl Sync for MinifyConfig {} - /// Configuration options for Cloudflare's image optimization feature: /// #[wasm_bindgen] @@ -211,9 +291,6 @@ pub enum PolishConfig { Lossless, } -unsafe impl Send for PolishConfig {} -unsafe impl Sync for PolishConfig {} - impl Default for PolishConfig { fn default() -> Self { Self::Off @@ -229,3 +306,32 @@ impl From for &str { } } } + +#[wasm_bindgen] +#[derive(Default, Clone, Copy)] +pub enum RequestRedirect { + Error, + #[default] + Follow, + Manual, +} + +impl From for &str { + fn from(redirect: RequestRedirect) -> Self { + match redirect { + RequestRedirect::Error => "error", + RequestRedirect::Follow => "follow", + RequestRedirect::Manual => "manual", + } + } +} + +impl From for web_sys::RequestRedirect { + fn from(redir: RequestRedirect) -> Self { + match redir { + RequestRedirect::Error => web_sys::RequestRedirect::Error, + RequestRedirect::Follow => web_sys::RequestRedirect::Follow, + RequestRedirect::Manual => web_sys::RequestRedirect::Manual, + } + } +}