Skip to content

Commit

Permalink
feat: Add support for using the in operator with context values
Browse files Browse the repository at this point in the history
This change adds support for the `in` operator when using a context
value in a prisma expression. This allows you to do useful stuff
like allow a match against multiple contextually provided values (e.g.
org membership).

Signed-off-by: Lucian Buzzo <[email protected]>
  • Loading branch information
LucianBuzzo committed Jan 9, 2024
1 parent f7d21d0 commit e39732c
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 4 deletions.
80 changes: 80 additions & 0 deletions src/ast-fragments.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { escapeLiteral } from "./escape";

/**
* Generates an AST fragment that will check if a column value exists in a JSONB array stored in a `current_setting`
* The AST fragment represents SQL that looks like this:
* = ANY (SELECT jsonb_array_elements_text(current_setting('ctx.my_context_value')::jsonb))
*/
export const jsonb_array_elements_text = (setting: string) => {
return {
type: "function",
name: "ANY",
args: {
type: "expr_list",
value: [
{
ast: {
with: null,
type: "select",
options: null,
distinct: {
type: null,
},
columns: [
{
type: "expr",
expr: {
type: "function",
name: "jsonb_array_elements_text",
args: {
type: "expr_list",
value: [
{
type: "cast",
keyword: "cast",
expr: {
type: "function",
name: "current_setting",
args: {
type: "expr_list",
value: [
{
type: "parameter",
value: escapeLiteral(setting.replace(/^___yates_context_/, "")),
},
],
},
},
as: null,
symbol: "::",
target: {
dataType: "jsonb",
},
arrows: [],
properties: [],
},
],
},
},
as: null,
},
],
into: {
position: null,
},
from: null,
where: null,
groupby: null,
having: null,
orderby: null,
limit: {
seperator: "",
value: [],
},
window: null,
},
},
],
},
};
};
38 changes: 38 additions & 0 deletions src/expressions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import matches from "lodash/matches";
import { Parser } from "node-sql-parser";
import { escapeLiteral } from "./escape";
import { defineDmmfProperty } from "@prisma/client/runtime/library";
import { jsonb_array_elements_text } from "./ast-fragments";

// This is black magic to get the runtime data model from the Prisma client
// It's not exported, so we need to use some type infiltration to get it
Expand Down Expand Up @@ -61,6 +62,8 @@ const getDmmfMetaData = (client: PrismaClient, model: string, field: string) =>

