Skip to content

Commit

Permalink
Duplicate test harness using axum router (cloudflare#481)
Browse files Browse the repository at this point in the history
* Prepare to remove global Fetch for http

* clippy

* Refactor worker test harness

* Axum test harness

* Get more tests working

* Macro for marking future as Send

* Remaining axum routes

* More documentation

* Cleanup
  • Loading branch information
kflansburg authored and jdon committed Mar 27, 2024
1 parent 81e9179 commit 3454cf6
Show file tree
Hide file tree
Showing 30 changed files with 1,751 additions and 852 deletions.
23 changes: 22 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions worker-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod durable_object;
mod event;
mod send;

use proc_macro::TokenStream;

Expand All @@ -21,3 +22,21 @@ pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
pub fn event(attr: TokenStream, item: TokenStream) -> TokenStream {
event::expand_macro(attr, item, false)
}

#[proc_macro_attribute]
/// Convert an async function which is `!Send` to be `Send`.
///
/// This is useful for implementing async handlers in frameworks which
/// expect the handler to be `Send`, such as `axum`.
///
/// ```rust
/// #[worker::send]
/// async fn foo() {
/// // JsFuture is !Send
/// let fut = JsFuture::from(promise);
/// fut.await
/// }
/// ```
pub fn send(attr: TokenStream, stream: TokenStream) -> TokenStream {
send::expand_macro(attr, stream)
}
26 changes: 26 additions & 0 deletions worker-macros/src/send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, ItemFn};

pub fn expand_macro(_attr: TokenStream, stream: TokenStream) -> TokenStream {
let stream_clone = stream.clone();
let input = parse_macro_input!(stream_clone as ItemFn);

let ItemFn {
attrs,
vis,
sig,
block,
} = input;
let stmts = &block.stmts;

let tokens = quote! {
#(#attrs)* #vis #sig {
worker::send::SendFuture::new(async {
#(#stmts)*
}).await
}
};

TokenStream::from(tokens)
}
16 changes: 15 additions & 1 deletion worker-sandbox/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ path = "src/lib.rs"

[features]
default = ["console_error_panic_hook"]
http = ["worker/http"]
http = ["worker/http", "dep:axum", "dep:tower-service", "dep:axum-macros"]

[dependencies]
futures-channel.workspace = true
Expand All @@ -38,6 +38,20 @@ uuid = { version = "1.3.3", features = ["v4", "serde"] }
serde-wasm-bindgen = "0.6.1"
md5 = "0.7.0"

[dependencies.axum]
version = "0.7"
optional = true
default-features = false

[dependencies.axum-macros]
version = "0.4"
optional = true
default-features = false

[dependencies.tower-service]
version = "0.3"
optional = true

[dev-dependencies]
wasm-bindgen-test.workspace = true
futures-channel = { version = "0.3.29", features = ["sink"] }
Expand Down
30 changes: 30 additions & 0 deletions worker-sandbox/src/alarm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::time::Duration;

use worker::*;

use super::SomeSharedData;

#[durable_object]
pub struct AlarmObject {
state: State,
Expand Down Expand Up @@ -39,3 +41,31 @@ impl DurableObject for AlarmObject {
Response::ok("ALARMED")
}
}

#[worker::send]
pub async fn handle_alarm(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let namespace = env.durable_object("ALARM")?;
let stub = namespace.id_from_name("alarm")?.get_stub()?;
// when calling fetch to a Durable Object, a full URL must be used. Alternatively, a
// compatibility flag can be provided in wrangler.toml to opt-in to older behavior:
// https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url
stub.fetch_with_str("https://fake-host/alarm").await
}

#[worker::send]
pub async fn handle_id(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let namespace = env.durable_object("COUNTER").expect("DAWJKHDAD");
let stub = namespace.id_from_name("A")?.get_stub()?;
// when calling fetch to a Durable Object, a full URL must be used. Alternatively, a
// compatibility flag can be provided in wrangler.toml to opt-in to older behavior:
// https://developers.cloudflare.com/workers/platform/compatibility-dates#durable-object-stubfetch-requires-a-full-url
stub.fetch_with_str("https://fake-host/").await
}

#[worker::send]
pub async fn handle_put_raw(req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let namespace = env.durable_object("PUT_RAW_TEST_OBJECT")?;
let id = namespace.unique_id()?;
let stub = id.get_stub()?;
stub.fetch_with_request(req).await
}
123 changes: 123 additions & 0 deletions worker-sandbox/src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use super::SomeSharedData;
use futures_util::stream::StreamExt;
use rand::Rng;
use std::time::Duration;
use worker::{console_log, Cache, Date, Delay, Env, Request, Response, Result};

fn key(req: &Request) -> Result<Option<String>> {
let uri = req.url()?;
let mut segments = uri.path_segments().unwrap();
Ok(segments.nth(2).map(|s| s.to_owned()))
}

