From 682a06ea476477795e1ceacdc49bc04daa3d62ef Mon Sep 17 00:00:00 2001
From: Ernest Badu <ernestjbadu@outlook.com>
Date: Thu, 13 Jun 2024 20:49:25 +0100
Subject: [PATCH] feat: improve integration with rpc worker entrypoints

---
 .../src/__test__/fixtures/handler.ts          | 61 +++++++++++++++++++
 .../cloudflare/src/__test__/fixtures/users.ts |  1 +
 .../cloudflare/src/__test__/handler.test.ts   | 29 ++++++++-
 packages/cloudflare/src/handler.ts            | 18 +++---
 packages/cloudflare/vitest.config.ts          |  3 +
 5 files changed, 104 insertions(+), 8 deletions(-)
 create mode 100644 packages/cloudflare/src/__test__/fixtures/handler.ts
 create mode 100644 packages/cloudflare/src/__test__/fixtures/users.ts

diff --git a/packages/cloudflare/src/__test__/fixtures/handler.ts b/packages/cloudflare/src/__test__/fixtures/handler.ts
new file mode 100644
index 0000000..4edaf4d
--- /dev/null
+++ b/packages/cloudflare/src/__test__/fixtures/handler.ts
@@ -0,0 +1,61 @@
+import { WorkerEntrypoint } from "cloudflare:workers";
+import { cloudflare } from "src/handler";
+import { inject } from "src/inject";
+import { cache } from "src/modules/cache/cache";
+import { cookies } from "src/modules/cookies/cookies";
+import { bindings, env, locals } from "src/modules/env";
+import { executionContext, waitUntil } from "src/modules/execution-context";
+import {
+	url,
+	headers,
+	ip,
+	method,
+	pathname,
+	request,
+	searchParams,
+} from "src/modules/request";
+import { response, status } from "src/modules/response";
+import { expect } from "vitest";
+import { sendWelcomeEmail } from "./users";
+
+inject(() => ({ requestId: "123" }));
+
+export default class TestHandler extends cloudflare(
+	WorkerEntrypoint<{ TestBucket: R2Bucket }>
+) {
+	// @ts-expect-error
+	async fetch(req: Request) {
+		const services = bindings();
+
+		expect(services).toHaveProperty("TestBucket");
+		// @ts-expect-error TODO: Fix
+		expect(services.TestBucket).toMatchObject({
+			put: expect.any(Function),
+			get: expect.any(Function),
+		});
+
+		cache({ maxAge: "1w" });
+		waitUntil(new Promise((resolve) => sendWelcomeEmail().then(resolve)));
+		status(201);
+		response().headers.set("X-Foo", "Bar");
+		response().headers.set("Content-Type", "application/json");
+		cookies().set("session", "123");
+
+		expect(method()).toEqual("GET");
+		expect(request()).toEqual(req);
+		expect(pathname()).toEqual("/users/123");
+		expect(ip()).toEqual("some-ip");
+		expect(url()).toEqual(new URL("https://example.com/users/123?sort=asc"));
+		expect(env()).toEqual({ DatabaseUrl: "some-url" });
+		expect(locals()).toEqual({ requestId: "123" });
+		expect(executionContext()).toBe(this.ctx);
+
+		expect(Object.fromEntries(searchParams())).toEqual({ sort: "asc" });
+		expect(Object.fromEntries(headers())).toEqual({
+			"x-request-id": "some-req-id",
+			"cf-connecting-ip": "some-ip",
+		});
+
+		return { id: "123" };
+	}
+}
diff --git a/packages/cloudflare/src/__test__/fixtures/users.ts b/packages/cloudflare/src/__test__/fixtures/users.ts
new file mode 100644
index 0000000..42d55b8
--- /dev/null
+++ b/packages/cloudflare/src/__test__/fixtures/users.ts
@@ -0,0 +1 @@
+export async function sendWelcomeEmail() {}
diff --git a/packages/cloudflare/src/__test__/handler.test.ts b/packages/cloudflare/src/__test__/handler.test.ts
index eec67ef..9274cda 100644
--- a/packages/cloudflare/src/__test__/handler.test.ts
+++ b/packages/cloudflare/src/__test__/handler.test.ts
@@ -1,4 +1,4 @@
-import { createExecutionContext, waitOnExecutionContext } from "cloudflare:test";
+import { SELF, createExecutionContext, waitOnExecutionContext } from "cloudflare:test";
 import { cloudflare } from "src/handler";
 import { inject } from "src/inject";
 import { cache, cached } from "src/modules/cache/cache";
