diff --git a/src/ast-fragments.ts b/src/ast-fragments.ts new file mode 100644 index 0000000..5ed1f77 --- /dev/null +++ b/src/ast-fragments.ts @@ -0,0 +1,80 @@ +import { escapeLiteral } from "./escape"; + +/** + * Generates an AST fragment that will check if a column value exists in a JSONB array stored in a `current_setting` + * The AST fragment represents SQL that looks like this: + * = ANY (SELECT jsonb_array_elements_text(current_setting('ctx.my_context_value')::jsonb)) + */ +export const jsonb_array_elements_text = (setting: string) => { + return { + type: "function", + name: "ANY", + args: { + type: "expr_list", + value: [ + { + ast: { + with: null, + type: "select", + options: null, + distinct: { + type: null, + }, + columns: [ + { + type: "expr", + expr: { + type: "function", + name: "jsonb_array_elements_text", + args: { + type: "expr_list", + value: [ + { + type: "cast", + keyword: "cast", + expr: { + type: "function", + name: "current_setting", + args: { + type: "expr_list", + value: [ + { + type: "parameter", + value: escapeLiteral(setting.replace(/^___yates_context_/, "")), + }, + ], + }, + }, + as: null, + symbol: "::", + target: { + dataType: "jsonb", + }, + arrows: [], + properties: [], + }, + ], + }, + }, + as: null, + }, + ], + into: { + position: null, + }, + from: null, + where: null, + groupby: null, + having: null, + orderby: null, + limit: { + seperator: "", + value: [], + }, + window: null, + }, + }, + ], + }, + }; +}; diff --git a/src/expressions.ts b/src/expressions.ts index 518927a..b96abfc 100644 --- a/src/expressions.ts +++ b/src/expressions.ts @@ -4,6 +4,7 @@ import matches from "lodash/matches"; import { Parser } from "node-sql-parser"; import { escapeLiteral } from "./escape"; import { defineDmmfProperty } from "@prisma/client/runtime/library"; +import { jsonb_array_elements_text } from "./ast-fragments"; // This is black magic to get the runtime data model from the Prisma client // It's not exported, so we need to use some type infiltration to get it @@ -61,6 +62,8 @@ const getDmmfMetaData = (client: PrismaClient, model: string, field: string) => // Perform substitution of Ints so that Prisma doesn't throw an error due to mismatched type values // After we've captured the SQL, we can replace the Ints with the original values +// The returned tokens are a map of the token int, and the AST fragment that will replace it. +// We can then reconstruct the query using the AST fragments. const tokenizeWhereExpression = ( /** The Prisma client to use for metadata */ client: PrismaClient, @@ -116,6 +119,7 @@ const tokenizeWhereExpression = ( const isNumeric = PRISMA_NUMERIC_TYPES.includes(fieldData.type); const isColumnName = typeof value === "string" && !!value.match(/^___yates_row_/); const isContext = typeof value === "string" && !!value.match(/^___yates_context_/); + const isInStatement = !!value.in; switch (true) { case isColumnName: @@ -185,6 +189,37 @@ const tokenizeWhereExpression = ( }; break; + case isInStatement: + if (Array.isArray(value.in)) { + const values = []; + for (const item in value.in) { + values.push({ + type: "single_quote_string", + value: item, + }); + } + astFragment = { + type: "binary_expr", + operator: "IN", + left: { + type: "column_ref", + schema: "public", + table: table, + column: field, + }, + right: { + type: "expr_list", + value: values, + }, + }; + } else { + // If the value of `in` is a context value, we assume that it is an array that has been JSON encoded + // We create an AST fragment representing a function call to `jsonb_array_elements_text` with the context value as the argument + astFragment = jsonb_array_elements_text(value.in); + } + + break; + // All other types are treated as strings default: astFragment = { @@ -235,7 +270,9 @@ export const expressionToSQL = async (getExpression: Expression, table: string): } if ("where" in args && args.where) { + console.log(args.where); const { where } = tokenizeWhereExpression(baseClient, args.where, table, model, tokens); + console.log("tokenized", where); args.where = where; } @@ -261,6 +298,7 @@ export const expressionToSQL = async (getExpression: Expression, table: string): try { const parser = new Parser(); // Parse the query into an AST + console.log(e.query); const ast: any = parser.astify(e.query, { database: "postgresql", }); diff --git a/src/index.ts b/src/index.ts index f592dc0..e148ec5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,7 +8,7 @@ import { Expression, expressionToSQL, RuntimeDataModel } from "./expressions"; const VALID_OPERATIONS = ["SELECT", "UPDATE", "INSERT", "DELETE"] as const; -type Operation = typeof VALID_OPERATIONS[number]; +type Operation = (typeof VALID_OPERATIONS)[number]; export type Models = Prisma.ModelName; interface ClientOptions { @@ -69,7 +69,11 @@ const hashWithPrefix = (prefix: string, abilityName: string) => { }; // Sanitize a single string by ensuring the it has only lowercase alpha characters and underscores -const sanitizeSlug = (slug: string) => slug.toLowerCase().replace("-", "_").replace(/[^a-z0-9_]/gi, ""); +const sanitizeSlug = (slug: string) => + slug + .toLowerCase() + .replace("-", "_") + .replace(/[^a-z0-9_]/gi, ""); export const createAbilityName = (model: string, ability: string) => { return sanitizeSlug(hashWithPrefix("yates_ability_", `${model}_${ability}`)); @@ -111,8 +115,17 @@ export const createClient = (prisma: PrismaClient, getContext: GetContextFn, opt `Context variable "${k}" contains invalid characters. Context variables must only contain lowercase letters, numbers, periods and underscores.`, ); } - if (typeof context[k] !== "number" && typeof context[k] !== "string") { - throw new Error(`Context variable "${k}" must be a string or number. Got ${typeof context[k]}`); + if (typeof context[k] !== "number" && typeof context[k] !== "string" && !Array.isArray(context[k])) { + throw new Error(`Context variable "${k}" must be a string, number or array. Got ${typeof context[k]}`); + } + if (Array.isArray(context[k])) { + for (const v of context[k] as any[]) { + if (typeof v !== "string") { + throw new Error(`Context variable "${k}" must be an array of strings. Got ${typeof v}`); + } + } + // Cast to a JSON string so that it can be used in RLS expressions + context[k] = JSON.stringify(context[k]); } } } @@ -183,6 +196,9 @@ const setRLS = async ( rawExpression: Expression, ) => { let expression = await expressionToSQL(rawExpression, table); + if (expression !== "true") { + console.log("GOT expression", expression); + } // Check if RLS exists const policyName = roleName; diff --git a/test/integration/expressions.spec.ts b/test/integration/expressions.spec.ts index 7f40e5e..731d832 100644 --- a/test/integration/expressions.spec.ts +++ b/test/integration/expressions.spec.ts @@ -2,6 +2,7 @@ import { PrismaClient } from "@prisma/client"; import _ from "lodash"; import { v4 as uuid } from "uuid"; import { setup } from "../../src"; +import { Parser } from "node-sql-parser"; jest.setTimeout(30000); @@ -832,5 +833,82 @@ describe("expressions", () => { expect(result2).toBeNull(); }); + + it("should be able to handle context values that are arrays", async () => { + const initial = new PrismaClient(); + + const role = `USER_${uuid()}`; + + const testTitle1 = `test_${uuid()}`; + const testTitle2 = `test_${uuid()}`; + + const client = await setup({ + prisma: initial, + customAbilities: { + Post: { + customCreateAbility: { + description: "Create posts where there is already a tag label with the same title", + operation: "INSERT", + expression: (client: PrismaClient, _row, context) => { + return client.tag.findFirst({ + where: { + label: { + in: context("post.title") as any as string[], + }, + }, + }); + }, + }, + }, + }, + getRoles(abilities) { + return { + [role]: [abilities.Post.customCreateAbility, abilities.Post.read, abilities.Tag.read, abilities.Tag.create], + }; + }, + getContext: () => ({ + role, + context: { + "post.title": [testTitle1, testTitle2], + }, + }), + }); + + await expect( + client.post.create({ + data: { + title: testTitle1, + }, + }), + ).rejects.toThrow(); + + await client.tag.create({ + data: { + label: testTitle1, + }, + }); + + const post1 = await client.post.create({ + data: { + title: testTitle1, + }, + }); + + expect(post1.id).toBeDefined(); + + await client.tag.create({ + data: { + label: testTitle2, + }, + }); + + const post2 = await client.post.create({ + data: { + title: testTitle2, + }, + }); + + expect(post2.id).toBeDefined(); + }); }); });