Skip to content

Commit

Permalink
fix: avoid race condition when creating policies
Browse files Browse the repository at this point in the history
If you're trying to startup multiple Yates clients simultaneously, it's
possible that you hit an error where on client tries to create a policy
that the other client has already created.
This occurs because we read *all* abilities in one go. To avoid this, we
check the ability table once per RLS policy, and we do it inside
a transaction with a lock. This means that the second client will behave
correctly and not try to create the policy again.
  • Loading branch information
LucianBuzzo committed Apr 9, 2024
1 parent 3cd9dc7 commit b7b3b78
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ export const createClient = (

const setRLS = async <ContextKeys extends string, YModel extends Models>(
prisma: PrismaClient,
existingAbilities: PgYatesAbility[],
table: string,
roleName: string,
slug: string,
Expand All @@ -324,50 +323,54 @@ const setRLS = async <ContextKeys extends string, YModel extends Models>(
throw new Error("Expression must be defined for RLS abilities");
}

// Check if RLS exists
const policyName = roleName;
const existingAbility = existingAbilities.find(
(row) =>
row.ability_model === table && row.ability_policy_name === policyName,
);

let shouldUpdateAbilityTable = false;

// IF RLS doesn't exist or expression is different, set RLS
if (!existingAbility) {
debug("Creating RLS policy for", roleName, "on", table, "for", operation);
const expression = await expressionToSQL(rawExpression, table);

// If the operation is an insert or update, we need to use a different syntax as the "WITH CHECK" expression is used.
if (operation === "INSERT") {
await prisma.$queryRawUnsafe(`
// Take a lock and run the RLS setup in a transaction to prevent conflicts
// in a multi-server environment
await prisma.$transaction(async (tx) => {
await takeLock(tx as PrismaClient);
// Check if RLS exists
const policyName = roleName;
const existingAbilities: PgYatesAbility[] = await tx.$queryRaw`
select * from _yates._yates_abilities where ability_model = ${table} and ability_policy_name = ${policyName}
`;
const existingAbility = existingAbilities[0];

let shouldUpdateAbilityTable = false;

// IF RLS doesn't exist or expression is different, set RLS
if (!existingAbility) {
debug("Creating RLS policy for", roleName, "on", table, "for", operation);
const expression = await expressionToSQL(rawExpression, table);

// If the operation is an insert or update, we need to use a different syntax as the "WITH CHECK" expression is used.
if (operation === "INSERT") {
await tx.$queryRawUnsafe(`
CREATE POLICY ${policyName} ON "public"."${table}" FOR ${operation} TO ${roleName} WITH CHECK (${expression});
`);
} else {
await prisma.$queryRawUnsafe(`
} else {
await tx.$queryRawUnsafe(`
CREATE POLICY ${policyName} ON "public"."${table}" FOR ${operation} TO ${roleName} USING (${expression});
`);
}
shouldUpdateAbilityTable = true;
} else if (existingAbility.ability_expression !== rawExpression.toString()) {
debug("Updating RLS policy for", roleName, "on", table, "for", operation);
const expression = await expressionToSQL(rawExpression, table);
if (operation === "INSERT") {
await prisma.$queryRawUnsafe(`
}
shouldUpdateAbilityTable = true;
} else if (
existingAbility.ability_expression !== rawExpression.toString()
) {
debug("Updating RLS policy for", roleName, "on", table, "for", operation);
const expression = await expressionToSQL(rawExpression, table);
if (operation === "INSERT") {
await tx.$queryRawUnsafe(`
ALTER POLICY ${policyName} ON "public"."${table}" TO ${roleName} WITH CHECK (${expression});
`);
} else {
await prisma.$queryRawUnsafe(`
} else {
await tx.$queryRawUnsafe(`
ALTER POLICY ${policyName} ON "public"."${table}" TO ${roleName} USING (${expression});
`);
}
shouldUpdateAbilityTable = true;
}
shouldUpdateAbilityTable = true;
}

if (shouldUpdateAbilityTable) {
await prisma.$transaction([
takeLock(prisma),
upsertAbility(prisma, {
if (shouldUpdateAbilityTable) {
await upsertAbility(tx as PrismaClient, {
ability_model: table,
ability_name: slug,
ability_policy_name: policyName,
Expand All @@ -376,9 +379,9 @@ const setRLS = async <ContextKeys extends string, YModel extends Models>(
// We store the string representation of the expression so that
// we can compare it later without having to recompute the SQL
ability_expression: rawExpression.toString(),
}),
]);
}
});
}
});
};

export const createRoles = async <
Expand Down Expand Up @@ -559,7 +562,6 @@ export const createRoles = async <
if (ability.expression) {
await setRLS(
prisma,
existingAbilities,
table,
roleName,
slug,
Expand Down

0 comments on commit b7b3b78

Please sign in to comment.