Skip to content

Commit

Permalink
Merge pull request #1509 from Angelelz/feat-onupdate
Browse files Browse the repository at this point in the history
Feat: add an `$onUpdate` method to columns
  • Loading branch information
AndriiSherman authored Mar 27, 2024
2 parents 3216719 + 9d789cf commit 77763ae
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 39 deletions.
21 changes: 21 additions & 0 deletions drizzle-orm/src/column-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export type ColumnBuilderRuntimeConfig<TData, TRuntimeConfig extends object = ob
notNull: boolean;
default: TData | SQL | undefined;
defaultFn: (() => TData | SQL) | undefined;
onUpdateFn: (() => TData | SQL) | undefined;
hasDefault: boolean;
primaryKey: boolean;
isUnique: boolean;
Expand Down Expand Up @@ -192,6 +193,26 @@ export abstract class ColumnBuilder<
*/
$default = this.$defaultFn;

/**
* Adds a dynamic update value to the column.
* The function will be called when the row is updated, and the returned value will be used as the column value if none is provided.
* If no `default` (or `$defaultFn`) value is provided, the function will be called when the row is inserted as well, and the returned value will be used as the column value.
*
* **Note:** This value does not affect the `drizzle-kit` behavior, it is only used at runtime in `drizzle-orm`.
*/
$onUpdateFn(
fn: () => (this['_'] extends { $type: infer U } ? U : this['_']['data']) | SQL,
): HasDefault<this> {
this.config.onUpdateFn = fn;
this.config.hasDefault = true;
return this as HasDefault<this>;
}

/**
* Alias for {@link $defaultFn}.
*/
$onUpdate = this.$onUpdateFn;