// Perform substitution of Ints so that Prisma doesn't throw an error due to mismatched type values
// After we've captured the SQL, we can replace the Ints with the original values
// The returned tokens are a map of the token int, and the AST fragment that will replace it.
// We can then reconstruct the query using the AST fragments.
const tokenizeWhereExpression = (
/** The Prisma client to use for metadata */
client: PrismaClient,
Expand Down Expand Up @@ -116,6 +119,7 @@ const tokenizeWhereExpression = (
const isNumeric = PRISMA_NUMERIC_TYPES.includes(fieldData.type);
const isColumnName = typeof value === "string" && !!value.match(/^___yates_row_/);
const isContext = typeof value === "string" && !!value.match(/^___yates_context_/);
const isInStatement = !!value.in;

switch (true) {
case isColumnName:
Expand Down Expand Up @@ -185,6 +189,37 @@ const tokenizeWhereExpression = (
};
break;

case isInStatement:
if (Array.isArray(value.in)) {
const values = [];
for (const item in value.in) {
values.push({
type: "single_quote_string",
value: item,
});
}
astFragment = {
type: "binary_expr",
operator: "IN",
left: {
type: "column_ref",
schema: "public",
table: table,
column: field,
},
right: {
type: "expr_list",
value: values,
},
};
} else {
// If the value of `in` is a context value, we assume that it is an array that has been JSON encoded
// We create an AST fragment representing a function call to `jsonb_array_elements_text` with the context value as the argument
astFragment = jsonb_array_elements_text(value.in);
}

break;

// All other types are treated as strings
default:
astFragment = {
Expand Down Expand Up @@ -235,7 +270,9 @@ export const expressionToSQL = async (getExpression: Expression, table: string):
}

if ("where" in args && args.where) {
console.log(args.where);
const { where } = tokenizeWhereExpression(baseClient, args.where, table, model, tokens);
console.log("tokenized", where);
args.where = where;
}

Expand All @@ -261,6 +298,7 @@ export const expressionToSQL = async (getExpression: Expression, table: string):
try {
const parser = new Parser();
// Parse the query into an AST
console.log(e.query);
const ast: any = parser.astify(e.query, {
database: "postgresql",
});
Expand Down
24 changes: 20 additions & 4 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { Expression, expressionToSQL, RuntimeDataModel } from "./expressions";

const VALID_OPERATIONS = ["SELECT", "UPDATE", "INSERT", "DELETE"] as const;

type Operation = typeof VALID_OPERATIONS[number];
type Operation = (typeof VALID_OPERATIONS)[number];
export type Models = Prisma.ModelName;

interface ClientOptions {
Expand Down Expand Up @@ -69,7 +69,11 @@ const hashWithPrefix = (prefix: string, abilityName: string) => {
};

// Sanitize a single string by ensuring the it has only lowercase alpha characters and underscores
const sanitizeSlug = (slug: string) => slug.toLowerCase().replace("-", "_").replace(/[^a-z0-9_]/gi, "");
const sanitizeSlug = (slug: string) =>
slug
.toLowerCase()
.replace("-", "_")
.replace(/[^a-z0-9_]/gi, "");

export const createAbilityName = (model: string, ability: string) => {
return sanitizeSlug(hashWithPrefix("yates_ability_", `${model}_${ability}`));
Expand Down Expand Up @@ -111,8 +115,17 @@ export const createClient = (prisma: PrismaClient, getContext: GetContextFn, opt
`Context variable "${k}" contains invalid characters. Context variables must only contain lowercase letters, numbers, periods and underscores.`,
);
}
if (typeof context[k] !== "number" && typeof context[k] !== "string") {
throw new Error(`Context variable "${k}" must be a string or number. Got ${typeof context[k]}`);
if (typeof context[k] !== "number" && typeof context[k] !== "string" && !Array.isArray(context[k])) {
throw new Error(`Context variable "${k}" must be a string, number or array. Got ${typeof context[k]}`);
}
if (Array.isArray(context[k])) {
for (const v of context[k] as any[]) {
if (typeof v !== "string") {
throw new Error(`Context variable "${k}" must be an array of strings. Got ${typeof v}`);
}
}
// Cast to a JSON string so that it can be used in RLS expressions
context[k] = JSON.stringify(context[k]);
}
}
}
Expand Down Expand Up @@ -183,6 +196,9 @@ const setRLS = async (
rawExpression: Expression,
) => {
let expression = await expressionToSQL(rawExpression, table);
if (expression !== "true") {
console.log("GOT expression", expression);
}

// Check if RLS exists
const policyName = roleName;
Expand Down
78 changes: 78 additions & 0 deletions test/integration/expressions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { PrismaClient } from "@prisma/client";
import _ from "lodash";
import { v4 as uuid } from "uuid";
import { setup } from "../../src";
import { Parser } from "node-sql-parser";

jest.setTimeout(30000);

Expand Down Expand Up @@ -832,5 +833,82 @@ describe("expressions", () => {

expect(result2).toBeNull();
});

it("should be able to handle context values that are arrays", async () => {
const initial = new PrismaClient();

const role = `USER_${uuid()}`;

const testTitle1 = `test_${uuid()}`;
const testTitle2 = `test_${uuid()}`;

const client = await setup({
prisma: initial,
customAbilities: {
Post: {
customCreateAbility: {
description: "Create posts where there is already a tag label with the same title",
operation: "INSERT",
expression: (client: PrismaClient, _row, context) => {
return client.tag.findFirst({
where: {
label: {
in: context("post.title") as any as string[],
},
},
});
},
},
},
},
getRoles(abilities) {
return {
[role]: [abilities.Post.customCreateAbility, abilities.Post.read, abilities.Tag.read, abilities.Tag.create],
};
},
getContext: () => ({
role,
context: {
"post.title": [testTitle1, testTitle2],
},
}),
});

await expect(
client.post.create({
data: {
title: testTitle1,
},
}),
).rejects.toThrow();

await client.tag.create({
data: {
label: testTitle1,
},
});

const post1 = await client.post.create({
data: {
title: testTitle1,
},
});

expect(post1.id).toBeDefined();

await client.tag.create({
data: {
label: testTitle2,
},
});

const post2 = await client.post.create({
data: {
title: testTitle2,
},
});

expect(post2.id).toBeDefined();
});
});
});

0 comments on commit e39732c

Please sign in to comment.