diff --git a/packages/persistence/src/db.provider.ts b/packages/persistence/src/db.provider.ts index 1fc9733b6..38fbf12a3 100644 --- a/packages/persistence/src/db.provider.ts +++ b/packages/persistence/src/db.provider.ts @@ -1,5 +1,26 @@ -import { inject } from "@undb/di" +import { inject, singleton } from "@undb/di" export const DB_PROVIDER = Symbol.for("DB_PROVIDER") export const injectDbProvider = () => inject(DB_PROVIDER) + +export interface IDbProvider { + getDbProvider(): string + + isPostgres(): boolean + isSqlite(): boolean +} + +@singleton() +export class DbProviderService implements IDbProvider { + constructor(@inject(DB_PROVIDER) private readonly dbProvider: string) {} + getDbProvider(): string { + return this.dbProvider + } + isPostgres(): boolean { + return this.dbProvider === "postgres" + } + isSqlite(): boolean { + return this.dbProvider === "sqlite" || this.dbProvider === "turso" || !this.dbProvider + } +} diff --git a/packages/persistence/src/record/record-query.helper.ts b/packages/persistence/src/record/record-query.helper.ts index eafd0faa8..bdc5c74d0 100644 --- a/packages/persistence/src/record/record-query.helper.ts +++ b/packages/persistence/src/record/record-query.helper.ts @@ -5,7 +5,7 @@ import { FieldIdVo, type Field, type IViewSort, type RecordComositeSpecification import { sql, type ExpressionBuilder, type SelectQueryBuilder } from "kysely" import type { ITxContext } from "../ctx.interface" import { injectTxCTX } from "../ctx.provider" -import { injectDbProvider } from "../db.provider" +import { DbProviderService, type IDbProvider } from "../db.provider" import { injectQueryBuilder } from "../qb.provider" import type { IRecordQueryBuilder } from "../qb.type" import { UnderlyingTable } from "../underlying/underlying-table" @@ -26,8 +26,8 @@ export class RecordQueryHelper { private readonly context: IContext, @injectTxCTX() private readonly txContext: ITxContext, - @injectDbProvider() - private readonly dbProvider: string, + @inject(DbProviderService) + private readonly dbProvider: IDbProvider, @inject(DatabaseFnUtil) private readonly dbFnUtil: IDatabaseFnUtil, ) {} diff --git a/packages/persistence/src/record/record-select-field-visitor.ts b/packages/persistence/src/record/record-select-field-visitor.ts index 25f6d086c..b43fe02f6 100644 --- a/packages/persistence/src/record/record-select-field-visitor.ts +++ b/packages/persistence/src/record/record-select-field-visitor.ts @@ -32,6 +32,7 @@ import { import type { FormulaField } from "@undb/table/src/modules/schema/fields/variants/formula-field" import { getTableName } from "drizzle-orm" import { sql, type ExpressionBuilder, type SelectExpression } from "kysely" +import type { IDbProvider } from "../db.provider" import { users } from "../schema/sqlite" import type { UnderlyingTable } from "../underlying/underlying-table" import { getDateRangeFieldName } from "../underlying/underlying-table.util" @@ -60,15 +61,13 @@ export class RecordSelectFieldVisitor implements IFieldVisitor { private readonly table: UnderlyingTable, private readonly foreignTables: Map, private readonly eb: ExpressionBuilder, - private readonly dbProvider: string, + private readonly dbProvider: IDbProvider, private readonly dbFnUtil: IDatabaseFnUtil, ) { this.#addSelect(this.getField(ID_TYPE)) } #selectSingelUser(field: UserField | CreatedByField | UpdatedByField) { - const db = this.dbProvider - const as = createDisplayFieldName(field) const user = getTableName(users) diff --git a/packages/persistence/src/record/record.mutate-visitor.ts b/packages/persistence/src/record/record.mutate-visitor.ts index 832f32654..f8bacb29a 100644 --- a/packages/persistence/src/record/record.mutate-visitor.ts +++ b/packages/persistence/src/record/record.mutate-visitor.ts @@ -76,6 +76,7 @@ import { startOfDay, startOfToday, startOfTomorrow, startOfYesterday } from "dat import { sql, type ExpressionBuilder } from "kysely" import { unique } from "radash" import { AbstractQBMutationVisitor } from "../abstract-qb.visitor" +import type { IDbProvider } from "../db.provider" import type { IQueryBuilder, IRecordQueryBuilder } from "../qb.type" import { JoinTable } from "../underlying/reference/join-table" import { getDateRangeFieldName } from "../underlying/underlying-table.util" @@ -90,14 +91,14 @@ export class RecordMutateVisitor extends AbstractQBMutationVisitor implements IR private readonly qb: IRecordQueryBuilder, private readonly eb: ExpressionBuilder, private readonly context: IContext, - private readonly dbProvider: string, + private readonly dbProvider: IDbProvider, ) { super() } #setDate(fieldId: string, value: Date | null) { if (value) { - this.setData(fieldId, this.dbProvider === "postgres" ? value : value.getTime()) + this.setData(fieldId, this.dbProvider.isPostgres() ? value : value.getTime()) } else { this.setData(fieldId, null) } diff --git a/packages/persistence/src/record/record.repository.ts b/packages/persistence/src/record/record.repository.ts index 0757e1214..e076d4f7b 100644 --- a/packages/persistence/src/record/record.repository.ts +++ b/packages/persistence/src/record/record.repository.ts @@ -22,7 +22,7 @@ import { chunk } from "es-toolkit/array" import { sql, type CompiledQuery, type ExpressionBuilder } from "kysely" import type { ITxContext } from "../ctx.interface" import { injectTxCTX } from "../ctx.provider" -import { injectDbProvider } from "../db.provider" +import { DbProviderService, type IDbProvider } from "../db.provider" import { UnderlyingTable } from "../underlying/underlying-table" import { RecordQueryHelper } from "./record-query.helper" import { getRecordDTOFromEntity } from "./record-utils" @@ -44,8 +44,8 @@ export class RecordRepository implements IRecordRepository { private readonly context: IContext, @injectTxCTX() private readonly txContext: ITxContext, - @injectDbProvider() - private readonly dbProvider: string, + @inject(DbProviderService) + private readonly dbProvider: IDbProvider, ) {} private async getForeignTables(table: TableDo, fields: Field[]): Promise> { diff --git a/packages/persistence/src/underlying/underlying-table-field.visitor.ts b/packages/persistence/src/underlying/underlying-table-field.visitor.ts index 960f83585..d9b93c0ff 100644 --- a/packages/persistence/src/underlying/underlying-table-field.visitor.ts +++ b/packages/persistence/src/underlying/underlying-table-field.visitor.ts @@ -31,6 +31,7 @@ import { type UpdatedAtField, } from "@undb/table" import { AlterTableBuilder, AlterTableColumnAlteringBuilder, CompiledQuery, CreateTableBuilder, sql } from "kysely" +import type { IDbProvider } from "../db.provider" import type { IQueryBuilder } from "../qb.type" import { JoinTable } from "./reference/join-table" import { getUnderlyingFormulaType } from "./underlying-formula.util" @@ -45,7 +46,7 @@ export class UnderlyingTableFieldVisitor private readonly qb: IQueryBuilder, private readonly t: UnderlyingTable, public tb: TB, - private readonly dbProvider: string, + private readonly dbProvider: IDbProvider, public readonly isNew: boolean = false, ) {} public atb: AlterTableColumnAlteringBuilder | CreateTableBuilder | null = null @@ -70,7 +71,7 @@ export class UnderlyingTableFieldVisitor const c = this.tb.addColumn(field.id.value, "timestamp", (b) => b.defaultTo(sql`(CURRENT_TIMESTAMP)`).notNull()) this.addColumn(c) - if (this.dbProvider === "postgres") { + if (this.dbProvider.isPostgres()) { const query = sql .raw( ` @@ -104,7 +105,7 @@ CREATE TRIGGER update_customer_modtime_${tableName} BEFORE UPDATE ON ${tableName } } autoIncrement(field: AutoIncrementField): void { - if (this.dbProvider === "postgres") { + if (this.dbProvider.isPostgres()) { const c = this.tb.addColumn(field.id.value, "bigserial", (b) => b.primaryKey()) this.addColumn(c) } else { @@ -203,7 +204,7 @@ CREATE TRIGGER update_customer_modtime_${tableName} BEFORE UPDATE ON ${tableName } rollup(field: RollupField): void {} checkbox(field: CheckboxField): void { - const defaultValue = this.dbProvider === "postgres" ? false : 0 + const defaultValue = this.dbProvider.isPostgres() ? false : 0 const c = this.tb.addColumn(field.id.value, "boolean", (b) => b.defaultTo(defaultValue).notNull()) this.addColumn(c) } @@ -230,7 +231,7 @@ CREATE TRIGGER update_customer_modtime_${tableName} BEFORE UPDATE ON ${tableName const type = getUnderlyingFormulaType(field.returnType) const c = this.tb.addColumn(field.id.value, type, (b) => { const column = b.generatedAlwaysAs(sql.raw(parsed)) - if (this.dbProvider === "postgres") { + if (this.dbProvider.isPostgres()) { return column.stored() } return this.isNew ? column.stored() : column diff --git a/packages/persistence/src/underlying/underlying-table-spec.visitor.ts b/packages/persistence/src/underlying/underlying-table-spec.visitor.ts index 8269e4162..2d8ea49c3 100644 --- a/packages/persistence/src/underlying/underlying-table-spec.visitor.ts +++ b/packages/persistence/src/underlying/underlying-table-spec.visitor.ts @@ -47,6 +47,7 @@ import type { } from "@undb/table/src/specifications/table-forms.specification" import type { WithTableRLS } from "@undb/table/src/specifications/table-rls.specification" import { AlterTableBuilder, AlterTableColumnAlteringBuilder, CompiledQuery, CreateTableBuilder, sql } from "kysely" +import type { IDbProvider } from "../db.provider" import type { IRecordQueryBuilder } from "../qb.type" import type { IDatabaseFnUtil } from "../utils/fn.util" import { ConversionContext } from "./conversion/conversion.context" @@ -62,7 +63,7 @@ export class UnderlyingTableSpecVisitor implements ITableSpecVisitor { public readonly table: UnderlyingTable, public readonly qb: IRecordQueryBuilder, public readonly context: IContext, - private readonly dbProvider: string, + private readonly dbProvider: IDbProvider, private readonly dbFnUtil: IDatabaseFnUtil, ) { this.tb = qb.schema.alterTable(table.name) diff --git a/packages/persistence/src/underlying/underlying-table.service.ts b/packages/persistence/src/underlying/underlying-table.service.ts index f04fd966c..9a97ac525 100644 --- a/packages/persistence/src/underlying/underlying-table.service.ts +++ b/packages/persistence/src/underlying/underlying-table.service.ts @@ -5,7 +5,7 @@ import type { TableComositeSpecification, TableDo } from "@undb/table" import type { CompiledQuery } from "kysely" import type { ITxContext } from "../ctx.interface" import { injectTxCTX } from "../ctx.provider" -import { injectDbProvider } from "../db.provider" +import { DbProviderService, type IDbProvider } from "../db.provider" import { DatabaseFnUtil, type IDatabaseFnUtil } from "../utils/fn.util" import { JoinTable } from "./reference/join-table" import { UnderlyingTable } from "./underlying-table" @@ -18,8 +18,8 @@ export class UnderlyingTableService { @injectContext() private readonly context: IContext, @injectTxCTX() private readonly txContext: ITxContext, - @injectDbProvider() - private readonly dbProvider: string, + @inject(DbProviderService) + private readonly dbProvider: IDbProvider, @inject(DatabaseFnUtil) private readonly dbFnUtil: IDatabaseFnUtil, ) {} diff --git a/packages/persistence/src/utils/fn.util.ts b/packages/persistence/src/utils/fn.util.ts index 6db0a8310..06efae869 100644 --- a/packages/persistence/src/utils/fn.util.ts +++ b/packages/persistence/src/utils/fn.util.ts @@ -1,6 +1,6 @@ -import { singleton } from "@undb/di" +import { inject, singleton } from "@undb/di" import { match } from "ts-pattern" -import { injectDbProvider } from "../db.provider" +import { DbProviderService, type IDbProvider } from "../db.provider" export interface IDatabaseFnUtil { get jsonGroupArray(): string @@ -10,23 +10,32 @@ export interface IDatabaseFnUtil { @singleton() export class DatabaseFnUtil implements IDatabaseFnUtil { - constructor(@injectDbProvider() private readonly dbProvider: string) {} + constructor(@inject(DbProviderService) private readonly dbProvider: IDbProvider) {} get jsonGroupArray() { return match(this.dbProvider) - .with("postgres", () => "json_agg") + .when( + (p) => p.isPostgres(), + () => "json_agg", + ) .otherwise(() => "json_group_array") } get jsonObject() { return match(this.dbProvider) - .with("postgres", () => "json_build_object") + .when( + (p) => p.isPostgres(), + () => "json_build_object", + ) .otherwise(() => "json_object") } get jsonArray() { return match(this.dbProvider) - .with("postgres", () => "json_build_array") + .when( + (p) => p.isPostgres(), + () => "json_build_array", + ) .otherwise(() => "json_array") } }