@@ -16,6 +16,7 @@ import {
 } from "src/modules/request";
 import { response, status } from "src/modules/response";
 import { describe, expect, test, vi } from "vitest";
+import * as users from "./fixtures/users";
 
 describe("context", () => {
 	test("should throw error trying to access request at the module level", async () => {
@@ -160,4 +161,30 @@ describe("context", () => {
 		expect(waitUntilSpy).toHaveBeenCalledTimes(2);
 		expect(cacheSpy).toHaveBeenCalledOnce();
 	});
+
+	test("context methods should be available on worker entrypoint", async () => {
+		const sendEmailSpy = vi.spyOn(users, "sendWelcomeEmail");
+
+		const req = new Request("https://example.com/users/123?sort=asc", {
+			headers: {
+				"X-Request-Id": "some-req-id",
+				"cf-connecting-ip": "some-ip",
+			},
+		});
+
+		const res = await SELF.fetch(req.clone());
+
+		expect(res.status).toEqual(201);
+		expect(res.headers.get("X-Foo")).toEqual("Bar");
+		expect(res.headers.get("Set-Cookie")).toEqual("session=123");
+		expect(res.headers.get("Content-Type")).toEqual("application/json;charset=utf-8");
+		expect(res.headers.get("Cache-Control")).toEqual("max-age=604800");
+		await expect(res.json()).resolves.toEqual({ id: "123" });
+
+		expect(sendEmailSpy).toHaveBeenCalledOnce();
+
+		await SELF.fetch(req.clone());
+
+		expect(sendEmailSpy).toHaveBeenCalledTimes(2);
+	});
 });
diff --git a/packages/cloudflare/src/handler.ts b/packages/cloudflare/src/handler.ts
index 2f4ce4e..4c964e6 100644
--- a/packages/cloudflare/src/handler.ts
+++ b/packages/cloudflare/src/handler.ts
@@ -1,5 +1,5 @@
 import type { WorkerEntrypoint } from "cloudflare:workers";
-import type { ExportedHandler } from "@cloudflare/workers-types";
+import type { ExecutionContext, ExportedHandler } from "@cloudflare/workers-types";
 import { createResponse } from "./create-response";
 import {
 	HandlerContext,
@@ -42,12 +42,16 @@ export function cloudflare<T extends ExportedWorker<Env>>(handler: T): MakeAsync
 		const Cls = handler as typeof WorkerEntrypoint<Env>;
 		// @ts-expect-error
 		return class extends Cls {
-			fetch(req: any) {
-				return handleRequest(
-					createContext(req, this.env, this.ctx),
-					// biome-ignore lint/style/noNonNullAssertion: <explanation>
-					() => super.fetch?.(req)!
-				);
+			constructor(ctx: ExecutionContext, env: Env) {
+				super(ctx, env);
+
+				// biome-ignore lint/style/noNonNullAssertion: <explanation>
+				const fetchFn = this.fetch!;
+
+				this.fetch = async function handle(request) {
+					const ctx = createContext(request, this.env, this.ctx);
+					return await handleRequest(ctx, () => fetchFn.bind(this)(request));
+				};
 			}
 		};
 	}
diff --git a/packages/cloudflare/vitest.config.ts b/packages/cloudflare/vitest.config.ts
index 85e133e..abcc890 100644
--- a/packages/cloudflare/vitest.config.ts
+++ b/packages/cloudflare/vitest.config.ts
@@ -19,10 +19,13 @@ export default defineWorkersConfig({
 		poolOptions: {
 			workers: {
 				singleWorker: true,
+				main: "./src/__test__/fixtures/handler.ts",
 				miniflare: {
 					name: "main",
 					compatibilityDate: COMPATIBILITY_DATE,
 					compatibilityFlags: COMPATIBILITY_FLAGS,
+					r2Buckets: ["TestBucket"],
+					bindings: { DatabaseUrl: "some-url" },
 				},
 			},
 		},