#[worker::send]
pub async fn handle_cache_example(
req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
console_log!("url: {}", req.url()?.to_string());
let cache = Cache::default();
let key = req.url()?.to_string();
if let Some(resp) = cache.get(&key, true).await? {
console_log!("Cache HIT!");
Ok(resp)
} else {
console_log!("Cache MISS!");
let mut resp =
Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?;

// Cache API respects Cache-Control headers. Setting s-max-age to 10
// will limit the response to be in cache for 10 seconds max
resp.headers_mut().set("cache-control", "s-maxage=10")?;
cache.put(key, resp.cloned()?).await?;
Ok(resp)
}
}

#[worker::send]
pub async fn handle_cache_api_get(
req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
if let Some(key) = key(&req)? {
let cache = Cache::default();
if let Some(resp) = cache.get(format!("https://{key}"), true).await? {
return Ok(resp);
} else {
return Response::ok("cache miss");
}
}
Response::error("key missing", 400)
}

#[worker::send]
pub async fn handle_cache_api_put(
req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
if let Some(key) = key(&req)? {
let cache = Cache::default();

let mut resp =
Response::from_json(&serde_json::json!({ "timestamp": Date::now().as_millis() }))?;

// Cache API respects Cache-Control headers. Setting s-max-age to 10
// will limit the response to be in cache for 10 seconds max
resp.headers_mut().set("cache-control", "s-maxage=10")?;
cache.put(format!("https://{key}"), resp.cloned()?).await?;
return Ok(resp);
}
Response::error("key missing", 400)
}

#[worker::send]
pub async fn handle_cache_api_delete(
req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
if let Some(key) = key(&req)? {
let cache = Cache::default();

let res = cache.delete(format!("https://{key}"), true).await?;
return Response::ok(serde_json::to_string(&res)?);
}
Response::error("key missing", 400)
}

#[worker::send]
pub async fn handle_cache_stream(
req: Request,
_env: Env,
_data: SomeSharedData,
) -> Result<Response> {
console_log!("url: {}", req.url()?.to_string());
let cache = Cache::default();
let key = req.url()?.to_string();
if let Some(resp) = cache.get(&key, true).await? {
console_log!("Cache HIT!");
Ok(resp)
} else {
console_log!("Cache MISS!");
let mut rng = rand::thread_rng();
let count = rng.gen_range(0..10);
let stream = futures_util::stream::repeat("Hello, world!\n")
.take(count)
.then(|text| async move {
Delay::from(Duration::from_millis(50)).await;
Result::Ok(text.as_bytes().to_vec())
});

let mut resp = Response::from_stream(stream)?;
console_log!("resp = {:?}", resp);
// Cache API respects Cache-Control headers. Setting s-max-age to 10
// will limit the response to be in cache for 10 seconds max
resp.headers_mut().set("cache-control", "s-maxage=10")?;

cache.put(key, resp.cloned()?).await?;
Ok(resp)
}
}
26 changes: 16 additions & 10 deletions worker-sandbox/src/d1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ struct Person {
age: u32,
}

#[worker::send]
pub async fn prepared_statement(
_req: Request,
ctx: RouteContext<SomeSharedData>,
env: Env,
_data: SomeSharedData,
) -> Result<Response> {
let db = ctx.env.d1("DB")?;
let db = env.d1("DB")?;
let stmt = worker::query!(&db, "SELECT * FROM people WHERE name = ?", "Ryan Upton")?;

// All rows
Expand Down Expand Up @@ -49,8 +51,9 @@ pub async fn prepared_statement(
Response::ok("ok")
}

pub async fn batch(_req: Request, ctx: RouteContext<SomeSharedData>) -> Result<Response> {
let db = ctx.env.d1("DB")?;
#[worker::send]
pub async fn batch(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let db = env.d1("DB")?;
let mut results = db
.batch(vec![
worker::query!(&db, "SELECT * FROM people WHERE id < 4"),
Expand All @@ -73,8 +76,9 @@ pub async fn batch(_req: Request, ctx: RouteContext<SomeSharedData>) -> Result<R
Response::ok("ok")
}

pub async fn exec(mut req: Request, ctx: RouteContext<SomeSharedData>) -> Result<Response> {
let db = ctx.env.d1("DB")?;
#[worker::send]
pub async fn exec(mut req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let db = env.d1("DB")?;
let result = db
.exec(req.text().await?.as_ref())
.await
Expand All @@ -83,14 +87,16 @@ pub async fn exec(mut req: Request, ctx: RouteContext<SomeSharedData>) -> Result
Response::ok(result.count().unwrap_or_default().to_string())
}

pub async fn dump(_req: Request, ctx: RouteContext<SomeSharedData>) -> Result<Response> {
let db = ctx.env.d1("DB")?;
#[worker::send]
pub async fn dump(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let db = env.d1("DB")?;
let bytes = db.dump().await?;
Response::from_bytes(bytes)
}

pub async fn error(_req: Request, ctx: RouteContext<SomeSharedData>) -> Result<Response> {
let db = ctx.env.d1("DB")?;
#[worker::send]
pub async fn error(_req: Request, env: Env, _data: SomeSharedData) -> Result<Response> {
let db = env.d1("DB")?;
let error = db
.exec("THIS IS NOT VALID SQL")
.await
Expand Down
Loading

0 comments on commit 3454cf6

Please sign in to comment.