diff --git a/example/with-fabbrica/package-lock.json b/example/with-fabbrica/package-lock.json index 5e862d3..fbb1061 100644 --- a/example/with-fabbrica/package-lock.json +++ b/example/with-fabbrica/package-lock.json @@ -1,10 +1,10 @@ { - "name": "with-jest-example", + "name": "with-fabbrica-example", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "with-jest-example", + "name": "with-fabbrica-example", "dependencies": { "@prisma/client": "^5.11.0" }, diff --git a/example/with-fabbrica/prisma/schema.prisma b/example/with-fabbrica/prisma/schema.prisma index e1980a3..f330f24 100644 --- a/example/with-fabbrica/prisma/schema.prisma +++ b/example/with-fabbrica/prisma/schema.prisma @@ -13,15 +13,30 @@ generator fabbrica { } model User { - id Int @id @default(autoincrement()) - name String - posts Post[] + id Int @id @default(autoincrement()) + name String + posts Post[] + comment Comment[] } model Post { - id Int @id @default(autoincrement()) + id Int @id @default(autoincrement()) title String content String - User User? @relation(fields: [userId], references: [id]) + user User? @relation(fields: [userId], references: [id]) userId Int? + comment Comment[] + + @@map("posts") +} + +model Comment { + id Int @id @default(autoincrement()) + body String + user User? @relation(fields: [userId], references: [id]) + post Post @relation(fields: [postId], references: [id]) + userId Int? + postId Int + + @@map("comments") } diff --git a/example/with-fabbrica/src/UserService.test.ts b/example/with-fabbrica/src/UserService.test.ts index 8b123d9..ba3237f 100644 --- a/example/with-fabbrica/src/UserService.test.ts +++ b/example/with-fabbrica/src/UserService.test.ts @@ -1,26 +1,29 @@ import { PrismaClient } from "@prisma/client"; -import { defineUserFactory } from "./__generated__/fabbrica"; -import { UserService } from "./UserService"; +import { definePostFactory, defineUserFactory } from "./__generated__/fabbrica"; import { cleaner } from "../test/cleaner"; const UserFactory = defineUserFactory(); +const PostFactory = definePostFactory({ + defaultData: { + user: UserFactory, + }, +}); describe("UserService", () => { const prisma = new PrismaClient().$extends( cleaner.withCleaner(), ) as PrismaClient; - const userService = new UserService(prisma); - it("should create a new user", async () => { // this record will delete by prisma-cleaner in afterEach defined by setup.ts - const user = await userService.createUser("xxx"); - expect(user.name).toEqual("xxx"); - await UserFactory.create(); - expect(await prisma.user.count()).toEqual(2); + const post = await PostFactory.create({ title: "xxx" }); + expect(post.title).toEqual("xxx"); + + expect(await prisma.user.count()).toEqual(1); + expect(await prisma.post.count()).toEqual(1); }); it("should be cleanup user table by cleaner", async () => { - const count = await prisma.user.count(); - expect(count).toEqual(0); + expect(await prisma.user.count()).toEqual(0); + expect(await prisma.post.count()).toEqual(0); }); }); diff --git a/example/with-jest/prisma/schema.prisma b/example/with-jest/prisma/schema.prisma deleted file mode 100644 index 8f8cc17..0000000 --- a/example/with-jest/prisma/schema.prisma +++ /dev/null @@ -1,22 +0,0 @@ -generator client { - provider = "prisma-client-js" -} - -datasource db { - provider = "postgresql" - url = "postgresql://postgres@localhost:5432/prisma_cleaner" -} - -model User { - id Int @id @default(autoincrement()) - name String - posts Post[] -} - -model Post { - id Int @id @default(autoincrement()) - title String - content String - User User? @relation(fields: [userId], references: [id]) - userId Int? -} diff --git a/example/with-jest/prisma/schema.prisma b/example/with-jest/prisma/schema.prisma new file mode 120000 index 0000000..59b62d0 --- /dev/null +++ b/example/with-jest/prisma/schema.prisma @@ -0,0 +1 @@ +../../../prisma/schema.prisma \ No newline at end of file diff --git a/example/with-mock-client/prisma/schema.prisma b/example/with-mock-client/prisma/schema.prisma deleted file mode 100644 index 8f8cc17..0000000 --- a/example/with-mock-client/prisma/schema.prisma +++ /dev/null @@ -1,22 +0,0 @@ -generator client { - provider = "prisma-client-js" -} - -datasource db { - provider = "postgresql" - url = "postgresql://postgres@localhost:5432/prisma_cleaner" -} - -model User { - id Int @id @default(autoincrement()) - name String - posts Post[] -} - -model Post { - id Int @id @default(autoincrement()) - title String - content String - User User? @relation(fields: [userId], references: [id]) - userId Int? -} diff --git a/example/with-mock-client/prisma/schema.prisma b/example/with-mock-client/prisma/schema.prisma new file mode 120000 index 0000000..59b62d0 --- /dev/null +++ b/example/with-mock-client/prisma/schema.prisma @@ -0,0 +1 @@ +../../../prisma/schema.prisma \ No newline at end of file diff --git a/example/with-nestjs/prisma/schema.prisma b/example/with-nestjs/prisma/schema.prisma deleted file mode 100644 index 8f8cc17..0000000 --- a/example/with-nestjs/prisma/schema.prisma +++ /dev/null @@ -1,22 +0,0 @@ -generator client { - provider = "prisma-client-js" -} - -datasource db { - provider = "postgresql" - url = "postgresql://postgres@localhost:5432/prisma_cleaner" -} - -model User { - id Int @id @default(autoincrement()) - name String - posts Post[] -} - -model Post { - id Int @id @default(autoincrement()) - title String - content String - User User? @relation(fields: [userId], references: [id]) - userId Int? -} diff --git a/example/with-nestjs/prisma/schema.prisma b/example/with-nestjs/prisma/schema.prisma new file mode 120000 index 0000000..59b62d0 --- /dev/null +++ b/example/with-nestjs/prisma/schema.prisma @@ -0,0 +1 @@ +../../../prisma/schema.prisma \ No newline at end of file diff --git a/example/with-vitest/prisma/schema.prisma b/example/with-vitest/prisma/schema.prisma deleted file mode 100644 index 8f8cc17..0000000 --- a/example/with-vitest/prisma/schema.prisma +++ /dev/null @@ -1,22 +0,0 @@ -generator client { - provider = "prisma-client-js" -} - -datasource db { - provider = "postgresql" - url = "postgresql://postgres@localhost:5432/prisma_cleaner" -} - -model User { - id Int @id @default(autoincrement()) - name String - posts Post[] -} - -model Post { - id Int @id @default(autoincrement()) - title String - content String - User User? @relation(fields: [userId], references: [id]) - userId Int? -} diff --git a/example/with-vitest/prisma/schema.prisma b/example/with-vitest/prisma/schema.prisma new file mode 120000 index 0000000..59b62d0 --- /dev/null +++ b/example/with-vitest/prisma/schema.prisma @@ -0,0 +1 @@ +../../../prisma/schema.prisma \ No newline at end of file diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 8f8cc17..86c0b62 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -8,15 +8,30 @@ datasource db { } model User { - id Int @id @default(autoincrement()) - name String - posts Post[] + id Int @id @default(autoincrement()) + name String + posts Post[] + comment Comment[] } model Post { - id Int @id @default(autoincrement()) + id Int @id @default(autoincrement()) title String content String - User User? @relation(fields: [userId], references: [id]) + user User? @relation(fields: [userId], references: [id]) userId Int? + comment Comment[] + + @@map("posts") +} + +model Comment { + id Int @id @default(autoincrement()) + body String + user User? @relation(fields: [userId], references: [id]) + post Post @relation(fields: [postId], references: [id]) + userId Int? + postId Int + + @@map("comments") } diff --git a/src/index.ts b/src/index.ts index 0a48edd..6c4d647 100644 --- a/src/index.ts +++ b/src/index.ts @@ -10,6 +10,11 @@ type PrismaClientLike = { type ModelLike = { name: string; dbName: string | null; + fields: readonly { + name: string; + kind: string; + type: string; + }[]; }; type Table = { @@ -19,10 +24,24 @@ type Table = { const targetOperations = ["create", "createMany", "upsert"]; +function isPlainObject(obj: unknown): obj is Record { + return obj != null && Object.getPrototypeOf(obj) === Object.prototype; +} + export class PrismaCleaner { private readonly prisma: PrismaClientLike; private readonly cleanupTargetModels = new Set(); - private readonly tableByModel = new Map(); + private readonly modelsMap = new Map< + string, // model name + { + table: string; + fields: readonly { + name: string; + kind: string; + type: string; + }[]; + } + >(); private tables: Table[] | null = null; private schemaListByTableName: Record | null = null; @@ -35,8 +54,14 @@ export class PrismaCleaner { models: readonly ModelLike[] | ModelLike[]; }) { this.prisma = prisma; - this.tableByModel = new Map( - models.map((model) => [model.name, model.dbName || model.name]), + this.modelsMap = new Map( + models.map((model) => [ + model.name, + { + table: model.dbName || model.name, + fields: model.fields, + }, + ]), ); } @@ -49,6 +74,7 @@ export class PrismaCleaner { async $allOperations({ operation, model, args, query }) { if (model && targetOperations.includes(operation)) { self.cleanupTargetModels.add(model); + self.addTargetModelByArgs(model, args); } return query(args); }, @@ -83,7 +109,7 @@ export class PrismaCleaner { const targetTableNames = Array.from(this.cleanupTargetModels) .map((model) => { - return this.tableByModel.get(model); + return this.modelsMap.get(model)?.table; }) .filter((table): table is string => table != null); const schemaListByTableName = await this.getSchemaListByTableName(); @@ -137,4 +163,44 @@ AND table_name != '_prisma_migrations' ); return this.tables; } + + private addTargetModelByArgs(modelName: string, args: unknown): void { + if (!isPlainObject(args)) return; + this.addTargetModelByInputData(modelName, args.data); + } + + private addTargetModelByInputData(modelName: string, data: unknown): void { + const model = this.modelsMap.get(modelName); + if (!model) return; + + if (Array.isArray(data)) { + data.forEach((d) => this.addTargetModelByInputData(modelName, d)); + return; + } + if (!isPlainObject(data)) return; + + for (const [key, value] of Object.entries(data)) { + if (!isPlainObject(value)) continue; + if (isPlainObject(value.create)) { + const field = model.fields.find((f) => f.name === key); + if (field) { + this.cleanupTargetModels.add(field.type); + this.addTargetModelByInputData(field.type, value.create); + } + } + if ( + isPlainObject(value.connectOrCreate) && + isPlainObject(value.connectOrCreate.create) + ) { + const field = model.fields.find((f) => f.name === key); + if (field) { + this.cleanupTargetModels.add(field.type); + this.addTargetModelByInputData( + field.type, + value.connectOrCreate.create, + ); + } + } + } + } } diff --git a/test/cleaner.test.ts b/test/cleaner.test.ts index e60ebf7..a1b131b 100644 --- a/test/cleaner.test.ts +++ b/test/cleaner.test.ts @@ -44,6 +44,34 @@ describe("PrismaCleaner", () => { expect(await prisma.user.count()).toBe(0); }); + test("with nesting", async () => { + await prisma.comment.create({ + data: { + body: "xxx", + post: { + create: { + title: "yyy", + content: "zzz", + user: { + connectOrCreate: { + where: { + id: 1, + }, + create: { + name: "foo", + }, + }, + }, + }, + }, + }, + }); + await cleaner.cleanup(); + expect(await prisma.user.count()).toBe(0); + expect(await prisma.post.count()).toBe(0); + expect(await prisma.comment.count()).toBe(0); + }); + test("manually cleanup", async () => { const insert = () => prisma.$executeRaw`insert into "User" (name) values ('xxx')`;