Skip to content

Commit

Permalink
add returning support in MERGE queries. (#1171)
Browse files Browse the repository at this point in the history
  • Loading branch information
igalklebanov committed Dec 8, 2024
1 parent 7d293a7 commit c937f1a
Show file tree
Hide file tree
Showing 11 changed files with 503 additions and 89 deletions.
1 change: 1 addition & 0 deletions deno.check.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type {
export interface Database {
audit: AuditTable
person: PersonTable
person_backup: PersonTable
pet: PetTable
toy: ToyTable
wine: WineTable
Expand Down
52 changes: 52 additions & 0 deletions src/helpers/postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,55 @@ export function jsonBuildObject<O extends Record<string, Expression<unknown>>>(
Object.keys(obj).flatMap((k) => [sql.lit(k), obj[k]]),
)})`
}

export type MergeAction = 'INSERT' | 'UPDATE' | 'DELETE'

/**
* The PostgreSQL `merge_action` function.
*
* This function can be used in a `returning` clause to get the action that was
* performed in a `mergeInto` query. The function returns one of the following
* strings: `'INSERT'`, `'UPDATE'`, or `'DELETE'`.
*
* ### Examples
*
* ```ts
* import { mergeAction } from 'kysely/helpers/postgres'
*
* const result = await db
* .mergeInto('person as p')
* .using('person_backup as pb', 'p.id', 'pb.id')
* .whenMatched()
* .thenUpdateSet(({ ref }) => ({
* first_name: ref('pb.first_name'),
* updated_at: ref('pb.updated_at').$castTo<string | null>(),
* }))
* .whenNotMatched()
* .thenInsertValues(({ ref}) => ({
* id: ref('pb.id'),
* first_name: ref('pb.first_name'),
* created_at: ref('pb.updated_at'),
* updated_at: ref('pb.updated_at').$castTo<string | null>(),
* }))
* .returning([mergeAction().as('action'), 'p.id', 'p.updated_at'])
* .execute()
*
* result[0].action
* ```
*
* The generated SQL (PostgreSQL):
*
* ```sql
* merge into "person" as "p"
* using "person_backup" as "pb" on "p"."id" = "pb"."id"
* when matched then update set
* "first_name" = "pb"."first_name",
* "updated_at" = "pb"."updated_at"::text
* when not matched then insert values ("id", "first_name", "created_at", "updated_at")
* values ("pb"."id", "pb"."first_name", "pb"."updated_at", "pb"."updated_at")
* returning merge_action() as "action", "p"."id", "p"."updated_at"
* ```
*/
export function mergeAction(): RawBuilder<MergeAction> {
return sql`merge_action()`
}
2 changes: 2 additions & 0 deletions src/operation-node/merge-query-node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { AliasNode } from './alias-node.js'
import { JoinNode } from './join-node.js'
import { OperationNode } from './operation-node.js'
import { OutputNode } from './output-node.js'
import { ReturningNode } from './returning-node.js'
import { TableNode } from './table-node.js'
import { TopNode } from './top-node.js'
import { WhenNode } from './when-node.js'
Expand All @@ -15,6 +16,7 @@ export interface MergeQueryNode extends OperationNode {
readonly whens?: ReadonlyArray<WhenNode>
readonly with?: WithNode
readonly top?: TopNode
readonly returning?: ReturningNode
readonly output?: OutputNode
readonly endModifiers?: ReadonlyArray<OperationNode>
}
Expand Down
1 change: 1 addition & 0 deletions src/operation-node/operation-node-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ export class OperationNodeTransformer {
top: this.transformNode(node.top),
endModifiers: this.transformNodeList(node.endModifiers),
output: this.transformNode(node.output),
returning: this.transformNode(node.returning),
})
}

Expand Down
4 changes: 2 additions & 2 deletions src/query-builder/delete-query-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import { QueryId } from '../util/query-id.js'
import { freeze } from '../util/object-utils.js'
import { KyselyPlugin } from '../plugin/kysely-plugin.js'
import { WhereInterface } from './where-interface.js'
import { ReturningInterface } from './returning-interface.js'
import { MultiTableReturningInterface } from './returning-interface.js'
import {
isNoResultErrorConstructor,
NoResultError,
Expand Down Expand Up @@ -82,7 +82,7 @@ import {
export class DeleteQueryBuilder<DB, TB extends keyof DB, O>
implements
WhereInterface<DB, TB>,
ReturningInterface<DB, TB, O>,
MultiTableReturningInterface<DB, TB, O>,
OutputInterface<DB, TB, O, 'deleted'>,
OperationNodeSource,
Compilable<O>,
Expand Down
117 changes: 111 additions & 6 deletions src/query-builder/merge-query-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,17 @@ import {
} from '../parser/join-parser.js'
import { parseMergeThen, parseMergeWhen } from '../parser/merge-parser.js'
import { ReferenceExpression } from '../parser/reference-parser.js'
import { ReturningAllRow, ReturningRow } from '../parser/returning-parser.js'
import { parseSelectAll, parseSelectArg } from '../parser/select-parser.js'
import {
ReturningAllRow,
ReturningCallbackRow,
ReturningRow,
} from '../parser/returning-parser.js'
import {
parseSelectAll,
parseSelectArg,
SelectCallback,
SelectExpression,
} from '../parser/select-parser.js'
import { TableExpression } from '../parser/table-parser.js'
import { parseTop } from '../parser/top-parser.js'
import {
Expand Down Expand Up @@ -58,10 +67,13 @@ import {
SelectExpressionFromOutputCallback,
SelectExpressionFromOutputExpression,
} from './output-interface.js'
import { MultiTableReturningInterface } from './returning-interface.js'
import { UpdateQueryBuilder } from './update-query-builder.js'

export class MergeQueryBuilder<DB, TT extends keyof DB, O>
implements OutputInterface<DB, TT, O>
implements
MultiTableReturningInterface<DB, TT, O>,
OutputInterface<DB, TT, O>
{
readonly #props: MergeQueryBuilderProps

Expand Down Expand Up @@ -215,6 +227,44 @@ export class MergeQueryBuilder<DB, TT extends keyof DB, O>
})
}

returning<SE extends SelectExpression<DB, TT>>(
selections: ReadonlyArray<SE>,
): MergeQueryBuilder<DB, TT, ReturningRow<DB, TT, O, SE>>

returning<CB extends SelectCallback<DB, TT>>(
callback: CB,
): MergeQueryBuilder<DB, TT, ReturningCallbackRow<DB, TT, O, CB>>

returning<SE extends SelectExpression<DB, TT>>(
selection: SE,
): MergeQueryBuilder<DB, TT, ReturningRow<DB, TT, O, SE>>

returning(args: any): any {
return new MergeQueryBuilder({
...this.#props,
queryNode: QueryNode.cloneWithReturning(
this.#props.queryNode,
parseSelectArg(args),
),
})
}

returningAll<T extends TT>(
table: T,
): MergeQueryBuilder<DB, TT, ReturningAllRow<DB, T, O>>

returningAll(): MergeQueryBuilder<DB, TT, ReturningAllRow<DB, TT, O>>

returningAll(table?: any): any {
return new MergeQueryBuilder({
...this.#props,
queryNode: QueryNode.cloneWithReturning(
this.#props.queryNode,
parseSelectAll(table),
),
})
}

output<OE extends OutputExpression<DB, TT>>(
selections: readonly OE[],
): MergeQueryBuilder<
Expand Down Expand Up @@ -274,7 +324,11 @@ export class WheneableMergeQueryBuilder<
ST extends keyof DB,
O,
>
implements Compilable<O>, OutputInterface<DB, TT, O>, OperationNodeSource
implements
Compilable<O>,
MultiTableReturningInterface<DB, TT | ST, O>,
OutputInterface<DB, TT, O>,
OperationNodeSource
{
readonly #props: MergeQueryBuilderProps

Expand Down Expand Up @@ -608,6 +662,54 @@ export class WheneableMergeQueryBuilder<
return this.#whenNotMatched([lhs, op, rhs], true, true)
}

returning<SE extends SelectExpression<DB, TT | ST>>(
selections: ReadonlyArray<SE>,
): WheneableMergeQueryBuilder<DB, TT, ST, ReturningRow<DB, TT | ST, O, SE>>

returning<CB extends SelectCallback<DB, TT | ST>>(
callback: CB,
): WheneableMergeQueryBuilder<
DB,
TT,
ST,
ReturningCallbackRow<DB, TT | ST, O, CB>
>

returning<SE extends SelectExpression<DB, TT | ST>>(
selection: SE,
): WheneableMergeQueryBuilder<DB, TT, ST, ReturningRow<DB, TT | ST, O, SE>>

returning(args: any): any {
return new WheneableMergeQueryBuilder({
...this.#props,
queryNode: QueryNode.cloneWithReturning(
this.#props.queryNode,
parseSelectArg(args),
),
})
}

returningAll<T extends TT | ST>(
table: T,
): WheneableMergeQueryBuilder<DB, TT, ST, ReturningAllRow<DB, T, O>>

returningAll(): WheneableMergeQueryBuilder<
DB,
TT,
ST,
ReturningAllRow<DB, TT | ST, O>
>

returningAll(table?: any): any {
return new WheneableMergeQueryBuilder({
...this.#props,
queryNode: QueryNode.cloneWithReturning(
this.#props.queryNode,
parseSelectAll(table),
),
})
}

output<OE extends OutputExpression<DB, TT>>(
selections: readonly OE[],
): WheneableMergeQueryBuilder<
Expand Down Expand Up @@ -788,9 +890,12 @@ export class WheneableMergeQueryBuilder<
this.#props.queryId,
)

const { adapter } = this.#props.executor
const query = compiledQuery.query as MergeQueryNode

if (
(compiledQuery.query as MergeQueryNode).output &&
this.#props.executor.adapter.supportsOutput
(query.returning && adapter.supportsReturning) ||
(query.output && adapter.supportsOutput)
) {
return result.rows as any
}
Expand Down
24 changes: 22 additions & 2 deletions src/query-builder/returning-interface.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
ReturningAllRow,
ReturningCallbackRow,
ReturningRow,
} from '../parser/returning-parser.js'
Expand All @@ -10,7 +11,7 @@ export interface ReturningInterface<DB, TB extends keyof DB, O> {
* Allows you to return data from modified rows.
*
* On supported databases like PostgreSQL, this method can be chained to
* `insert`, `update` and `delete` queries to return data.
* `insert`, `update`, `delete` and `merge` queries to return data.
*
* Note that on SQLite you need to give aliases for the expressions to avoid
* [this bug](https://sqlite.org/forum/forumpost/033daf0b32) in SQLite.
Expand Down Expand Up @@ -78,10 +79,29 @@ export interface ReturningInterface<DB, TB extends keyof DB, O> {
): ReturningInterface<DB, TB, ReturningRow<DB, TB, O, SE>>

/**
* Adds a `returning *` to an insert/update/delete query on databases
* Adds a `returning *` to an insert/update/delete/merge query on databases
* that support `returning` such as PostgreSQL.
*
* Also see the {@link returning} method.
*/
returningAll(): ReturningInterface<DB, TB, Selectable<DB[TB]>>
}

export interface MultiTableReturningInterface<DB, TB extends keyof DB, O>
extends ReturningInterface<DB, TB, O> {
/**
* Adds a `returning *` or `returning table.*` to an insert/update/delete/merge
* query on databases that support `returning` such as PostgreSQL.
*
* Also see the {@link returning} method.
*/
returningAll<T extends TB>(
tables: ReadonlyArray<T>,
): MultiTableReturningInterface<DB, TB, ReturningAllRow<DB, T, O>>

returningAll<T extends TB>(
table: T,
): MultiTableReturningInterface<DB, TB, ReturningAllRow<DB, T, O>>

returningAll(): ReturningInterface<DB, TB, Selectable<DB[TB]>>
}
4 changes: 2 additions & 2 deletions src/query-builder/update-query-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import { freeze } from '../util/object-utils.js'
import { UpdateResult } from './update-result.js'
import { KyselyPlugin } from '../plugin/kysely-plugin.js'
import { WhereInterface } from './where-interface.js'
import { ReturningInterface } from './returning-interface.js'
import { MultiTableReturningInterface } from './returning-interface.js'
import {
isNoResultErrorConstructor,
NoResultError,
Expand Down Expand Up @@ -83,7 +83,7 @@ import {
export class UpdateQueryBuilder<DB, UT extends keyof DB, TB extends keyof DB, O>
implements
WhereInterface<DB, TB>,
ReturningInterface<DB, TB, O>,
MultiTableReturningInterface<DB, TB, O>,
OutputInterface<DB, TB, O>,
OperationNodeSource,
Compilable<O>,
Expand Down
5 changes: 5 additions & 0 deletions src/query-compiler/default-query-compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,11 @@ export class DefaultQueryCompiler
this.compileList(node.whens, ' ')
}

if (node.returning) {
this.append(' ')
this.visitNode(node.returning)
}

if (node.output) {
this.append(' ')
this.visitNode(node.output)
Expand Down
Loading

0 comments on commit c937f1a

Please sign in to comment.