Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use the same transaction id for matching roles and context #106

Closed
wants to merge 17 commits into from
274 changes: 229 additions & 45 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as crypto from "crypto";
import { Prisma, PrismaClient } from "@prisma/client";
import { Prisma, PrismaClient, PrismaPromise } from "@prisma/client";
import logger from "debug";
import difference from "lodash/difference";
import flatMap from "lodash/flatMap";
Expand All @@ -11,6 +11,112 @@ const VALID_OPERATIONS = ["SELECT", "UPDATE", "INSERT", "DELETE"] as const;

const debug = logger("yates");

interface Batch {
pgRole: string;
context?: { [x: string]: string | number | string[] };
requests: Array<{
params: object;
query: (args: unknown[]) => PrismaPromise<unknown>;
args: unknown;
resolve: (result: unknown) => void;
reject: (error: unknown) => void;
}>;
}

const BatchTxIdCounter = {
id: 0,
nextId() {
return ++this.id;
},
};
export interface ErrorWithBatchIndex {
batchRequestIdx?: number;
}

export function hasBatchIndex(
value: object,
): value is Required<ErrorWithBatchIndex> {
// @ts-ignore
return typeof value["batchRequestIdx"] === "number";
}
export function waitForBatch<T extends PromiseLike<unknown>[]>(
promises: T,
): Promise<{ [K in keyof T]: Awaited<T[K]> }> {
if (promises.length === 0) {
return Promise.resolve([] as { [K in keyof T]: Awaited<T[K]> });
}
return new Promise((resolve, reject) => {
const successfulResults = new Array(promises.length) as {
[K in keyof T]: Awaited<T[K]>;
};
let bestError: unknown = null;
let done = false;
let settledPromisesCount = 0;

const settleOnePromise = () => {
if (done) {
return;
}
settledPromisesCount++;
if (settledPromisesCount === promises.length) {
done = true;
if (bestError) {
reject(bestError);
} else {
resolve(successfulResults);
}
}
};

const immediatelyReject = (error: unknown) => {
if (!done) {
done = true;
reject(error);
}
};

for (let i = 0; i < promises.length; i++) {
promises[i].then(
(result) => {
successfulResults[i] = result;
settleOnePromise();
},
(error) => {
if (!hasBatchIndex(error)) {
immediatelyReject(error);
return;
}

if (error.batchRequestIdx === i) {
immediatelyReject(error);
} else {
if (!bestError) {
bestError = error;
}
settleOnePromise();
}
},
);
}
});
}

export function getLockCountPromise<V = void>(
knock: number,
cb: () => V | void = () => {},
) {
let resolve: (v: V | void) => void;
const lock = new Promise<V | void>((res) => (resolve = res));

return {
then(onFulfilled) {
if (--knock === 0) resolve(cb());

return onFulfilled?.(lock as unknown as V | void);
},
} as PromiseLike<V | void>;
}

type Operation = (typeof VALID_OPERATIONS)[number];
export type Models = Prisma.ModelName;

Expand Down Expand Up @@ -76,6 +182,7 @@ export type CustomAbilities<