/**
* Adds a `primary key` clause to the column definition. This implicitly makes the column `not null`.
*
Expand Down
2 changes: 2 additions & 0 deletions drizzle-orm/src/column.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export abstract class Column<
readonly notNull: boolean;
readonly default: T['data'] | SQL | undefined;
readonly defaultFn: (() => T['data'] | SQL) | undefined;
readonly onUpdateFn: (() => T['data'] | SQL) | undefined;
readonly hasDefault: boolean;
readonly isUnique: boolean;
readonly uniqueName: string | undefined;
Expand All @@ -79,6 +80,7 @@ export abstract class Column<
this.notNull = config.notNull;
this.default = config.default;
this.defaultFn = config.defaultFn;
this.onUpdateFn = config.onUpdateFn;
this.hasDefault = config.hasDefault;
this.primary = config.primaryKey;
this.isUnique = config.isUnique;
Expand Down
35 changes: 22 additions & 13 deletions drizzle-orm/src/mysql-core/dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,24 @@ export class MySqlDialect {
}

buildUpdateSet(table: MySqlTable, set: UpdateSet): SQL {
const setEntries = Object.entries(set);

const setSize = setEntries.length;
return sql.join(
setEntries
.flatMap(([colName, value], i): SQL[] => {
const col: MySqlColumn = table[Table.Symbol.Columns][colName]!;
const res = sql`${sql.identifier(col.name)} = ${value}`;
if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}),
const tableColumns = table[Table.Symbol.Columns];

const columnNames = Object.keys(tableColumns).filter((colName) =>
set[colName] !== undefined || tableColumns[colName]?.onUpdateFn !== undefined
);

const setSize = columnNames.length;
return sql.join(columnNames.flatMap((colName, i) => {
const col = tableColumns[colName]!;

const value = set[colName] ?? sql.param(col.onUpdateFn!(), col);
const res = sql`${sql.identifier(col.name)} = ${value}`;

if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}));
}

buildUpdateQuery({ table, set, where, returning, withList }: MySqlUpdateConfig): SQL {
Expand Down Expand Up @@ -423,6 +427,11 @@ export class MySqlDialect {
const defaultFnResult = col.defaultFn();
const defaultValue = is(defaultFnResult, SQL) ? defaultFnResult : sql.param(defaultFnResult, col);
valueList.push(defaultValue);
// eslint-disable-next-line unicorn/no-negated-condition
} else if (!col.default && col.onUpdateFn !== undefined) {
const onUpdateFnResult = col.onUpdateFn();
const newValue = is(onUpdateFnResult, SQL) ? onUpdateFnResult : sql.param(onUpdateFnResult, col);
valueList.push(newValue);
} else {
valueList.push(sql`default`);
}
Expand Down
35 changes: 22 additions & 13 deletions drizzle-orm/src/pg-core/dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,24 @@ export class PgDialect {
}

buildUpdateSet(table: PgTable, set: UpdateSet): SQL {
const setEntries = Object.entries(set);

const setSize = setEntries.length;
return sql.join(
setEntries
.flatMap(([colName, value], i): SQL[] => {
const col: PgColumn = table[Table.Symbol.Columns][colName]!;
const res = sql`${sql.identifier(col.name)} = ${value}`;
if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}),
const tableColumns = table[Table.Symbol.Columns];

const columnNames = Object.keys(tableColumns).filter((colName) =>
set[colName] !== undefined || tableColumns[colName]?.onUpdateFn !== undefined
);

const setSize = columnNames.length;
return sql.join(columnNames.flatMap((colName, i) => {
const col = tableColumns[colName]!;

const value = set[colName] ?? sql.param(col.onUpdateFn!(), col);
const res = sql`${sql.identifier(col.name)} = ${value}`;

if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}));
}

buildUpdateQuery({ table, set, where, returning, withList }: PgUpdateConfig): SQL {
Expand Down Expand Up @@ -455,6 +459,11 @@ export class PgDialect {
const defaultFnResult = col.defaultFn();
const defaultValue = is(defaultFnResult, SQL) ? defaultFnResult : sql.param(defaultFnResult, col);
valueList.push(defaultValue);
// eslint-disable-next-line unicorn/no-negated-condition
} else if (!col.default && col.onUpdateFn !== undefined) {
const onUpdateFnResult = col.onUpdateFn();
const newValue = is(onUpdateFnResult, SQL) ? onUpdateFnResult : sql.param(onUpdateFnResult, col);
valueList.push(newValue);
} else {
valueList.push(sql`default`);
}
Expand Down
34 changes: 21 additions & 13 deletions drizzle-orm/src/sqlite-core/dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,24 @@ export abstract class SQLiteDialect {
}

buildUpdateSet(table: SQLiteTable, set: UpdateSet): SQL {
const setEntries = Object.entries(set);

const setSize = setEntries.length;
return sql.join(
setEntries
.flatMap(([colName, value], i): SQL[] => {
const col: SQLiteColumn = table[Table.Symbol.Columns][colName]!;
const res = sql`${sql.identifier(col.name)} = ${value}`;
if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}),
const tableColumns = table[Table.Symbol.Columns];

const columnNames = Object.keys(tableColumns).filter((colName) =>
set[colName] !== undefined || tableColumns[colName]?.onUpdateFn !== undefined
);

const setSize = columnNames.length;
return sql.join(columnNames.flatMap((colName, i) => {
const col = tableColumns[colName]!;

const value = set[colName] ?? sql.param(col.onUpdateFn!(), col);
const res = sql`${sql.identifier(col.name)} = ${value}`;

if (i < setSize - 1) {
return [res, sql.raw(', ')];
}
return [res];
}));
}

buildUpdateQuery({ table, set, where, returning, withList }: SQLiteUpdateConfig): SQL {
Expand Down Expand Up @@ -387,6 +391,10 @@ export abstract class SQLiteDialect {
} else if (col.defaultFn !== undefined) {
const defaultFnResult = col.defaultFn();
defaultValue = is(defaultFnResult, SQL) ? defaultFnResult : sql.param(defaultFnResult, col);
// eslint-disable-next-line unicorn/no-negated-condition
} else if (!col.default && col.onUpdateFn !== undefined) {
const onUpdateFnResult = col.onUpdateFn();
defaultValue = is(onUpdateFnResult, SQL) ? onUpdateFnResult : sql.param(onUpdateFnResult, col);
} else {
defaultValue = sql`null`;
}
Expand Down
99 changes: 99 additions & 0 deletions integration-tests/tests/libsql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
countDistinct,
eq,
exists,
getTableColumns,
gt,
gte,
inArray,
Expand Down Expand Up @@ -67,6 +68,17 @@ const usersTable = sqliteTable('users', {
createdAt: integer('created_at', { mode: 'timestamp' }).notNull().default(sql`strftime('%s', 'now')`),
});

const usersOnUpdate = sqliteTable('users_on_update', {
id: integer('id').primaryKey({ autoIncrement: true }),
name: text('name').notNull(),
updateCounter: integer('update_counter').default(sql`1`).$onUpdateFn(() => sql`update_counter + 1`),
updatedAt: integer('updated_at', { mode: 'timestamp_ms' }).$onUpdate(() => new Date()),
alwaysNull: text('always_null').$type<string | null>().$onUpdate(() => null),
// uppercaseName: text('uppercase_name').$onUpdateFn(() =>
// sql`upper(s.name)`
// ), This doesn't seem to be supported in sqlite
});

const users2Table = sqliteTable('users2', {
id: integer('id').primaryKey(),
name: text('name').notNull(),
Expand Down Expand Up @@ -2699,3 +2711,90 @@ test.serial('aggregate function: min', async (t) => {
t.deepEqual(result1[0]?.value, 10);
t.deepEqual(result2[0]?.value, null);
});

test.serial('test $onUpdateFn and $onUpdate works as $default', async (t) => {
const { db } = t.context;

await db.run(sql`drop table if exists ${usersOnUpdate}`);

await db.run(
sql`
create table ${usersOnUpdate} (
id integer primary key autoincrement,
name text not null,
update_counter integer default 1 not null,
updated_at integer,
always_null text
)
`,
);

await db.insert(usersOnUpdate).values([
{ name: 'John' },
{ name: 'Jane' },
{ name: 'Jack' },
{ name: 'Jill' },
]);
const { updatedAt, ...rest } = getTableColumns(usersOnUpdate);

const justDates = await db.select({ updatedAt }).from(usersOnUpdate).orderBy(asc(usersOnUpdate.id));

const response = await db.select({ ...rest }).from(usersOnUpdate).orderBy(asc(usersOnUpdate.id));

t.deepEqual(response, [
{ name: 'John', id: 1, updateCounter: 1, alwaysNull: null },
{ name: 'Jane', id: 2, updateCounter: 1, alwaysNull: null },
{ name: 'Jack', id: 3, updateCounter: 1, alwaysNull: null },
{ name: 'Jill', id: 4, updateCounter: 1, alwaysNull: null },
]);
const msDelay = 250;

for (const eachUser of justDates) {
t.assert(eachUser.updatedAt!.valueOf() > Date.now() - msDelay);
}
});

test.serial('test $onUpdateFn and $onUpdate works updating', async (t) => {
const { db } = t.context;

await db.run(sql`drop table if exists ${usersOnUpdate}`);

await db.run(
sql`
create table ${usersOnUpdate} (
id integer primary key autoincrement,
name text not null,
update_counter integer default 1,
updated_at integer,
always_null text
)
`,
);

await db.insert(usersOnUpdate).values([
{ name: 'John', alwaysNull: 'this will be null after updating' },
{ name: 'Jane' },
{ name: 'Jack' },
{ name: 'Jill' },
]);
const { updatedAt, ...rest } = getTableColumns(usersOnUpdate);

await db.update(usersOnUpdate).set({ name: 'Angel' }).where(eq(usersOnUpdate.id, 1));
await db.update(usersOnUpdate).set({ updateCounter: null }).where(eq(usersOnUpdate.id, 2));

const justDates = await db.select({ updatedAt }).from(usersOnUpdate).orderBy(asc(usersOnUpdate.id));

const response = await db.select({ ...rest }).from(usersOnUpdate).orderBy(asc(usersOnUpdate.id));

t.deepEqual(response, [
{ name: 'Angel', id: 1, updateCounter: 2, alwaysNull: null },
{ name: 'Jane', id: 2, updateCounter: null, alwaysNull: null },
{ name: 'Jack', id: 3, updateCounter: 1, alwaysNull: null },
{ name: 'Jill', id: 4, updateCounter: 1, alwaysNull: null },
]);
const msDelay = 250;

for (const eachUser of justDates) {
t.assert(eachUser.updatedAt!.valueOf() > Date.now() - msDelay);
}
});
Loading

0 comments on commit 77763ae

Please sign in to comment.