diff --git a/quirrel/src/api/scheduler/token-auth.ts b/quirrel/src/api/scheduler/token-auth.ts index 8e4d170e7..4fb1ec0f1 100644 --- a/quirrel/src/api/scheduler/token-auth.ts +++ b/quirrel/src/api/scheduler/token-auth.ts @@ -1,6 +1,6 @@ import { FastifyPluginCallback, FastifyReply, FastifyRequest } from "fastify"; import fp from "fastify-plugin"; -import { IncomingMessage } from "http"; +import { IncomingMessage, IncomingHttpHeaders } from "http"; import { UsageMeter } from "../shared/usage-meter"; import basicAuth from "basic-auth"; @@ -33,17 +33,21 @@ const tokenAuthServicePlugin: FastifyPluginCallback = ( ) => { const usageMeter = new UsageMeter(fastify.redis); - async function getTokenID(authorizationHeader?: string) { - if (!authorizationHeader) { + async function getTokenID(headers: IncomingHttpHeaders) { + const { authorization } = headers; + if (!authorization) { return null; } - if (authorizationHeader.startsWith("Bearer ")) { - const [_, token] = authorizationHeader.split("Bearer "); + if (authorization.startsWith("Bearer ")) { + const [_, token] = authorization.split("Bearer "); const tokenId = await fastify.tokens.check(token); - return tokenId; - } else if (authorizationHeader.startsWith("Basic ")) { - const basicCredentials = basicAuth.parse(authorizationHeader); + if (!tokenId) { + return null; + } + return { tokenId, countUsage: true }; + } else if (authorization.startsWith("Basic ")) { + const basicCredentials = basicAuth.parse(authorization); if (!basicCredentials) { return null; @@ -51,7 +55,8 @@ const tokenAuthServicePlugin: FastifyPluginCallback = ( const isRootUser = opts.passphrases.includes(basicCredentials.pass); if (isRootUser) { - return basicCredentials.name; + const countUsage = !!headers["x-quirrel-count-usage"]; + return { tokenId: basicCredentials.name, countUsage }; } } @@ -62,10 +67,14 @@ const tokenAuthServicePlugin: FastifyPluginCallback = ( request: FastifyRequest | IncomingMessage ): Promise { if (opts.auth) { - const { authorization } = request.headers; - const tokenId = await getTokenID(authorization); + const result = await getTokenID(request.headers); + if (!result) { + return null; + } + + const { tokenId, countUsage } = result; - if (tokenId) { + if (countUsage) { usageMeter.record(tokenId); } diff --git a/quirrel/src/api/test/authenticated_jobs.test.ts b/quirrel/src/api/test/authenticated_jobs.test.ts index 88442f40f..56abe7be2 100644 --- a/quirrel/src/api/test/authenticated_jobs.test.ts +++ b/quirrel/src/api/test/authenticated_jobs.test.ts @@ -93,8 +93,12 @@ function testAgainst(backend: "Redis" | "Mock") { .auth("ignored", passphrase) .expect(200, {}); }); - test("admin impersonation", async () => { + await request(quirrel) + .delete("/usage") + .auth("ignored", passphrase) + .expect(200); + const { text: token } = await request(quirrel) .put("/tokens/this.is.a.project") .auth("ignored", passphrase) @@ -113,6 +117,33 @@ function testAgainst(backend: "Redis" | "Mock") { expect(lastBody).toEqual('{"foo":"bar"}'); expect(lastSignature).toMatch(/v=(\d+),d=([\da-f]+)/); expect(verify(lastBody, token, lastSignature)).toBe(true); + + await request(quirrel) + .delete("/usage") + .auth("ignored", passphrase) + .expect(200, { + // only one for execution + "this.is.a.project": 1, + }); + + await request(quirrel) + .post("/queues/" + endpoint) + .set("x-quirrel-count-usage", "true") + .auth("this.is.a.project", passphrase) + .send({ + body: JSON.stringify({ foo: "bar" }), + }) + .expect(201); + + await delay(300); + + await request(quirrel) + .delete("/usage") + .auth("ignored", passphrase) + .expect(200, { + // one for enqueueing, one for execution + "this.is.a.project": 2, + }); }); }); }