export type GetContextFn<ContextKeys extends string = string> = () => {
role: string;
transactionId?: string;
context?: {
[key in ContextKeys]: string | number | string[];
};
Expand Down Expand Up @@ -179,6 +286,38 @@ export const createRoleName = (name: string) => {
return sanitizeSlug(hashWithPrefix("yates_role_", `${name}`));
};

// @ts-ignore
export function getBatchId(query: any): string | undefined {
if (query.action !== "findUnique" && query.action !== "findUniqueOrThrow") {
return undefined;
}
const parts: string[] = [];
if (query.modelName) {
parts.push(query.modelName);
}

if (query.query.arguments) {
parts.push(buildKeysString(query.query.arguments));
}
parts.push(buildKeysString(query.query.selection));

return parts.join("");
}
function buildKeysString(obj: object): string {
const keysArray = Object.keys(obj)
.sort()
.map((key) => {
// @ts-ignore
const value = obj[key];
if (typeof value === "object" && value !== null) {
return `(${key} ${buildKeysString(value)})`;
}
return key;
});

return `(${keysArray.join(" ")})`;
}

// This uses client extensions to set the role and context for the current user so that RLS can be applied
export const createClient = (
prisma: PrismaClient,
Expand All @@ -187,6 +326,63 @@ export const createClient = (
) => {
// Set default options
const { txMaxWait = 30000, txTimeout = 30000 } = options;

// @ts-ignore
(prisma as any)._requestHandler.dataloader.options.batchBy = (n) => {
return n.transaction?.yates_id
? n.transaction.yates_id + (getBatchId(n.protocolQuery) || "")
: n.transaction?.id
? `transaction-${n.transaction.id}`
: getBatchId(n.protocolQuery);
};

let tickActive = false;
const batches: Record<string, Batch> = {};

const dispatchBatches = () => {
for (const [key, batch] of Object.entries(batches)) {
console.log(key, batch);
delete batches[key];

prisma
.$transaction(async (tx) => {
await tx.$queryRawUnsafe(`SET ROLE ${batch.pgRole}`);
// Now set all the context variables using `set_config` so that they can be used in RLS
for (const [key, value] of toPairs(batch.context)) {
await tx.$queryRaw`SELECT set_config(${key}, ${value.toString()}, true);`;
}
// https://github.com/prisma/prisma/blob/4.11.0/packages/client/src/runtime/getPrismaClient.ts#L1013
// biome-ignore lint/suspicious/noExplicitAny: This is a private API, so not much we can do about it
const txId = (tx as any)[Symbol.for("prisma.client.transaction.id")];
const results = await Promise.all(
batch.requests.map((request) =>
prisma._executeRequest({
...request.params,
transaction: {
kind: "itx",
id: txId,
yates_id: key,
},
}),
),
);
// Switch role back to admin user
await tx.$queryRawUnsafe("SET ROLE none");

return results;
})
.then((results) => {
results.forEach((result, index) => {
batch.requests[index].resolve(result);
});
})
.catch((e) => {
batch.requests.forEach((request) => request.reject(e));
delete batches[key];
});
}
};

const client = prisma.$extends({
name: "Yates client",
query: {
Expand Down Expand Up @@ -244,51 +440,39 @@ export const createClient = (
}

try {
// Because batch transactions inside a prisma client query extension can run out of order if used with async middleware,
// we need to run the logic inside an interactive transaction, however this brings a different set of problems in that the
// main query will no longer automatically run inside the transaction. We resolve this issue by manually executing the prisma request.
// See https://github.com/prisma/prisma/issues/18276
const queryResults = await prisma.$transaction(
async (tx) => {
// Switch to the user role, We can't use a prepared statement here, due to limitations in PG not allowing prepared statements to be used in SET ROLE
await tx.$queryRawUnsafe(`SET ROLE ${pgRole}`);
// Now set all the context variables using `set_config` so that they can be used in RLS
for (const [key, value] of toPairs(context)) {
await tx.$queryRaw`SELECT set_config(${key}, ${value.toString()}, true);`;
}

// Inconveniently, the `query` function will not run inside an interactive transaction.
// We need to manually reconstruct the query, and attached the "secret" transaction ID.
// This ensures that the query will run inside the transaction AND that middlewares will not be re-applied

// https://github.com/prisma/prisma/blob/4.11.0/packages/client/src/runtime/getPrismaClient.ts#L1013
// biome-ignore lint/suspicious/noExplicitAny: This is a private API, so not much we can do about it
const txId = (tx as any)[
Symbol.for("prisma.client.transaction.id")
];

// See https://github.com/prisma/prisma/blob/4.11.0/packages/client/src/runtime/getPrismaClient.ts#L860
// biome-ignore lint/suspicious/noExplicitAny: This is a private API, so not much we can do about it
const __internalParams = (params as any).__internalParams;
const result = await prisma._executeRequest({
...__internalParams,
transaction: {
kind: "itx",
id: txId,
},
const txId = hashWithPrefix("yates_tx_", JSON.stringify(ctx));

const hash = txId;
if (!batches[hash]) {
batches[hash] = {
pgRole,
context,
requests: [],
};

// make sure, that we only tick once at a time
if (!tickActive) {
tickActive = true;
process.nextTick(() => {
dispatchBatches();
tickActive = false;
});
// Switch role back to admin user
await tx.$queryRawUnsafe("SET ROLE none");

return result;
},
{
maxWait: txMaxWait,
timeout: txTimeout,
},
);

return queryResults;
}
}

// See https://github.com/prisma/prisma/blob/4.11.0/packages/client/src/runtime/getPrismaClient.ts#L860
// biome-ignore lint/suspicious/noExplicitAny: This is a private API, so not much we can do about it
const __internalParams = (params as any).__internalParams;

return new Promise((resolve, reject) => {
batches[hash].requests.push({
params: __internalParams,
query,
args,
resolve,
reject,
});
});
} catch (e) {
// Normalize RLS errors to make them a bit more readable.
if (
Expand Down
Loading