diff --git a/.vscode/launch.json b/.vscode/launch.json index 5ed666059a9..be13d723cc3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -360,6 +360,28 @@ }, "console": "internalConsole" }, + { + "name": "server: RQ - consensus", + "type": "debugpy", + "request": "launch", + "stopOnEntry": false, + "justMyCode": false, + "python": "${command:python.interpreterPath}", + "program": "${workspaceRoot}/manage.py", + "args": [ + "rqworker", + "consensus", + "--worker-class", + "cvat.rqworker.SimpleWorker" + ], + "django": true, + "cwd": "${workspaceFolder}", + "env": { + "DJANGO_LOG_SERVER_HOST": "localhost", + "DJANGO_LOG_SERVER_PORT": "8282" + }, + "console": "internalConsole" + }, { "name": "server: migrate", "type": "debugpy", @@ -537,7 +559,8 @@ "server: RQ - scheduler", "server: RQ - quality reports", "server: RQ - analytics reports", - "server: RQ - cleaning" + "server: RQ - cleaning", + "server: RQ - consensus", ] } ] diff --git a/cvat-core/src/api-implementation.ts b/cvat-core/src/api-implementation.ts index 0e9f400ad49..9de58b83770 100644 --- a/cvat-core/src/api-implementation.ts +++ b/cvat-core/src/api-implementation.ts @@ -31,8 +31,11 @@ import Organization, { Invitation } from './organization'; import Webhook from './webhook'; import { ArgumentError } from './exceptions'; import { - AnalyticsReportFilter, QualityConflictsFilter, QualityReportsFilter, + AnalyticsReportFilter, ConflictsFilter, QualityReportsFilter, QualitySettingsFilter, SerializedAsset, + ConsensusReportsFilter, + AssigneeConsensusReportsFilter, + ConsensusSettingsFilter, } from './server-response-types'; import QualityReport from './quality-report'; import QualityConflict, { ConflictSeverity } from './quality-conflict'; @@ -44,11 +47,85 @@ import { convertDescriptions, getServerAPISchema } from './server-schema'; import { JobType } from './enums'; import { PaginatedResource } from './core-types'; import CVATCore from '.'; +import ConsensusSettings from './consensus-settings'; +import ConsensusReport from './consensus-report'; +import AssigneeConsensusReport from './assignee-consensus-report'; +import ConsensusConflict from './consensus-conflict'; function implementationMixin(func: Function, implementation: Function): void { Object.assign(func, { implementation }); } +type ConflictType = ConsensusConflict | QualityConflict; + +function mergeConflicts(conflicts: T[]): T[] { + const frames = Array.from(new Set(conflicts.map((conflict) => conflict.frame))) + .sort((a, b) => a - b); + + const mergedConflicts: T[] = []; + + for (const frame of frames) { + const frameConflicts = conflicts.filter((conflict) => conflict.frame === frame); + const conflictsByObject: Record = {}; + + frameConflicts.forEach((qualityConflict: T) => { + const { type, serverID } = qualityConflict.annotationConflicts[0]; + const firstObjID = `${type}_${serverID}`; + conflictsByObject[firstObjID] = conflictsByObject[firstObjID] || []; + conflictsByObject[firstObjID].push(qualityConflict); + }); + + for (const objectConflicts of Object.values(conflictsByObject)) { + if (objectConflicts.length === 1) { + mergedConflicts.push(objectConflicts[0]); + } else { + const [firstConflict] = objectConflicts; + let mainObjectConflict: T; + + if (firstConflict instanceof QualityConflict) { + mainObjectConflict = objectConflicts.find( + (conflict) => (conflict as QualityConflict).severity === ConflictSeverity.ERROR, + ) || firstConflict; + } else { + mainObjectConflict = firstConflict; + } + const descriptionList: string[] = [mainObjectConflict.description]; + + for (const objectConflict of objectConflicts) { + if (objectConflict !== mainObjectConflict) { + descriptionList.push(objectConflict.description); + + for (const annotationConflict of objectConflict.annotationConflicts) { + if (!mainObjectConflict.annotationConflicts.find((_annotationConflict) => ( + _annotationConflict.serverID === annotationConflict.serverID && + _annotationConflict.type === annotationConflict.type)) + ) { + mainObjectConflict.annotationConflicts.push(annotationConflict); + } + } + } + } + + const description = descriptionList.join(', '); + const visibleConflict = new Proxy(mainObjectConflict, { + get(target, prop) { + if (prop === 'description') { + return description; + } + + const val = Reflect.get(target, prop); + return typeof val === 'function' ? (...args: any[]) => val.apply(target, args) : val; + }, + }); + + mergedConflicts.push(visibleConflict); + } + } + } + + return mergedConflicts; +} + export default function implementAPI(cvat: CVATCore): CVATCore { implementationMixin(cvat.plugins.list, PluginRegistry.list); implementationMixin(cvat.plugins.register, PluginRegistry.register.bind(cvat)); @@ -434,7 +511,7 @@ export default function implementAPI(cvat: CVATCore): CVATCore { ); return reports; }); - implementationMixin(cvat.analytics.quality.conflicts, async (filter: QualityConflictsFilter) => { + implementationMixin(cvat.analytics.quality.conflicts, async (filter: ConflictsFilter) => { checkFilter(filter, { reportID: isInteger, }); @@ -443,72 +520,7 @@ export default function implementAPI(cvat: CVATCore): CVATCore { const conflictsData = await serverProxy.analytics.quality.conflicts(params); const conflicts = conflictsData.map((conflict) => new QualityConflict({ ...conflict })); - const frames = Array.from(new Set(conflicts.map((conflict) => conflict.frame))) - .sort((a, b) => a - b); - - // each QualityConflict may have several AnnotationConflicts bound - // at the same time, many quality conflicts may refer - // to the same labeled object (e.g. mismatch label, low overlap) - // the code below unites quality conflicts bound to the same object into one QualityConflict object - const mergedConflicts: QualityConflict[] = []; - - for (const frame of frames) { - const frameConflicts = conflicts.filter((conflict) => conflict.frame === frame); - const conflictsByObject: Record = {}; - - frameConflicts.forEach((qualityConflict: QualityConflict) => { - const { type, serverID } = qualityConflict.annotationConflicts[0]; - const firstObjID = `${type}_${serverID}`; - conflictsByObject[firstObjID] = conflictsByObject[firstObjID] || []; - conflictsByObject[firstObjID].push(qualityConflict); - }); - - for (const objectConflicts of Object.values(conflictsByObject)) { - if (objectConflicts.length === 1) { - // only one quality conflict refers to the object on current frame - mergedConflicts.push(objectConflicts[0]); - } else { - const mainObjectConflict = objectConflicts - .find((conflict) => conflict.severity === ConflictSeverity.ERROR) || objectConflicts[0]; - const descriptionList: string[] = [mainObjectConflict.description]; - - for (const objectConflict of objectConflicts) { - if (objectConflict !== mainObjectConflict) { - descriptionList.push(objectConflict.description); - - for (const annotationConflict of objectConflict.annotationConflicts) { - if (!mainObjectConflict.annotationConflicts.find((_annotationConflict) => ( - _annotationConflict.serverID === annotationConflict.serverID && - _annotationConflict.type === annotationConflict.type)) - ) { - mainObjectConflict.annotationConflicts.push(annotationConflict); - } - } - } - } - - // decorate the original conflict to avoid changing it - const description = descriptionList.join(', '); - const visibleConflict = new Proxy(mainObjectConflict, { - get(target, prop) { - if (prop === 'description') { - return description; - } - - // By default, it looks like Reflect.get(target, prop, receiver) - // which has a different value of `this`. It doesn't allow to - // work with methods / properties that use private members. - const val = Reflect.get(target, prop); - return typeof val === 'function' ? (...args: any[]) => val.apply(target, args) : val; - }, - }); - - mergedConflicts.push(visibleConflict); - } - } - } - - return mergedConflicts; + return mergeConflicts(conflicts); }); implementationMixin(cvat.analytics.quality.settings.get, async (filter: QualitySettingsFilter) => { checkFilter(filter, { @@ -523,6 +535,58 @@ export default function implementAPI(cvat: CVATCore): CVATCore { return new QualitySettings({ ...settings, descriptions }); }); + implementationMixin(cvat.consensus.reports, async (filter: ConsensusReportsFilter) => { + checkFilter(filter, { + page: isInteger, + pageSize: isPageSize, + projectID: isInteger, + taskID: isInteger, + jobID: isInteger, + filter: isString, + search: isString, + target: isString, + sort: isString, + }); + + const params = fieldsToSnakeCase({ ...filter, sort: '-created_date' }); + + const reportsData = await serverProxy.consensus.reports(params); + const reports = Object.assign( + reportsData.map((report) => new ConsensusReport({ ...report })), + { count: reportsData.count }, + ); + return reports; + }); + implementationMixin(cvat.consensus.assigneeReports, async (filter: AssigneeConsensusReportsFilter) => { + checkFilter(filter, { + page: isInteger, + pageSize: isPageSize, + taskID: isInteger, + filter: isString, + consensusReportID: isInteger, + search: isString, + sort: isString, + }); + + const params = fieldsToSnakeCase({ ...filter, sort: '-id' }); + + const reportsData = await serverProxy.consensus.assigneeReports(params); + const reports = Object.assign( + reportsData.map((report) => new AssigneeConsensusReport({ ...report })), + { count: reportsData.count }, + ); + return reports; + }); + implementationMixin(cvat.consensus.settings.get, async (filter: ConsensusSettingsFilter) => { + checkFilter(filter, { + taskID: isInteger, + }); + + const params = fieldsToSnakeCase(filter); + + const settings = await serverProxy.consensus.settings.get(params); + return new ConsensusSettings({ ...settings }); + }); implementationMixin(cvat.analytics.performance.reports, async (filter: AnalyticsReportFilter) => { checkFilter(filter, { jobID: isInteger, @@ -538,6 +602,17 @@ export default function implementAPI(cvat: CVATCore): CVATCore { const reportData = await serverProxy.analytics.performance.reports(params); return new AnalyticsReport(reportData); }); + implementationMixin(cvat.consensus.conflicts, async (filter: ConflictsFilter) => { + checkFilter(filter, { + reportID: isInteger, + }); + + const params = fieldsToSnakeCase(filter); + + const conflictsData = await serverProxy.consensus.conflicts(params); + const conflicts = conflictsData.map((conflict) => new ConsensusConflict({ ...conflict })); + return mergeConflicts(conflicts); + }); implementationMixin(cvat.analytics.performance.calculate, async ( body: Parameters[0], onUpdate: Parameters[1], diff --git a/cvat-core/src/api.ts b/cvat-core/src/api.ts index fe3975217ce..5c056941c7c 100644 --- a/cvat-core/src/api.ts +++ b/cvat-core/src/api.ts @@ -401,6 +401,26 @@ function build(): CVATCore { return result; }, }, + consensus: { + async assigneeReports(filter = {}) { + const result = await PluginRegistry.apiWrapper(cvat.consensus.assigneeReports, filter); + return result; + }, + async reports(filter = {}) { + const result = await PluginRegistry.apiWrapper(cvat.consensus.reports, filter); + return result; + }, + async conflicts(filter = {}) { + const result = await PluginRegistry.apiWrapper(cvat.consensus.conflicts, filter); + return result; + }, + settings: { + async get(filter = {}) { + const result = await PluginRegistry.apiWrapper(cvat.consensus.settings.get, filter); + return result; + }, + }, + }, classes: { User, Project: implementProject(Project), @@ -452,6 +472,7 @@ function build(): CVATCore { cvat.organizations = Object.freeze(cvat.organizations); cvat.webhooks = Object.freeze(cvat.webhooks); cvat.analytics = Object.freeze(cvat.analytics); + cvat.consensus = Object.freeze(cvat.consensus); cvat.classes = Object.freeze(cvat.classes); cvat.utils = Object.freeze(cvat.utils); diff --git a/cvat-core/src/assignee-consensus-report.ts b/cvat-core/src/assignee-consensus-report.ts new file mode 100644 index 00000000000..246b9036c63 --- /dev/null +++ b/cvat-core/src/assignee-consensus-report.ts @@ -0,0 +1,53 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedAssigneeConsensusReportData } from './server-response-types'; +import User from './user'; + +export default class AssigneeConsensusReport { + #id: number; + #taskID: number; + #assignee: User; + #consensusScore: number; + #conflictCount: number; + #consensusReportID: number; + + constructor(initialData: SerializedAssigneeConsensusReportData) { + this.#id = initialData.id; + this.#taskID = initialData.task_id; + this.#consensusScore = initialData.consensus_score; + this.#consensusReportID = initialData.consensus_report_id; + this.#conflictCount = initialData.conflict_count; + + if (initialData.assignee) { + this.#assignee = new User(initialData.assignee); + } else { + this.#assignee = null; + } + } + + get id(): number { + return this.#id; + } + + get taskID(): number { + return this.#taskID; + } + + get assignee(): User { + return this.#assignee; + } + + get consensusScore(): number { + return this.#consensusScore; + } + + get conflictCount(): number { + return this.#conflictCount; + } + + get consensusReportID(): number { + return this.#consensusReportID; + } +} diff --git a/cvat-core/src/consensus-conflict.ts b/cvat-core/src/consensus-conflict.ts new file mode 100644 index 00000000000..79ae6316294 --- /dev/null +++ b/cvat-core/src/consensus-conflict.ts @@ -0,0 +1,103 @@ +// Copyright (C) 2023 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedAnnotationConsensusConflictData, SerializedConsensusConflictData } from './server-response-types'; +import { ObjectType } from './enums'; + +export enum ConsensusConflictType { + NO_MATCHING_ITEM = 'no_matching_item', + NO_MATCHING_ANNOTATION = 'no_matching_annotation', + ANNOTATION_TOO_CLOSE = 'annotation_too_close', + FAILED_LABEL_VOTING = 'failed_label_voting', +} + +export class AnnotationConflict { + #jobID: number; + #serverID: number; + #type: ObjectType; + #shapeType: string | null; + #description: string; + #conflictType: ConsensusConflictType; + + constructor(initialData: SerializedAnnotationConsensusConflictData) { + this.#jobID = initialData.job_id; + this.#serverID = initialData.obj_id; + this.#type = initialData.type; + this.#shapeType = initialData.shape_type; + this.#conflictType = initialData.conflict_type as ConsensusConflictType; + + const desc = this.#conflictType.split('_').join(' '); + this.#description = desc.charAt(0).toUpperCase() + desc.slice(1); + } + + get jobID(): number { + return this.#jobID; + } + + get serverID(): number { + return this.#serverID; + } + + get type(): ObjectType { + return this.#type; + } + + get shapeType(): string | null { + return this.#shapeType; + } + + get conflictType(): ConsensusConflictType { + return this.#conflictType; + } + + get description(): string { + return this.#description; + } +} + +export default class ConsensusConflict { + #id: number; + #frame: number; + #type: ConsensusConflictType; + #annotationConflicts: AnnotationConflict[]; + #description: string; + + constructor(initialData: SerializedConsensusConflictData) { + this.#id = initialData.id; + this.#frame = initialData.frame; + this.#type = initialData.type as ConsensusConflictType; + this.#annotationConflicts = initialData.annotation_ids + .map((rawData: SerializedAnnotationConsensusConflictData) => new AnnotationConflict({ + ...rawData, + conflict_type: initialData.type, + })); + + const desc = initialData.type.split('_').join(' '); + this.#description = desc.charAt(0).toUpperCase() + desc.slice(1); + } + + get id(): number { + return this.#id; + } + + get frame(): number { + return this.#frame; + } + + get type(): ConsensusConflictType { + return this.#type; + } + + get annotationConflicts(): AnnotationConflict[] { + return this.#annotationConflicts; + } + + get description(): string { + return this.#description; + } + + set description(newDescription: string) { + this.#description = newDescription; + } +} diff --git a/cvat-core/src/consensus-report.ts b/cvat-core/src/consensus-report.ts new file mode 100644 index 00000000000..50bccbadce9 --- /dev/null +++ b/cvat-core/src/consensus-report.ts @@ -0,0 +1,85 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedConsensusReportData } from './server-response-types'; +import User from './user'; + +export interface ConsensusSummary { + frameCount: number; + conflictCount: number; + conflictsByType: { + no_matching_item: number; + no_matching_annotation: number; + annotation_too_close: number; + failed_label_voting: number; + } +} + +export default class ConsensusReport { + #id: number; + #taskID: number; + #jobID: number | null; + #createdDate: string; + #assignee: User | null; + #consensus_score: number; + #target: string; + #summary: Partial; + + constructor(initialData: SerializedConsensusReportData) { + this.#id = initialData.id; + this.#taskID = initialData.task_id; + this.#jobID = initialData.job_id; + this.#createdDate = initialData.created_date; + this.#target = initialData.target; + this.#consensus_score = initialData.consensus_score; + this.#summary = initialData.summary; + + if (initialData.assignee) { + this.#assignee = new User(initialData.assignee); + } else { + this.#assignee = null; + } + } + + get id(): number { + return this.#id; + } + + get taskID(): number { + return this.#taskID; + } + + get jobID(): number | null { + return this.#jobID; + } + + get createdDate(): string { + return this.#createdDate; + } + + get assignee(): User | null { + return this.#assignee; + } + + get consensus_score(): number { + return this.#consensus_score; + } + + get target(): string { + return this.#target; + } + + get summary(): ConsensusSummary { + return { + frameCount: this.#summary.frame_count, + conflictCount: this.#summary.conflict_count, + conflictsByType: { + no_matching_item: this.#summary.conflicts_by_type?.no_matching_item, + no_matching_annotation: this.#summary.conflicts_by_type?.no_matching_annotation, + annotation_too_close: this.#summary.conflicts_by_type?.annotation_too_close, + failed_label_voting: this.#summary.conflicts_by_type?.failed_label_voting, + }, + }; + } +} diff --git a/cvat-core/src/consensus-settings.ts b/cvat-core/src/consensus-settings.ts new file mode 100644 index 00000000000..51585bec88a --- /dev/null +++ b/cvat-core/src/consensus-settings.ts @@ -0,0 +1,103 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedConsensusSettingsData } from './server-response-types'; +import PluginRegistry from './plugins'; +import serverProxy from './server-proxy'; + +export default class ConsensusSettings { + #id: number; + #task: number; + #iouThreshold: number; + #quorum: number; + #agreementScoreThreshold: number; + #sigma: number; + #lineThickness: number; + + constructor(initialData: SerializedConsensusSettingsData) { + this.#id = initialData.id; + this.#task = initialData.task; + this.#iouThreshold = initialData.iou_threshold; + this.#agreementScoreThreshold = initialData.agreement_score_threshold; + this.#quorum = initialData.quorum; + this.#sigma = initialData.sigma; + this.#lineThickness = initialData.line_thickness; + } + + get id(): number { + return this.#id; + } + + get task(): number { + return this.#task; + } + + get iouThreshold(): number { + return this.#iouThreshold; + } + + set iouThreshold(newVal: number) { + this.#iouThreshold = newVal; + } + + get quorum(): number { + return this.#quorum; + } + + set quorum(newVal: number) { + this.#quorum = newVal; + } + + get sigma(): number { + return this.#sigma; + } + + set sigma(newVal: number) { + this.#sigma = newVal; + } + + get agreementScoreThreshold(): number { + return this.#agreementScoreThreshold; + } + + set agreementScoreThreshold(newVal: number) { + this.#agreementScoreThreshold = newVal; + } + + get lineThickness(): number { + return this.#lineThickness; + } + + set lineThickness(newVal: number) { + this.#lineThickness = newVal; + } + + public toJSON(): SerializedConsensusSettingsData { + const result: SerializedConsensusSettingsData = { + iou_threshold: this.#iouThreshold, + quorum: this.#quorum, + agreement_score_threshold: this.#agreementScoreThreshold, + sigma: this.#sigma, + line_thickness: this.#lineThickness, + }; + + return result; + } + + public async save(): Promise { + const result = await PluginRegistry.apiWrapper.call(this, ConsensusSettings.prototype.save); + return result; + } +} + +Object.defineProperties(ConsensusSettings.prototype.save, { + implementation: { + writable: false, + enumerable: false, + value: async function implementation() { + const result = await serverProxy.consensus.settings.update(this.id, this.toJSON()); + return new ConsensusSettings(result); + }, + }, +}); diff --git a/cvat-core/src/enums.ts b/cvat-core/src/enums.ts index 1b291662d21..cadd64230e9 100644 --- a/cvat-core/src/enums.ts +++ b/cvat-core/src/enums.ts @@ -37,6 +37,7 @@ export enum JobState { export enum JobType { ANNOTATION = 'annotation', GROUND_TRUTH = 'ground_truth', + CONSENSUS = 'consensus', } export enum DimensionType { diff --git a/cvat-core/src/index.ts b/cvat-core/src/index.ts index 4c68e2b2358..cdadd14fa5d 100644 --- a/cvat-core/src/index.ts +++ b/cvat-core/src/index.ts @@ -2,8 +2,12 @@ // // SPDX-License-Identifier: MIT +import ConsensusConflict from 'consensus-conflict'; +import AssigneeConsensusReport from 'assignee-consensus-report'; import { - AnalyticsReportFilter, QualityConflictsFilter, QualityReportsFilter, QualitySettingsFilter, + AnalyticsReportFilter, ConflictsFilter, QualityReportsFilter, QualitySettingsFilter, ConsensusReportsFilter, + AssigneeConsensusReportsFilter, + ConsensusSettingsFilter, } from './server-response-types'; import PluginRegistry from './plugins'; import serverProxy from './server-proxy'; @@ -30,6 +34,8 @@ import Webhook from './webhook'; import QualityReport from './quality-report'; import QualityConflict from './quality-conflict'; import QualitySettings from './quality-settings'; +import ConsensusReport from './consensus-report'; +import ConsensusSettings from './consensus-settings'; import AnalyticsReport from './analytics-report'; import AnnotationGuide from './guide'; import { Request } from './request'; @@ -131,10 +137,20 @@ export default interface CVATCore { webhooks: { get: any; }; + consensus: { + reports: (filter: ConsensusReportsFilter) => Promise>; + assigneeReports: ( + filter: AssigneeConsensusReportsFilter + ) => Promise>; + conflicts: (filter: ConflictsFilter) => Promise; + settings: { + get: (filter: ConsensusSettingsFilter) => Promise; + }; + } analytics: { quality: { reports: (filter: QualityReportsFilter) => Promise>; - conflicts: (filter: QualityConflictsFilter) => Promise; + conflicts: (filter: ConflictsFilter) => Promise; settings: { get: (filter: QualitySettingsFilter) => Promise; }; diff --git a/cvat-core/src/quality-conflict.ts b/cvat-core/src/quality-conflict.ts index 3d7252f37e8..7b107c8c836 100644 --- a/cvat-core/src/quality-conflict.ts +++ b/cvat-core/src/quality-conflict.ts @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: MIT -import { SerializedAnnotationConflictData, SerializedQualityConflictData } from './server-response-types'; +import { SerializedAnnotationQualityConflictData, SerializedQualityConflictData } from './server-response-types'; import { ObjectType } from './enums'; export enum QualityConflictType { @@ -25,7 +25,7 @@ export class AnnotationConflict { #severity: ConflictSeverity; #description: string; - constructor(initialData: SerializedAnnotationConflictData) { + constructor(initialData: SerializedAnnotationQualityConflictData) { this.#jobID = initialData.job_id; this.#serverID = initialData.obj_id; this.#type = initialData.type; @@ -80,7 +80,7 @@ export default class QualityConflict { this.#type = initialData.type as QualityConflictType; this.#severity = initialData.severity as ConflictSeverity; this.#annotationConflicts = initialData.annotation_ids - .map((rawData: SerializedAnnotationConflictData) => new AnnotationConflict({ + .map((rawData: SerializedAnnotationQualityConflictData) => new AnnotationConflict({ ...rawData, conflict_type: initialData.type, severity: initialData.severity, diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 51309426198..39dfc6f0fd0 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -17,9 +17,10 @@ import { SerializedAbout, SerializedRemoteFile, SerializedUserAgreement, SerializedRegister, JobsFilter, SerializedJob, SerializedGuide, SerializedAsset, SerializedAPISchema, SerializedInvitationData, SerializedCloudStorage, SerializedFramesMetaData, SerializedCollection, - SerializedQualitySettingsData, APIQualitySettingsFilter, SerializedQualityConflictData, APIQualityConflictsFilter, + SerializedQualitySettingsData, APIQualitySettingsFilter, SerializedQualityConflictData, APIConflictsFilter, SerializedQualityReportData, APIQualityReportsFilter, SerializedAnalyticsReport, APIAnalyticsReportFilter, - SerializedRequest, + SerializedConsensusSettingsData, SerializedRequest, APIConsensusReportsFilter, APIAssigneeConsensusReportsFilter, + SerializedConsensusConflictData, SerializedAssigneeConsensusReportData, SerializedConsensusReportData, } from './server-response-types'; import { PaginatedResource } from './core-types'; import { Request } from './request'; @@ -767,6 +768,41 @@ async function deleteTask(id: number, organizationID: string | null = null): Pro } } +async function mergeConsensusJobs(id: number, instanceType: string): Promise { + const { backendAPI } = config; + const url = `${backendAPI}/consensus/reports`; + const params = { + rq_id: null, + }; + const requestBody = { + task_id: 0, + job_id: 0, + }; + + if (instanceType === 'task') requestBody.task_id = id; + else requestBody.job_id = id; + + return new Promise((resolve, reject) => { + async function request() { + try { + const response = await Axios.post(url, requestBody, { params }); + params.rq_id = response.data.rq_id; + const { status } = response; + if (status === 202) { + setTimeout(request, 3000); + } else if (status === 201) { + resolve(); + } else { + reject(generateError(response)); + } + } catch (errorData) { + reject(generateError(errorData)); + } + } + setTimeout(request); + }); +} + async function getLabels(filter: { job_id?: number, task_id?: number, @@ -2150,8 +2186,44 @@ async function updateQualitySettings( } } +async function getConsensusSettings( + filter: APIQualitySettingsFilter, +): Promise { + const { backendAPI } = config; + + try { + const response = await Axios.get(`${backendAPI}/consensus/settings`, { + params: { + ...filter, + }, + }); + + return response.data.results[0]; + } catch (errorData) { + throw generateError(errorData); + } +} + +async function updateConsensusSettings( + settingsID: number, + settingsData: SerializedConsensusSettingsData, +): Promise { + const params = enableOrganization(); + const { backendAPI } = config; + + try { + const response = await Axios.patch(`${backendAPI}/consensus/settings/${settingsID}`, settingsData, { + params, + }); + + return response.data; + } catch (errorData) { + throw generateError(errorData); + } +} + async function getQualityConflicts( - filter: APIQualityConflictsFilter, + filter: APIConflictsFilter, ): Promise { const params = enableOrganization(); const { backendAPI } = config; @@ -2187,6 +2259,62 @@ async function getQualityReports( } } +async function getConsensusConflicts( + filter: APIConflictsFilter, +): Promise { + const params = enableOrganization(); + const { backendAPI } = config; + + try { + const response = await fetchAll(`${backendAPI}/consensus/conflicts`, { + ...params, + ...filter, + }); + + return response.results; + } catch (errorData) { + throw generateError(errorData); + } +} + +async function getConsensusReports( + filter: APIConsensusReportsFilter, +): Promise> { + const { backendAPI } = config; + + try { + const response = await Axios.get(`${backendAPI}/consensus/reports`, { + params: { + ...filter, + }, + }); + + response.data.results.count = response.data.count; + return response.data.results; + } catch (errorData) { + throw generateError(errorData); + } +} + +async function getAssigneeConsensusReports( + filter: APIAssigneeConsensusReportsFilter, +): Promise> { + const { backendAPI } = config; + + try { + const response = await Axios.get(`${backendAPI}/consensus/assignee_reports`, { + params: { + ...filter, + }, + }); + + response.data.results.count = response.data.count; + return response.data.results; + } catch (errorData) { + throw generateError(errorData); + } +} + async function getAnalyticsReports( filter: APIAnalyticsReportFilter, ): Promise { @@ -2362,6 +2490,7 @@ export default Object.freeze({ getPreview: getPreview('tasks'), backup: backupTask, restore: restoreTask, + mergeConsensusJobs, }), labels: Object.freeze({ @@ -2377,6 +2506,7 @@ export default Object.freeze({ create: createJob, delete: deleteJob, exportDataset: exportDataset('jobs'), + mergeConsensusJobs, }), users: Object.freeze({ @@ -2481,6 +2611,16 @@ export default Object.freeze({ }), }), + consensus: Object.freeze({ + assigneeReports: getAssigneeConsensusReports, + reports: getConsensusReports, + conflicts: getConsensusConflicts, + settings: Object.freeze({ + get: getConsensusSettings, + update: updateConsensusSettings, + }), + }), + requests: Object.freeze({ list: getRequestsList, status: getRequestStatus, diff --git a/cvat-core/src/server-response-types.ts b/cvat-core/src/server-response-types.ts index 5ec9e21d0b7..0654ffa250a 100644 --- a/cvat-core/src/server-response-types.ts +++ b/cvat-core/src/server-response-types.ts @@ -120,6 +120,7 @@ export interface SerializedTask { subset: string; updated_date: string; url: string; + consensus_jobs_per_regular_job: number; } export interface SerializedJob { @@ -146,6 +147,7 @@ export interface SerializedJob { url: string; source_storage: SerializedStorage | null; target_storage: SerializedStorage | null; + parent_job_id: number | null; } export type AttrInputType = 'select' | 'radio' | 'checkbox' | 'number' | 'text'; @@ -237,6 +239,7 @@ export interface APIQualitySettingsFilter extends APICommonFilterParams { task_id?: number; } export type QualitySettingsFilter = Camelized; +export type ConsensusSettingsFilter = QualitySettingsFilter; export interface SerializedQualitySettingsData { id?: number; @@ -259,12 +262,12 @@ export interface SerializedQualitySettingsData { descriptions?: Record; } -export interface APIQualityConflictsFilter extends APICommonFilterParams { +export interface APIConflictsFilter extends APICommonFilterParams { report_id?: number; } -export type QualityConflictsFilter = Camelized; +export type ConflictsFilter = Camelized; -export interface SerializedAnnotationConflictData { +export interface SerializedAnnotationQualityConflictData { job_id?: number; obj_id?: number; type?: ObjectType; @@ -277,7 +280,7 @@ export interface SerializedQualityConflictData { id?: number; frame?: number; type?: string; - annotation_ids?: SerializedAnnotationConflictData[]; + annotation_ids?: SerializedAnnotationQualityConflictData[]; data?: string; severity?: string; description?: string; @@ -285,7 +288,7 @@ export interface SerializedQualityConflictData { export interface APIQualityReportsFilter extends APICommonFilterParams { parent_id?: number; - peoject_id?: number; + project_id?: number; task_id?: number; job_id?: number; target?: string; @@ -324,6 +327,79 @@ export interface SerializedQualityReportData { }; } +export interface SerializedAnnotationConsensusConflictData { + job_id?: number; + obj_id?: number; + type?: ObjectType; + shape_type?: string | null; + conflict_type?: string; +} + +export interface SerializedConsensusConflictData { + id?: number; + frame?: number; + type?: string; + annotation_ids?: SerializedAnnotationConsensusConflictData[]; + data?: string; + description?: string; +} + +export interface SerializedConsensusSettingsData { + id?: number; + task?: number; + agreement_score_threshold?: number; + quorum?: number; + iou_threshold?: number; + sigma?: number; + line_thickness?: number; +} + +export interface APIConsensusReportsFilter extends APICommonFilterParams { + task_id?: number; + job_id?: number | null; + target?: string; +} + +export type ConsensusReportsFilter = Camelized; + +export interface SerializedConsensusReportData { + id?: number; + task_id?: number; + job_id?: number | null; + created_date?: string; + target?: string; + assignee?: SerializedUser | null; + consensus_score?: number; + summary?: { + frame_count: number; + conflict_count: number; + conflicts_by_type: { + no_matching_item: number; + failed_attribute_voting: number; + no_matching_annotation: number; + annotation_too_close: number; + wrong_group: number; + failed_label_voting: number; + } + }; +} + +export interface APIAssigneeConsensusReportsFilter extends APICommonFilterParams { + task_id?: number; + consensus_report_id?: number; +} + +export type AssigneeConsensusReportsFilter = Camelized; + +export interface SerializedAssigneeConsensusReportData { + id?: number; + task_id?: number; + consensus_report_id?: number; + assignee?: SerializedUser; + consensus_score?: number; + conflict_count?: number; +} + export interface SerializedDataEntry { date?: string; value?: number | Record diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index 96177170872..509e7cb3c4a 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -594,6 +594,14 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { }, }); + Object.defineProperty(Job.prototype.mergeConsensusJobs, 'implementation', { + value: function mergeConsensusJobsImplementation( + this: JobClass, + ): ReturnType { + return serverProxy.jobs.mergeConsensusJobs(this.id, 'job'); + }, + }); + return Job; } @@ -709,6 +717,10 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { taskSpec.source_storage = this.sourceStorage.toJSON(); } + if (this.consensusJobsPerRegularJob) { + taskSpec.consensus_jobs_per_regular_job = this.consensusJobsPerRegularJob; + } + const taskDataSpec = { client_files: this.clientFiles, server_files: this.serverFiles, @@ -779,6 +791,14 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { }, }); + Object.defineProperty(Task.prototype.mergeConsensusJobs, 'implementation', { + value: function mergeConsensusJobsImplementation( + this: TaskClass, + ): ReturnType { + return serverProxy.tasks.mergeConsensusJobs(this.id, 'task'); + }, + }); + Object.defineProperty(Task.prototype.issues, 'implementation', { value: function issuesImplementation( this: TaskClass, diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index 54133ff6b66..9324353b043 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -475,7 +475,7 @@ export class Job extends Session { frame_count?: number; project_id: number | null; guide_id: number | null; - task_id: number | null; + task_id: number; labels: Label[]; dimension?: DimensionType; data_compressed_chunk_type?: ChunkType; @@ -486,8 +486,8 @@ export class Job extends Session { updated_date?: string, source_storage: Storage, target_storage: Storage, + parent_job_id: number | null; }; - constructor(initialData: InitializerType) { super(); @@ -513,6 +513,7 @@ export class Job extends Session { updated_date: undefined, source_storage: undefined, target_storage: undefined, + parent_job_id: null, }; this.#data.id = initialData.id ?? this.#data.id; @@ -527,6 +528,7 @@ export class Job extends Session { this.#data.data_chunk_size = initialData.data_chunk_size ?? this.#data.data_chunk_size; this.#data.mode = initialData.mode ?? this.#data.mode; this.#data.created_date = initialData.created_date ?? this.#data.created_date; + this.#data.parent_job_id = initialData.parent_job_id ?? this.#data.parent_job_id; if (Array.isArray(initialData.labels)) { this.#data.labels = initialData.labels.map((labelData) => { @@ -622,7 +624,7 @@ export class Job extends Session { return this.#data.guide_id; } - public get taskId(): number | null { + public get taskId(): number { return this.#data.task_id; } @@ -630,6 +632,10 @@ export class Job extends Session { return this.#data.dimension; } + public get parent_job_id(): number | null { + return this.#data.parent_job_id; + } + public get dataChunkType(): ChunkType { return this.#data.data_compressed_chunk_type; } @@ -699,6 +705,11 @@ export class Job extends Session { const result = await PluginRegistry.apiWrapper.call(this, Job.prototype.delete); return result; } + + async mergeConsensusJobs(): Promise { + const result = await PluginRegistry.apiWrapper.call(this, Job.prototype.mergeConsensusJobs); + return result; + } } export class Task extends Session { @@ -727,6 +738,7 @@ export class Task extends Session { public readonly organization: number | null; public readonly progress: { count: number; completed: number }; public readonly jobs: Job[]; + public readonly consensusJobsPerRegularJob: number; public readonly startFrame: number; public readonly stopFrame: number; @@ -786,6 +798,8 @@ export class Task extends Session { cloud_storage_id: undefined, sorting_method: undefined, files: undefined, + consensus_jobs_per_regular_job: undefined, + quality_settings: undefined, }; const updateTrigger = new FieldUpdateTrigger(); @@ -861,6 +875,7 @@ export class Task extends Session { data_chunk_size: data.data_chunk_size, target_storage: initialData.target_storage, source_storage: initialData.source_storage, + parent_job_id: job.parent_job_id, }); data.jobs.push(jobInstance); } @@ -968,6 +983,9 @@ export class Task extends Session { copyData: { get: () => data.copy_data, }, + consensusJobsPerRegularJob: { + get: () => data.consensus_jobs_per_regular_job, + }, labels: { get: () => [...data.labels], set: (labels: Label[]) => { @@ -1149,6 +1167,11 @@ export class Task extends Session { return result; } + async mergeConsensusJobs(): Promise { + const result = await PluginRegistry.apiWrapper.call(this, Task.prototype.mergeConsensusJobs); + return result; + } + async backup(targetStorage: Storage, useDefaultSettings: boolean, fileName?: string): Promise { const result = await PluginRegistry.apiWrapper.call( this, diff --git a/cvat-ui/src/actions/consensus-actions.ts b/cvat-ui/src/actions/consensus-actions.ts new file mode 100644 index 00000000000..5834e18eae9 --- /dev/null +++ b/cvat-ui/src/actions/consensus-actions.ts @@ -0,0 +1,74 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; +import { ConsensusSettings, Job, Task } from 'cvat-core-wrapper'; + +export enum ConsensusActionTypes { + SET_FETCHING = 'SET_FETCHING', + SET_CONSENSUS_SETTINGS = 'SET_CONSENSUS_SETTINGS', + MERGE_CONSENSUS_JOBS = 'MERGE_CONSENSUS_JOBS', + MERGE_CONSENSUS_JOBS_SUCCESS = 'MERGE_CONSENSUS_JOBS_SUCCESS', + MERGE_CONSENSUS_JOBS_FAILED = 'MERGE_CONSENSUS_JOBS_FAILED', + MERGE_SPECIFIC_CONSENSUS_JOBS = 'MERGE_SPECIFIC_CONSENSUS_JOBS', + MERGE_SPECIFIC_CONSENSUS_JOBS_SUCCESS = 'MERGE_SPECIFIC_CONSENSUS_JOBS_SUCCESS', + MERGE_SPECIFIC_CONSENSUS_JOBS_FAILED = 'MERGE_SPECIFIC_CONSENSUS_JOBS_FAILED', +} + +export const consensusActions = { + setFetching: (fetching: boolean) => ( + createAction(ConsensusActionTypes.SET_FETCHING, { fetching }) + ), + setConsensusSettings: (consensusSettings: ConsensusSettings) => ( + createAction(ConsensusActionTypes.SET_CONSENSUS_SETTINGS, { consensusSettings }) + ), + mergeTaskConsensusJobs: (taskID: number) => ( + createAction(ConsensusActionTypes.MERGE_CONSENSUS_JOBS, { taskID }) + ), + mergeTaskConsensusJobsSuccess: (taskID: number) => ( + createAction(ConsensusActionTypes.MERGE_CONSENSUS_JOBS_SUCCESS, { taskID }) + ), + mergeTaskConsensusJobsFailed: (taskID: number, error: any) => ( + createAction(ConsensusActionTypes.MERGE_CONSENSUS_JOBS_FAILED, { taskID, error }) + ), + mergeSpecificTaskConsensusJobs: (jobID: number) => ( + createAction(ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS, { jobID }) + ), + mergeSpecificTaskConsensusJobsSuccess: (jobID: number, taskID: number) => ( + createAction(ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_SUCCESS, { jobID, taskID }) + ), + mergeSpecificTaskConsensusJobsFailed: (jobID: number, taskID: number, error: any) => ( + createAction(ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_FAILED, { jobID, taskID, error }) + ), +}; + +export const mergeTaskConsensusJobsAsync = ( + taskInstance: Task, +): ThunkAction => async (dispatch) => { + try { + dispatch(consensusActions.mergeTaskConsensusJobs(taskInstance.id)); + await taskInstance.mergeConsensusJobs(); + } catch (error) { + dispatch(consensusActions.mergeTaskConsensusJobsFailed(taskInstance.id, error)); + return; + } + + dispatch(consensusActions.mergeTaskConsensusJobsSuccess(taskInstance.id)); +}; + +export const mergeTaskSpecificConsensusJobsAsync = ( + jobInstance: Job, +): ThunkAction => async (dispatch) => { + try { + dispatch(consensusActions.mergeSpecificTaskConsensusJobs(jobInstance.id)); + await jobInstance.mergeConsensusJobs(); + } catch (error) { + dispatch(consensusActions.mergeSpecificTaskConsensusJobsFailed(jobInstance.id, jobInstance.taskId, error)); + return; + } + + dispatch(consensusActions.mergeSpecificTaskConsensusJobsSuccess(jobInstance.id, jobInstance.taskId)); +}; + +export type ConsensusActions = ActionUnion; diff --git a/cvat-ui/src/actions/jobs-actions.ts b/cvat-ui/src/actions/jobs-actions.ts index 7c2df71a927..2db3bc4a387 100644 --- a/cvat-ui/src/actions/jobs-actions.ts +++ b/cvat-ui/src/actions/jobs-actions.ts @@ -27,6 +27,8 @@ export enum JobsActionTypes { DELETE_JOB = 'DELETE_JOB', DELETE_JOB_SUCCESS = 'DELETE_JOB_SUCCESS', DELETE_JOB_FAILED = 'DELETE_JOB_FAILED', + COLLAPSE_REGULAR_JOB = 'COLLAPSE_REGULAR_JOB', + UNCOLLAPSE_REGULAR_JOB = 'UNCOLLAPSE_REGULAR_JOB', } interface JobsList extends Array { @@ -69,6 +71,12 @@ const jobsActions = { deleteJobFailed: (jobID: number, error: any) => ( createAction(JobsActionTypes.DELETE_JOB_FAILED, { jobID, error }) ), + collapseRegularJob: (jobID: number) => ( + createAction(JobsActionTypes.COLLAPSE_REGULAR_JOB, { jobID }) + ), + uncollapseRegularJob: (jobID: number) => ( + createAction(JobsActionTypes.UNCOLLAPSE_REGULAR_JOB, { jobID }) + ), }; export type JobsActions = ActionUnion; @@ -134,3 +142,11 @@ export const deleteJobAsync = (job: Job): ThunkAction => async (dispatch) => { dispatch(jobsActions.deleteJobSuccess(job.id)); }; + +export const collapseRegularJob = (jobID: number, uncollapse: boolean): ThunkAction => async (dispatch) => { + if (uncollapse) { + dispatch(jobsActions.collapseRegularJob(jobID)); + } else { + dispatch(jobsActions.uncollapseRegularJob(jobID)); + } +}; diff --git a/cvat-ui/src/actions/tasks-actions.ts b/cvat-ui/src/actions/tasks-actions.ts index 70eb56d4c1e..fdc936e2c86 100644 --- a/cvat-ui/src/actions/tasks-actions.ts +++ b/cvat-ui/src/actions/tasks-actions.ts @@ -214,8 +214,9 @@ ThunkAction { use_zip_chunks: data.advanced.useZipChunks, use_cache: data.advanced.useCache, sorting_method: data.advanced.sortingMethod, - source_storage: new Storage(data.advanced.sourceStorage ?? { location: StorageLocation.LOCAL }).toJSON(), - target_storage: new Storage(data.advanced.targetStorage ?? { location: StorageLocation.LOCAL }).toJSON(), + source_storage: new Storage(data.advanced.sourceStorage || { location: StorageLocation.LOCAL }).toJSON(), + target_storage: new Storage(data.advanced.targetStorage || { location: StorageLocation.LOCAL }).toJSON(), + consensus_jobs_per_regular_job: data.advanced.consensusJobsPerRegularJob, }; if (data.projectId) { @@ -272,6 +273,9 @@ ThunkAction { validation_frames_per_job: data.quality.validationFramesPerJob, }; } + if (data.advanced.consensusJobsPerRegularJob) { + description.consensus_jobs_per_regular_job = +data.advanced.consensusJobsPerRegularJob; + } const taskInstance = new cvat.classes.Task(description); taskInstance.clientFiles = data.files.local; diff --git a/cvat-ui/src/components/actions-menu/actions-menu.tsx b/cvat-ui/src/components/actions-menu/actions-menu.tsx index a20502931f2..7e6ce230ff5 100644 --- a/cvat-ui/src/components/actions-menu/actions-menu.tsx +++ b/cvat-ui/src/components/actions-menu/actions-menu.tsx @@ -6,10 +6,12 @@ import './styles.scss'; import React, { useCallback } from 'react'; import Modal from 'antd/lib/modal'; +import { LoadingOutlined } from '@ant-design/icons'; import { DimensionType, CVATCore } from 'cvat-core-wrapper'; import Menu, { MenuInfo } from 'components/dropdown-menu'; import { usePlugins } from 'utils/hooks'; import { CombinedState } from 'reducers'; +import { useSelector } from 'react-redux'; type AnnotationFormats = Awaited>; @@ -22,6 +24,7 @@ interface Props { dumpers: AnnotationFormats['dumpers']; inferenceIsActive: boolean; taskDimension: DimensionType; + consensusJobsPerRegularJob: number; onClickMenu: (params: MenuInfo) => void; } @@ -35,6 +38,8 @@ export enum Actions { BACKUP_TASK = 'backup_task', VIEW_ANALYTICS = 'view_analytics', QUALITY_CONTROL = 'quality_control', + VIEW_CONSENSUS_ANALYTICS = 'view_consensus_analytics', + MERGE_CONSENSUS_JOBS = 'merge_consensus_jobs', } function ActionsMenuComponent(props: Props): JSX.Element { @@ -43,11 +48,15 @@ function ActionsMenuComponent(props: Props): JSX.Element { projectID, bugTracker, inferenceIsActive, + consensusJobsPerRegularJob, onClickMenu, } = props; const plugins = usePlugins((state: CombinedState) => state.plugins.components.taskActions.items, props); + const mergingConsensus = useSelector((state: CombinedState) => state.consensus.mergingConsensus); + const isTaskInMergingConsensus = mergingConsensus[`task_${taskID}`]; + const onClickMenuWrapper = useCallback( (params: MenuInfo) => { if (!params) { @@ -68,6 +77,20 @@ function ActionsMenuComponent(props: Props): JSX.Element { }, okText: 'Delete', }); + } else if (params.key === Actions.MERGE_CONSENSUS_JOBS) { + Modal.confirm({ + title: 'The consensus jobs will be merged', + content: 'Existing annotations in regular jobs will be updated. Continue?', + className: 'cvat-modal-confirm-delete-task', + onOk: () => { + onClickMenu(params); + }, + okButtonProps: { + type: 'primary', + danger: true, + }, + okText: 'Merge', + }); } else { onClickMenu(params); } @@ -120,6 +143,25 @@ function ActionsMenuComponent(props: Props): JSX.Element { ), 60]); + if (consensusJobsPerRegularJob) { + menuItems.push([( + + View Consensus Analytics + + ), 55]); + menuItems.push([( + } + > + Merge Consensus Jobs + + ), 60]); + } + if (projectID === null) { menuItems.push([( Move to project diff --git a/cvat-ui/src/components/analytics-page/consensus-analytics-page.tsx b/cvat-ui/src/components/analytics-page/consensus-analytics-page.tsx new file mode 100644 index 00000000000..e764a036ee5 --- /dev/null +++ b/cvat-ui/src/components/analytics-page/consensus-analytics-page.tsx @@ -0,0 +1,220 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import './styles.scss'; + +import React, { useCallback, useEffect, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; +import { useParams } from 'react-router'; +import { Link } from 'react-router-dom'; +import { Row, Col } from 'antd/lib/grid'; +import Tabs from 'antd/lib/tabs'; +import Title from 'antd/lib/typography/Title'; +import notification from 'antd/lib/notification'; +import { useIsMounted } from 'utils/hooks'; +import { CombinedState, Task } from 'reducers'; +import { getCore } from 'cvat-core-wrapper'; +import CVATLoadingSpinner from 'components/common/loading-spinner'; +import GoBackButton from 'components/common/go-back-button'; +import { consensusActions } from 'actions/consensus-actions'; +import ConsensusSettingsForm from './task-consensus/consensus-settings-form'; +import TaskConsensusAnalyticsComponent from './task-consensus/task-consensus-component'; + +const core = getCore(); + +enum ConsensusAnalyticsTabs { + OVERVIEW = 'overview', + SETTINGS = 'settings', +} + +function getTabFromHash(): ConsensusAnalyticsTabs { + const tab = window.location.hash.slice(1) as ConsensusAnalyticsTabs; + return Object.values(ConsensusAnalyticsTabs).includes(tab) ? tab : ConsensusAnalyticsTabs.OVERVIEW; +} + +type InstanceType = 'task'; + +function TaskConsensusAnalyticsPage(): JSX.Element { + const dispatch = useDispatch(); + + const requestedInstanceType: InstanceType = 'task'; + const requestedInstanceID = +useParams<{ tid: string }>().tid; + + const [activeTab, setTab] = useState(getTabFromHash()); + const [instanceType, setInstanceType] = useState(null); + const [instance, setInstance] = useState(null); + const [fetching, setFetching] = useState(true); + const isMounted = useIsMounted(); + const consensusSettings = useSelector((state: CombinedState) => state.consensus?.consensusSettings); + + const onTabKeyChange = useCallback((key: string): void => { + setTab(key as ConsensusAnalyticsTabs); + }, []); + + const receiveInstance = async (type: InstanceType, id: number): Promise => { + let receivedInstance: Task | null = null; + + try { + switch (type) { + case 'task': { + [receivedInstance] = await core.tasks.get({ id }); + break; + } + default: + return; + } + + if (isMounted()) { + setInstance(receivedInstance); + setInstanceType(type); + } + } catch (error: unknown) { + notification.error({ + message: `Could not receive requested ${type}`, + description: `${error instanceof Error ? error.message : ''}`, + }); + } + }; + + useEffect(() => { + if (Number.isInteger(requestedInstanceID) && ['project', 'task', 'job'].includes(requestedInstanceType)) { + setFetching(true); + Promise.all([ + receiveInstance(requestedInstanceType, requestedInstanceID), + ]).finally(() => { + if (isMounted()) { + setFetching(false); + } + }); + } else { + notification.error({ + message: 'Could not load this page', + description: `Not valid resource ${requestedInstanceType} #${requestedInstanceID}`, + }); + } + + return () => { + if (isMounted()) { + setInstance(null); + } + }; + }, [requestedInstanceType, requestedInstanceID]); + + function handleError(error: Error): void { + notification.error({ + description: error.toString(), + message: 'Could not fetch consensus settings.', + }); + } + + useEffect(() => { + window.addEventListener('hashchange', () => { + const hash = getTabFromHash(); + setTab(hash); + }); + }, []); + + useEffect(() => { + if (instance) { + dispatch(consensusActions.setFetching(true)); + + const settingsRequest = core.consensus.settings.get({ taskID: instance.id }); + + Promise.all([settingsRequest]) + .then(([settings]) => { + dispatch(consensusActions.setConsensusSettings(settings)); + }) + .catch(handleError) + .finally(() => { + dispatch(consensusActions.setFetching(false)); + }); + } + }, [instance?.id]); + + useEffect(() => { + window.location.hash = activeTab; + }, [activeTab]); + + let backNavigation: JSX.Element | null = null; + let title: JSX.Element | null = null; + let tabs: JSX.Element | null = null; + if (instanceType && instance) { + backNavigation = ( + + + + ); + + const analyticsFor = {`Task #${instance.id}`}; + title = ( + + + Consensus Analytics for + {' '} + {analyticsFor} + + + ); + + const consensusSettingsForm = ( + dispatch(consensusActions.setConsensusSettings(settings))} + /> + ); + + tabs = ( + + ), + }, + ]), + ...(instance.consensusJobsPerRegularJob ? + [ + { + key: ConsensusAnalyticsTabs.SETTINGS, + label: 'Settings', + children: ( + consensusSettingsForm + ), + }, + ] : + []), + ]} + /> + ); + } + + return ( +
+ {fetching ? ( +
+ +
+ ) : ( + + {backNavigation} + + {title} + {tabs} + + + )} +
+ ); +} + +export default React.memo(TaskConsensusAnalyticsPage); diff --git a/cvat-ui/src/components/analytics-page/styles.scss b/cvat-ui/src/components/analytics-page/styles.scss index f9639e2966a..b4d90bc530f 100644 --- a/cvat-ui/src/components/analytics-page/styles.scss +++ b/cvat-ui/src/components/analytics-page/styles.scss @@ -88,3 +88,16 @@ width: 100%; height: 100%; } + +.cvat-task-analytics-tabs { + width: 100%; +} + +.cvat-task-gt-conflicts, .cvat-task-issues { + padding-top: $grid-unit-size; + padding-bottom: $grid-unit-size; +} + +.cvat-analytics-card-holder { + min-height: $grid-unit-size * 19; +} diff --git a/cvat-ui/src/components/analytics-page/task-consensus/assignee-list.tsx b/cvat-ui/src/components/analytics-page/task-consensus/assignee-list.tsx new file mode 100644 index 00000000000..1b93cf79d84 --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/assignee-list.tsx @@ -0,0 +1,110 @@ +// Copyright (C) 2023-2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import React from 'react'; +import { ColumnFilterItem, Key } from 'antd/lib/table/interface'; +import Table from 'antd/lib/table'; +import Text from 'antd/lib/typography/Text'; + +import { + User, AssigneeConsensusReport, +} from 'cvat-core-wrapper'; +import Tag from 'antd/lib/tag'; +import { toRepresentation, consensusColorGenerator } from 'utils/consensus'; +import { sorter } from 'utils/quality'; + +interface Props { + assigneeReports: AssigneeConsensusReport[]; +} + +function AssigneeListComponent(props: Props): JSX.Element { + const { assigneeReports: assigneeReportsArray } = props; + const assigneeReports: Record = assigneeReportsArray + .reduce((acc, report) => ({ ...acc, [report?.assignee?.id]: report }), {}); + + function collectUsers(path: string): ColumnFilterItem[] { + return Array.from( + new Set( + Object.values(assigneeReports).map((report: AssigneeConsensusReport) => { + if (report[path] === null) { + return null; + } + + return report[path].username; + }), + ), + ).map((value: string | null) => ({ text: value ?? 'Is Empty', value: value ?? false })); + } + + const columns = [ + { + title: 'Assignee', + dataIndex: 'assignee', + key: 'assignee', + className: 'cvat-job-item-assignee', + render: (assignee: User): JSX.Element => {assignee?.username}, + sorter: sorter('assignee.username'), + filters: collectUsers('assignee'), + onFilter: (value: boolean | Key, assignee: any) => (assignee?.assignee?.username || false) === value, + }, + { + title: 'Conflicts', + dataIndex: 'conflict_count', + key: 'conflict_count', + className: 'cvat-job-item-conflict', + sorter: sorter('conflict_count'), + render: (value: number): JSX.Element => {value}, + }, + { + title: 'Score', + dataIndex: 'quality', + key: 'quality', + className: 'cvat-job-item-quality', + sorter: sorter('quality'), + render: (value: number): JSX.Element => { + const meanConsensusScore = value; + const consensusScoreRepresentation = toRepresentation(meanConsensusScore); + return consensusScoreRepresentation.includes('N/A') ? ( + + N/A + + ) : ( + {consensusScoreRepresentation} + ); + }, + }, + ]; + const data = assigneeReportsArray.reduce((acc: any[], assigneeReport: any) => { + const report = assigneeReports[assigneeReport?.assignee?.id]; + if (report?.assignee) { + acc.push({ + key: report.assignee.id || 0, + assignee: report.assignee, + quality: report.consensusScore, + conflict_count: report.conflictCount, + }); + } + + return acc; + }, []); + + return ( +
+ 'cvat-task-jobs-table-row'} + columns={columns} + dataSource={data} + size='small' + style={{ width: '100%' }} + /> + + ); +} + +export default React.memo(AssigneeListComponent); diff --git a/cvat-ui/src/components/analytics-page/task-consensus/consensus-conflicts.tsx b/cvat-ui/src/components/analytics-page/task-consensus/consensus-conflicts.tsx new file mode 100644 index 00000000000..f16f524440e --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/consensus-conflicts.tsx @@ -0,0 +1,68 @@ +// Copyright (C) 2023-2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import React from 'react'; +import Text from 'antd/lib/typography/Text'; +import { Col, Row } from 'antd/lib/grid'; + +import ConsensusReport, { ConsensusSummary } from 'cvat-core/src/consensus-report'; +import { clampValue } from 'utils/consensus'; +import AnalyticsCard from '../views/analytics-card'; + +interface Props { + taskReport: ConsensusReport | null; +} + +interface ConflictTooltipProps { + reportSummary?: ConsensusSummary; +} + +export function ConflictsTooltip(props: ConflictTooltipProps): JSX.Element { + const { reportSummary } = props; + return ( + + + Conflicts: + + No matching item:  + {reportSummary?.conflictsByType.no_matching_item || 0} + + + No matching annotation:  + {reportSummary?.conflictsByType.no_matching_annotation || 0} + + + Annotation too close:  + {reportSummary?.conflictsByType.annotation_too_close || 0} + + + Failed label voting:  + {reportSummary?.conflictsByType.failed_label_voting || 0} + + + + ); +} + +function ConsensusConflicts(props: Props): JSX.Element { + const { taskReport } = props; + let conflictsRepresentation: string | number = 'N/A'; + let reportSummary; + if (taskReport) { + reportSummary = taskReport.summary; + conflictsRepresentation = clampValue(reportSummary?.conflictCount); + } + + return ( + } + size={{ cardSize: 12 }} + /> + ); +} + +export default React.memo(ConsensusConflicts); diff --git a/cvat-ui/src/components/analytics-page/task-consensus/consensus-settings-form.tsx b/cvat-ui/src/components/analytics-page/task-consensus/consensus-settings-form.tsx new file mode 100644 index 00000000000..bfd7c9a49cc --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/consensus-settings-form.tsx @@ -0,0 +1,205 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import React, { useCallback, useState } from 'react'; +import { QuestionCircleOutlined } from '@ant-design/icons/lib/icons'; +import Text from 'antd/lib/typography/Text'; +import InputNumber from 'antd/lib/input-number'; +import { Col, Row } from 'antd/lib/grid'; +import Form from 'antd/lib/form'; +import { Button, Divider } from 'antd/lib'; +import notification from 'antd/lib/notification'; +import { LoadingOutlined } from '@ant-design/icons'; +import CVATTooltip from 'components/common/cvat-tooltip'; +import { ConsensusSettings } from 'cvat-core-wrapper'; + +interface Props { + settings: ConsensusSettings | null; + setConsensusSettings: (settings: ConsensusSettings) => void; +} + +export default function ConsensusSettingsForm(props: Props): JSX.Element | null { + const [form] = Form.useForm(); + const { settings, setConsensusSettings } = props; + const [updatingConsensusSetting, setUpdatingConsensusSetting] = useState(false); + + if (!settings) { + return No quality settings; + } + + const initialValues = { + iouThreshold: settings.iouThreshold * 100, + agreementScoreThreshold: settings.agreementScoreThreshold * 100, + quorum: settings.quorum, + sigma: settings.sigma * 100, + lineThickness: settings.lineThickness * 100, + }; + + const onSave = useCallback(async () => { + try { + if (settings) { + const values = await form.validateFields(); + + settings.iouThreshold = values.iouThreshold / 100; + settings.quorum = values.quorum; + settings.agreementScoreThreshold = values.agreementScoreThreshold / 100; + settings.sigma = values.sigma / 100; + settings.lineThickness = values.lineThickness / 100; + + try { + const responseSettings = await settings.save(); + setUpdatingConsensusSetting(true); + setConsensusSettings(responseSettings); + } catch (error: unknown) { + notification.error({ + message: 'Could not save consensus settings', + description: typeof Error === 'object' ? (error as object).toString() : '', + }); + throw error; + } + await settings.save(); + } + + return settings; + } catch (e) { + return false; + } finally { + setUpdatingConsensusSetting(false); + } + }, [settings]); + + const shapeComparisonTooltip = ( +
+ Min overlap threshold(IoU) is used for distinction between matched / unmatched shapes. +
+ ); + + const KeypointTooltip = ( +
+ Sigma is used for calculating the OKS distance. +
+ ); + + const LineThicknessTooltip = ( +
+ Relative thickness is used for calculating the line thickness. +
+ ); + + const validationTooltip = ( +
+ + Quorum is the minimum number of annotations that should be present in a cluster for it to be considered. + + + Agreement score threshold prevents merged annotations with low overlap (IoU) in their cluster from being + accepted. + +
+ ); + + return ( +
+ + Consensus Settings + + + Shape comparison + + + + + +
+ + + + + + + + Consensus Validation + + + + + + + + + + + + + + + + + + + Keypoint Comparison + + + + + + + + + + + + + + Line Comparison + + + + + + + + + + + + + + + + + + ); +} diff --git a/cvat-ui/src/components/analytics-page/task-consensus/issues.tsx b/cvat-ui/src/components/analytics-page/task-consensus/issues.tsx new file mode 100644 index 00000000000..8483761e7cd --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/issues.tsx @@ -0,0 +1,66 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import '../styles.scss'; + +import React, { useEffect, useState } from 'react'; +import Text from 'antd/lib/typography/Text'; +import notification from 'antd/lib/notification'; +import { Task } from 'cvat-core-wrapper'; +import { useIsMounted } from 'utils/hooks'; +import { clampValue, percent } from 'utils/consensus'; +import AnalyticsCard from '../views/analytics-card'; + +interface Props { + task: Task; +} + +function Issues(props: Props): JSX.Element { + const { task } = props; + + const [issuesCount, setIssuesCount] = useState(0); + const [resolvedIssues, setResolvedIssues] = useState(0); + const isMounted = useIsMounted(); + + useEffect(() => { + task + .issues() + .then((issues: any[]) => { + if (isMounted()) { + setIssuesCount(issues.length); + setResolvedIssues(issues.reduce((acc, issue) => (issue.resolved ? acc + 1 : acc), 0)); + } + }) + .catch((_error: any) => { + if (isMounted()) { + notification.error({ + description: _error.toString(), + message: "Couldn't fetch issues", + className: 'cvat-notification-notice-get-issues-error', + }); + } + }); + }, []); + + const bottomElement = ( + + Resolved: + {' '} + {clampValue(resolvedIssues)} + {resolvedIssues ? ` (${percent(resolvedIssues, issuesCount)})` : ''} + + ); + + return ( + + ); +} + +export default React.memo(Issues); diff --git a/cvat-ui/src/components/analytics-page/task-consensus/job-list.tsx b/cvat-ui/src/components/analytics-page/task-consensus/job-list.tsx new file mode 100644 index 00000000000..ec25d382089 --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/job-list.tsx @@ -0,0 +1,206 @@ +// Copyright (C) 2023-2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import React, { useState } from 'react'; +import { useHistory } from 'react-router'; +import { DownloadOutlined, QuestionCircleOutlined } from '@ant-design/icons'; +import { ColumnFilterItem, Key } from 'antd/lib/table/interface'; +import Table from 'antd/lib/table'; +import Button from 'antd/lib/button'; +import Text from 'antd/lib/typography/Text'; + +import { + Task, Job, JobType, getCore, ConsensusReport, +} from 'cvat-core-wrapper'; +import CVATTooltip from 'components/common/cvat-tooltip'; +import Tag from 'antd/lib/tag'; +import { toRepresentation, consensusColorGenerator } from 'utils/consensus'; +import { sorter } from 'utils/quality'; +import { ConflictsTooltip } from './consensus-conflicts'; + +interface Props { + task: Task; + jobsReports: ConsensusReport[]; +} + +function JobListComponent(props: Props): JSX.Element { + const { task: taskInstance, jobsReports: jobsReportsArray } = props; + const jobsReports: Record = jobsReportsArray.reduce( + (acc, report) => { + if (!acc[report.jobID]) { + acc[report.jobID] = report; + } + return acc; + }, + {}, + ); + const history = useHistory(); + const { id: taskId, jobs } = taskInstance; + const [renderedJobs] = useState(jobs.filter((job: Job) => job.type === JobType.ANNOTATION)); + + function collectUsers(path: string): ColumnFilterItem[] { + return Array.from( + new Set( + Object.values(jobsReports).map((report: ConsensusReport) => { + if (report[path] === null) { + return null; + } + + return report[path].username; + }), + ), + ).map((value: string | null) => ({ text: value ?? 'Is Empty', value: value ?? false })); + } + + const columns = [ + { + title: 'Job', + dataIndex: 'job', + key: 'job', + sorter: sorter('key'), + render: (id: number): JSX.Element => ( +
+ +
+ ), + }, + { + title: 'Stage', + dataIndex: 'stage', + key: 'stage', + className: 'cvat-job-item-stage', + render: (jobInstance: any): JSX.Element => { + const { stage } = jobInstance; + + return ( +
+ {stage} +
+ ); + }, + sorter: sorter('stage.stage'), + filters: [ + { text: 'annotation', value: 'annotation' }, + { text: 'validation', value: 'validation' }, + { text: 'acceptance', value: 'acceptance' }, + ], + onFilter: (value: boolean | Key, record: any) => record.stage.stage === value, + }, + { + title: 'Assignee', + dataIndex: 'assignee', + key: 'assignee', + className: 'cvat-job-item-assignee', + render: (report: ConsensusReport): JSX.Element => {report?.assignee?.username}, + sorter: sorter('assignee.username'), + filters: collectUsers('assignee'), + onFilter: (value: boolean | Key, record: any) => (record.assignee.assignee?.username || false) === value, + }, + { + title: 'Conflicts', + dataIndex: 'conflicts', + key: 'conflicts', + className: 'cvat-job-item-conflicts', + sorter: sorter('conflicts.summary.conflictCount'), + render: (report: ConsensusReport): JSX.Element => { + const conflictCount = report?.summary?.conflictCount; + return ( +
+ {conflictCount || 0} + } + className='cvat-analytics-tooltip' + overlayStyle={{ maxWidth: '500px' }} + > + + +
+ ); + }, + }, + { + title: 'Score', + dataIndex: 'quality', + key: 'quality', + align: 'center' as const, + className: 'cvat-job-item-quality', + sorter: sorter('quality.consensus_score'), + render: (report?: ConsensusReport): JSX.Element => { + const meanConsensusScore = report?.consensus_score; + const consensusScoreRepresentation = toRepresentation(meanConsensusScore); + return consensusScoreRepresentation.includes('N/A') ? ( + + N/A + + ) : ( + {consensusScoreRepresentation} + ); + }, + }, + { + title: 'Download', + dataIndex: 'download', + key: 'download', + className: 'cvat-job-item-quality-report-download', + align: 'center' as const, + render: (job: Job): JSX.Element => { + const report = jobsReports[job.id]; + const reportID = report?.id; + return reportID ? ( + + + + ) : ( + + ); + }, + }, + ]; + const data = renderedJobs.reduce((acc: any[], job: any) => { + const report = jobsReports[job.id]; + + acc.push({ + key: job.id, + job: job.id, + download: job, + stage: job, + assignee: report, + quality: report, + conflicts: report, + }); + + return acc; + }, []); + + return ( +
+
'cvat-task-jobs-table-row'} + columns={columns} + dataSource={data} + size='small' + /> + + ); +} + +export default React.memo(JobListComponent); diff --git a/cvat-ui/src/components/analytics-page/task-consensus/mean-score.tsx b/cvat-ui/src/components/analytics-page/task-consensus/mean-score.tsx new file mode 100644 index 00000000000..7acd92baa1a --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/mean-score.tsx @@ -0,0 +1,67 @@ +// Copyright (C) 2023-2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import React from 'react'; +import { DownloadOutlined } from '@ant-design/icons'; +import { Col, Row } from 'antd/lib/grid'; +import Text from 'antd/lib/typography/Text'; +import Button from 'antd/lib/button'; + +import { ConsensusReport, getCore } from 'cvat-core-wrapper'; +import { toRepresentation } from 'utils/consensus'; +import AnalyticsCard from '../views/analytics-card'; + +interface Props { + taskID: number; + taskReport: ConsensusReport | null; +} + +function MeanQuality(props: Props): JSX.Element { + const { taskID, taskReport } = props; + const reportSummary = taskReport?.summary; + + const tooltip = ( +
+ + Conflicting annotations:  + {reportSummary?.conflictCount || 0} + +
+ ); + + const downloadReportButton = ( +
+ +
+ { + taskReport?.id ? ( + + ) : null + } + + + + ); + + return ( + + ); +} + +export default React.memo(MeanQuality); diff --git a/cvat-ui/src/components/analytics-page/task-consensus/task-consensus-component.tsx b/cvat-ui/src/components/analytics-page/task-consensus/task-consensus-component.tsx new file mode 100644 index 00000000000..6508351a0e9 --- /dev/null +++ b/cvat-ui/src/components/analytics-page/task-consensus/task-consensus-component.tsx @@ -0,0 +1,223 @@ +// Copyright (C) 2023-2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import moment from 'moment'; +import { Row } from 'antd/lib/grid'; +import Text from 'antd/lib/typography/Text'; +import notification from 'antd/lib/notification'; +import CVATLoadingSpinner from 'components/common/loading-spinner'; +import { + AssigneeConsensusReport, ConsensusReport, Task, getCore, +} from 'cvat-core-wrapper'; +import React, { + useCallback, useEffect, useReducer, useState, +} from 'react'; +import { useIsMounted } from 'utils/hooks'; +import { ActionUnion, createAction } from 'utils/redux'; +import { Tabs } from 'antd'; +import ConsensusConflicts from './consensus-conflicts'; +import Issues from './issues'; +import JobList from './job-list'; +import AssigneeListComponent from './assignee-list'; +import MeanQuality from './mean-score'; + +const core = getCore(); + +enum DetailsTabs { + JOBS = 'jobs', + ASSIGNEES = 'assignees', +} + +interface Props { + task: Task; +} + +interface State { + fetching: boolean; + taskReport: ConsensusReport | null; + jobsReports: ConsensusReport[]; + assigneeReports: AssigneeConsensusReport[]; +} + +enum ReducerActionType { + SET_FETCHING = 'SET_FETCHING', + SET_TASK_REPORT = 'SET_TASK_REPORT', + SET_JOBS_REPORTS = 'SET_JOBS_REPORTS', + SET_ASSIGNEE_REPORTS = 'SET_ASSIGNEE_REPORTS', +} + +export const reducerActions = { + setFetching: (fetching: boolean) => ( + createAction(ReducerActionType.SET_FETCHING, { fetching }) + ), + setTaskReport: (consensusReport: ConsensusReport) => ( + createAction(ReducerActionType.SET_TASK_REPORT, { consensusReport }) + ), + setJobsReports: (consensusReports: ConsensusReport[]) => ( + createAction(ReducerActionType.SET_JOBS_REPORTS, { consensusReports }) + ), + setAssigneeReports: (assigneeconsensusReports: AssigneeConsensusReport[]) => ( + createAction(ReducerActionType.SET_ASSIGNEE_REPORTS, { assigneeconsensusReports }) + ), +}; + +const reducer = (state: State, action: ActionUnion): State => { + if (action.type === ReducerActionType.SET_FETCHING) { + return { + ...state, + fetching: action.payload.fetching, + }; + } + + if (action.type === ReducerActionType.SET_TASK_REPORT) { + const taskReport = action.payload.consensusReport; + return { + ...state, + taskReport, + }; + } + + if (action.type === ReducerActionType.SET_JOBS_REPORTS) { + const jobsReports = action.payload.consensusReports; + return { + ...state, + jobsReports, + }; + } + + if (action.type === ReducerActionType.SET_ASSIGNEE_REPORTS) { + const assigneeReports = action.payload.assigneeconsensusReports; + return { + ...state, + assigneeReports, + }; + } + + return state; +}; + +function getTabFromHash(): DetailsTabs { + const tab = window.location.hash.slice(1) as DetailsTabs; + return Object.values(DetailsTabs).includes(tab) ? tab : DetailsTabs.JOBS; +} + +function TaskConsensusComponent(props: Props): JSX.Element { + const { task } = props; + const isMounted = useIsMounted(); + let tabs = null; + + const [state, dispatch] = useReducer(reducer, { + fetching: true, + taskReport: null, + jobsReports: [], + assigneeReports: [], + }); + const [activeTab, setTab] = useState(getTabFromHash()); + + useEffect(() => { + dispatch(reducerActions.setFetching(true)); + + function handleError(error: Error): void { + if (isMounted()) { + notification.error({ + description: error.toString(), + message: 'Could not initialize consensus analytics page', + }); + } + } + + core.consensus + .reports({ + pageSize: 1, target: 'task', taskID: task.id, + }) + .then(([report]) => { + let reportRequest = Promise.resolve([]); + let assigneeReportRequest = Promise.resolve([]); + if (report) { + reportRequest = core.consensus.reports({ + pageSize: task.jobs.length, + taskID: task.id, + target: 'job', + }); + assigneeReportRequest = core.consensus.assigneeReports({ + taskID: task.id, + consensusReportID: report.id, + }); + } + + Promise.all([reportRequest]) + .then(([jobReports]) => { + dispatch(reducerActions.setTaskReport(report || null)); + dispatch(reducerActions.setJobsReports(jobReports)); + Promise.all([assigneeReportRequest]) + .then(([assigneeReports]) => { + dispatch(reducerActions.setAssigneeReports(assigneeReports)); + }); + }) + .catch(handleError) + .finally(() => { + dispatch(reducerActions.setFetching(false)); + }); + }) + .catch(handleError); + }, [task?.id]); + + const { + fetching, taskReport, jobsReports, assigneeReports, + } = state; + + const onTabKeyChange = useCallback((key: string): void => { + setTab(key as DetailsTabs); + }, []); + + tabs = ( + , + }, + { + key: DetailsTabs.ASSIGNEES, + label: 'Assignees', + children: , + }, + ]} + /> + ); + + return ( +
+ {fetching ? ( + + ) : ( + <> + {taskReport?.id && ( + + + {`Created ${taskReport?.id ? moment(taskReport.createdDate).fromNow() : ''}`} + + + )} + + + + + + + + {tabs} + + )} +
+ ); +} + +export default React.memo(TaskConsensusComponent); diff --git a/cvat-ui/src/components/analytics-page/views/analytics-card.tsx b/cvat-ui/src/components/analytics-page/views/analytics-card.tsx index 655759bfaf7..26726494e51 100644 --- a/cvat-ui/src/components/analytics-page/views/analytics-card.tsx +++ b/cvat-ui/src/components/analytics-page/views/analytics-card.tsx @@ -30,7 +30,7 @@ function AnalyticsCard(props: Props): JSX.Element { return (
- + diff --git a/cvat-ui/src/components/create-task-page/advanced-configuration-form.tsx b/cvat-ui/src/components/create-task-page/advanced-configuration-form.tsx index f39afbe367f..7887a31edeb 100644 --- a/cvat-ui/src/components/create-task-page/advanced-configuration-form.tsx +++ b/cvat-ui/src/components/create-task-page/advanced-configuration-form.tsx @@ -17,6 +17,7 @@ import Text from 'antd/lib/typography/Text'; import { Store } from 'antd/lib/form/interface'; import CVATTooltip from 'components/common/cvat-tooltip'; import patterns from 'utils/validation-patterns'; +import { isInteger } from 'utils/validate-integer'; import { StorageLocation } from 'reducers'; import SourceStorageField from 'components/storage/source-storage-field'; import TargetStorageField from 'components/storage/target-storage-field'; @@ -47,6 +48,7 @@ export interface AdvancedConfiguration { sortingMethod: SortingMethod; useProjectSourceStorage: boolean; useProjectTargetStorage: boolean; + consensusJobsPerRegularJob: number; sourceStorage: StorageData; targetStorage: StorageData; } @@ -59,6 +61,7 @@ const initialValues: AdvancedConfiguration = { sortingMethod: SortingMethod.LEXICOGRAPHICAL, useProjectSourceStorage: true, useProjectTargetStorage: true, + consensusJobsPerRegularJob: 0, sourceStorage: { location: StorageLocation.LOCAL, @@ -92,30 +95,6 @@ function validateURL(_: RuleObject, value: string): Promise { return Promise.resolve(); } -const isInteger = ({ min, max }: { min?: number; max?: number }) => ( - _: RuleObject, - value?: number | string, -): Promise => { - if (typeof value === 'undefined' || value === '') { - return Promise.resolve(); - } - - const intValue = +value; - if (Number.isNaN(intValue) || !Number.isInteger(intValue)) { - return Promise.reject(new Error('Value must be a positive integer')); - } - - if (typeof min !== 'undefined' && intValue < min) { - return Promise.reject(new Error(`Value must be more than ${min}`)); - } - - if (typeof max !== 'undefined' && intValue > max) { - return Promise.reject(new Error(`Value must be less than ${max}`)); - } - - return Promise.resolve(); -}; - const validateOverlapSize: RuleRender = ({ getFieldValue }): RuleObject => ({ validator(_: RuleObject, value?: string | number): Promise { if (typeof value !== 'undefined' && value !== '') { @@ -402,6 +381,32 @@ class AdvancedConfigurationForm extends React.PureComponent { ); } + private renderconsensusJobsPerRegularJob(): JSX.Element { + return ( + intValue !== 1, + }), + }, + ]} + > + + + ); + } + private renderSourceStorage(): JSX.Element { const { projectId, @@ -483,6 +488,11 @@ class AdvancedConfigurationForm extends React.PureComponent { {this.renderChunkSize()} + + + {this.renderconsensusJobsPerRegularJob()} + + {this.renderBugTracker()} diff --git a/cvat-ui/src/components/create-task-page/create-task-content.tsx b/cvat-ui/src/components/create-task-page/create-task-content.tsx index 655d732c343..1408174155b 100644 --- a/cvat-ui/src/components/create-task-page/create-task-content.tsx +++ b/cvat-ui/src/components/create-task-page/create-task-content.tsx @@ -85,6 +85,7 @@ const defaultState: State = { }, useProjectSourceStorage: true, useProjectTargetStorage: true, + consensusJobsPerRegularJob: 0, }, quality: { validationMethod: ValidationMethod.NONE, diff --git a/cvat-ui/src/components/cvat-app.tsx b/cvat-ui/src/components/cvat-app.tsx index a9be65a055d..3d4257d2d2e 100644 --- a/cvat-ui/src/components/cvat-app.tsx +++ b/cvat-ui/src/components/cvat-app.tsx @@ -82,6 +82,7 @@ import IncorrectEmailConfirmationPage from './email-confirmation-pages/incorrect import CreateJobPage from './create-job-page/create-job-page'; import AnalyticsPage from './analytics-page/analytics-page'; import QualityControlPage from './quality-control/quality-control-page'; +import TaskConsensusAnalyticsPage from './analytics-page/consensus-analytics-page'; import InvitationWatcher from './invitation-watcher/invitation-watcher'; interface CVATAppProps { @@ -509,6 +510,7 @@ class CVATApplication extends React.PureComponent + diff --git a/cvat-ui/src/components/job-item/job-actions-menu.tsx b/cvat-ui/src/components/job-item/job-actions-menu.tsx index 0a3ac6c1900..7e84e30f234 100644 --- a/cvat-ui/src/components/job-item/job-actions-menu.tsx +++ b/cvat-ui/src/components/job-item/job-actions-menu.tsx @@ -3,22 +3,25 @@ // SPDX-License-Identifier: MIT import React, { useCallback } from 'react'; -import { useDispatch } from 'react-redux'; +import { useDispatch, useSelector } from 'react-redux'; import { useHistory } from 'react-router'; import Modal from 'antd/lib/modal'; - +import { LoadingOutlined } from '@ant-design/icons'; import { exportActions } from 'actions/export-actions'; import { deleteJobAsync } from 'actions/jobs-actions'; import { importActions } from 'actions/import-actions'; import { Job, JobType } from 'cvat-core-wrapper'; import Menu, { MenuInfo } from 'components/dropdown-menu'; +import { mergeTaskSpecificConsensusJobsAsync } from 'actions/consensus-actions'; +import { CombinedState } from 'reducers'; interface Props { job: Job; + consensusJobsPresent: boolean; } function JobActionsMenu(props: Props): JSX.Element { - const { job } = props; + const { job, consensusJobsPresent } = props; const dispatch = useDispatch(); const history = useHistory(); @@ -39,6 +42,9 @@ function JobActionsMenu(props: Props): JSX.Element { }); }, [job]); + const mergingConsensus = useSelector((state: CombinedState) => state.consensus.mergingConsensus); + const isTaskInMergingConsensus = mergingConsensus[`job_${job.id}`]; + return ( - Go to the task - Go to the project - Go to the bug tracker + + Go to the task + + + Go to the project + + + Go to the bug tracker + Import annotations Export annotations View analytics + {consensusJobsPresent && job.parent_job_id === null && ( + } + > + Merge Consensus Jobs + + )} - onDelete()} - > + onDelete()}> Delete diff --git a/cvat-ui/src/components/job-item/job-item.tsx b/cvat-ui/src/components/job-item/job-item.tsx index 3fe5267f9ef..54e11f18a64 100644 --- a/cvat-ui/src/components/job-item/job-item.tsx +++ b/cvat-ui/src/components/job-item/job-item.tsx @@ -5,6 +5,7 @@ import './styles.scss'; import React, { useEffect, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; import { Link } from 'react-router-dom'; import moment from 'moment'; import { Col, Row } from 'antd/lib/grid'; @@ -25,8 +26,9 @@ import { import { useIsMounted } from 'utils/hooks'; import UserSelector from 'components/task-page/user-selector'; import CVATTooltip from 'components/common/cvat-tooltip'; -import { useSelector } from 'react-redux'; import { CombinedState } from 'reducers'; +import Collapse from 'antd/lib/collapse'; +import { collapseRegularJob } from 'actions/jobs-actions'; import JobActionsMenu from './job-actions-menu'; interface Props { @@ -105,6 +107,7 @@ function JobItem(props: Props): JSX.Element { const deleted = job.id in deletes ? deletes[job.id] === true : false; const { stage } = job; + const dispatch = useDispatch(); const created = moment(job.createdDate); const updated = moment(job.updatedDate); const now = moment(moment.now()); @@ -116,6 +119,40 @@ function JobItem(props: Props): JSX.Element { } const frameCountPercent = ((job.frameCount / (task.size || 1)) * 100).toFixed(0); const frameCountPercentRepresentation = frameCountPercent === '0' ? '<1' : frameCountPercent; + let jobName = `Job #${job.id}`; + if (task.consensusJobsPerRegularJob && job.type !== JobType.GROUND_TRUTH) { + jobName = `Job #${job.id}`; + } + + let consensusJobs: Job[] = []; + if (task.consensusJobsPerRegularJob) { + consensusJobs = task.jobs.filter((eachJob: Job) => eachJob.parent_job_id === job.id).reverse(); + } + const consensusJobViews: React.JSX.Element[] = consensusJobs.map((eachJob: Job) => ( + + )); + + const regularJobViewUncollapse = useSelector((state: CombinedState) => state.jobs.regularJobViewUncollapse); + const regularJobUncollapsed = regularJobViewUncollapse[job.id]; + const handleCollapseChange = async (): Promise => { + await dispatch(collapseRegularJob(job.id, !regularJobUncollapsed)); + }; + + let tag = null; + if (job.type === JobType.GROUND_TRUTH) { + tag = ( + + Ground truth + + ); + } else if (job.type === JobType.CONSENSUS) { + tag = ( + + Consensus + + ); + } + return ( @@ -123,21 +160,16 @@ function JobItem(props: Props): JSX.Element { - {`Job #${job.id}`} + {jobName} - { - job.type === JobType.GROUND_TRUTH ? ( - - Ground truth - - ) : ( - - }> - - - - ) - } + {tag} + {job.type !== JobType.GROUND_TRUTH && ( + + }> + + + + )} @@ -208,15 +240,11 @@ function JobItem(props: Props): JSX.Element { onJobUpdate(job, { state: newValue }); }} > - - {JobState.NEW} - + {JobState.NEW} {JobState.IN_PROGRESS} - - {JobState.REJECTED} - + {JobState.REJECTED} {JobState.COMPLETED} @@ -233,7 +261,11 @@ function JobItem(props: Props): JSX.Element { Duration: - {`${moment.duration(now.diff(created)).humanize()}`} + + {`${moment + .duration(now.diff(created)) + .humanize()}`} + @@ -245,19 +277,17 @@ function JobItem(props: Props): JSX.Element { - { - job.type !== JobType.GROUND_TRUTH && ( - - - - Frame range: - - {`${job.startFrame}-${job.stopFrame}`} - - - - ) - } + {job.type !== JobType.GROUND_TRUTH && ( + + + + Frame range: + + {`${job.startFrame}-${job.stopFrame}`} + + + + )} @@ -265,10 +295,30 @@ function JobItem(props: Props): JSX.Element { } + className='job-actions-menu' + overlay={( + + )} > + {consensusJobs.length > 0 && ( + {`${consensusJobs.length} Consensus Jobs`}, + children: consensusJobViews, + }, + ]} + /> + )} ); diff --git a/cvat-ui/src/components/job-item/styles.scss b/cvat-ui/src/components/job-item/styles.scss index 39a137000ea..7723a24dec1 100644 --- a/cvat-ui/src/components/job-item/styles.scss +++ b/cvat-ui/src/components/job-item/styles.scss @@ -71,6 +71,20 @@ .cvat-job-item-dates-info { margin-top: $grid-unit-size; } + + .cvat-consensus-job-collapse { + margin-top: 12px; + + .ant-collapse-item > .ant-collapse-header { + align-items: center; + } + } + + .job-actions-menu { + position: absolute; + top: $grid-unit-size * 6.5; + } + } .ant-menu.cvat-job-item-menu { diff --git a/cvat-ui/src/components/jobs-page/job-card.tsx b/cvat-ui/src/components/jobs-page/job-card.tsx index a76ed0c3814..6d2593c44ba 100644 --- a/cvat-ui/src/components/jobs-page/job-card.tsx +++ b/cvat-ui/src/components/jobs-page/job-card.tsx @@ -11,7 +11,7 @@ import Descriptions from 'antd/lib/descriptions'; import { MoreOutlined } from '@ant-design/icons'; import Dropdown from 'antd/lib/dropdown'; -import { Job } from 'cvat-core-wrapper'; +import { Job, JobType } from 'cvat-core-wrapper'; import { useCardHeightHOC } from 'utils/hooks'; import Preview from 'components/common/preview'; import JobActionsMenu from 'components/job-item/job-actions-menu'; @@ -52,6 +52,13 @@ function JobCardComponent(props: Props): JSX.Element { (style as any).opacity = 0.5; } + let tag = null; + if (job.type === JobType.GROUND_TRUTH) { + tag = 'Ground truth'; + } else if (job.type === JobType.CONSENSUS) { + tag = 'Consensus'; + } + return ( + {tag &&
{tag}
}
{job.dimension.toUpperCase()}
)} @@ -78,14 +86,16 @@ function JobCardComponent(props: Props): JSX.Element { {`${job.stage} ${job.state}`} {job.stopFrame - job.startFrame + 1} - { job.assignee ? ( + {job.assignee ? ( {job.assignee.username} - ) : } + ) : ( + + )} )} + overlay={} > diff --git a/cvat-ui/src/components/jobs-page/jobs-filter-configuration.ts b/cvat-ui/src/components/jobs-page/jobs-filter-configuration.ts index 8c48a8f8609..ea95e4fe336 100644 --- a/cvat-ui/src/components/jobs-page/jobs-filter-configuration.ts +++ b/cvat-ui/src/components/jobs-page/jobs-filter-configuration.ts @@ -96,6 +96,19 @@ export const config: Partial = { valueSources: ['value'], operators: ['like'], }, + type: { + label: 'Job Type', + type: 'select', + operators: ['select_equals'], + valueSources: ['value'], + fieldSettings: { + listValues: [ + { value: 'annotation', title: 'annotation' }, + { value: 'ground_truth', title: 'ground_truth' }, + { value: 'consensus', title: 'consensus' }, + ], + }, + }, }, }; diff --git a/cvat-ui/src/components/jobs-page/styles.scss b/cvat-ui/src/components/jobs-page/styles.scss index e3a011a720a..40481b9aa38 100644 --- a/cvat-ui/src/components/jobs-page/styles.scss +++ b/cvat-ui/src/components/jobs-page/styles.scss @@ -75,6 +75,10 @@ .cvat-job-page-list-item-dimension { opacity: 1; } + + .cvat-job-page-list-item-type { + opacity: 1; + } } :nth-child(4n) { @@ -135,6 +139,20 @@ transition: 0.15s all ease; box-shadow: $box-shadow-base; } + + .cvat-job-page-list-item-type { + position: absolute; + top: $grid-unit-size * 5; + left: 0; + margin: $grid-unit-size $grid-unit-size $grid-unit-size 0; + width: fit-content; + background: white; + border-radius: 0 $border-radius-base $border-radius-base 0; + padding: $grid-unit-size; + opacity: 0.5; + transition: 0.15s all ease; + box-shadow: $box-shadow-base; + } } .cvat-jobs-page-pagination { diff --git a/cvat-ui/src/components/task-page/details.tsx b/cvat-ui/src/components/task-page/details.tsx index c1c986260bf..1d2a6c9f4c5 100644 --- a/cvat-ui/src/components/task-page/details.tsx +++ b/cvat-ui/src/components/task-page/details.tsx @@ -7,6 +7,7 @@ import React from 'react'; import { connect } from 'react-redux'; import { Row, Col } from 'antd/lib/grid'; +import Tag from 'antd/lib/tag'; import Text from 'antd/lib/typography/Text'; import Title from 'antd/lib/typography/Title'; import moment from 'moment'; @@ -59,6 +60,7 @@ const core = getCore(); interface State { name: string; subset: string; + consensusJobsPerRegularJob: number; } type Props = DispatchToProps & StateToProps & OwnProps; @@ -70,6 +72,7 @@ class DetailsComponent extends React.PureComponent { this.state = { name: taskInstance.name, subset: taskInstance.subset, + consensusJobsPerRegularJob: taskInstance.consensusJobsPerRegularJob, }; } @@ -86,29 +89,35 @@ class DetailsComponent extends React.PureComponent { private renderTaskName(): JSX.Element { const { name } = this.state; const { task: taskInstance, onUpdateTask } = this.props; + const taskName = name; return ( - { - this.setState({ - name: value, - }); - - taskInstance.name = value; - onUpdateTask(taskInstance); - }, - }} - className='cvat-text-color cvat-task-name' - > - {name} - + +
+ { + this.setState({ + name: value, + }); + + taskInstance.name = value; + onUpdateTask(taskInstance); + }, + }} + className='cvat-text-color cvat-task-name' + > + {taskName} + + + ); } private renderDescription(): JSX.Element { const { task: taskInstance, onUpdateTask } = this.props; + const { consensusJobsPerRegularJob } = this.state; const owner = taskInstance.owner ? taskInstance.owner.username : null; const assignee = taskInstance.assignee ? taskInstance.assignee : null; const created = moment(taskInstance.createdDate).format('MMMM Do YYYY'); @@ -127,8 +136,13 @@ class DetailsComponent extends React.PureComponent { {owner && ( - {`Task #${taskInstance.id} Created by ${owner} on ${created}`} +
+ + {`Task #${taskInstance.id} Created by ${owner} on ${created}`} + +
)} + {consensusJobsPerRegularJob > 0 && Consensus Based Annotation} Assigned to diff --git a/cvat-ui/src/components/task-page/job-list.tsx b/cvat-ui/src/components/task-page/job-list.tsx index bbd59652da2..c6f52cc662d 100644 --- a/cvat-ui/src/components/task-page/job-list.tsx +++ b/cvat-ui/src/components/task-page/job-list.tsx @@ -59,6 +59,9 @@ function setUpJobsList(jobs: Job[], query: JobsQuery): Job[] { result = result.filter((job, index) => jsonLogic.apply(filter, converted[index])); } + // consensus jobs will be under the collapse view + result = result.filter((job) => job.parent_job_id === null); + return result; } diff --git a/cvat-ui/src/components/task-page/top-bar.tsx b/cvat-ui/src/components/task-page/top-bar.tsx index a242aef4441..c0861f1669c 100644 --- a/cvat-ui/src/components/task-page/top-bar.tsx +++ b/cvat-ui/src/components/task-page/top-bar.tsx @@ -29,6 +29,10 @@ export default function DetailsComponent(props: DetailsComponentProps): JSX.Elem history.push(`/tasks/${taskInstance.id}/quality-control`); }; + const onViewConsensusAnalytics = (): void => { + history.push(`/tasks/${taskInstance.id}/consensus`); + }; + return ( @@ -63,6 +67,7 @@ export default function DetailsComponent(props: DetailsComponentProps): JSX.Elem taskInstance={taskInstance} onViewAnalytics={onViewAnalytics} onViewQualityControl={onViewQualityControl} + onViewConsensusAnalytics={onViewConsensusAnalytics} /> )} > diff --git a/cvat-ui/src/components/tasks-page/task-item.tsx b/cvat-ui/src/components/tasks-page/task-item.tsx index 9bc5fdec313..cf0ddd10517 100644 --- a/cvat-ui/src/components/tasks-page/task-item.tsx +++ b/cvat-ui/src/components/tasks-page/task-item.tsx @@ -242,6 +242,10 @@ class TaskItemComponent extends React.PureComponent { + history.push(`/tasks/${taskInstance.id}/consensus`); + }; + return ( @@ -271,6 +275,7 @@ class TaskItemComponent extends React.PureComponent )} > diff --git a/cvat-ui/src/containers/actions-menu/actions-menu.tsx b/cvat-ui/src/containers/actions-menu/actions-menu.tsx index e9773c2b905..ea8eff11eac 100644 --- a/cvat-ui/src/containers/actions-menu/actions-menu.tsx +++ b/cvat-ui/src/containers/actions-menu/actions-menu.tsx @@ -17,11 +17,13 @@ import { } from 'actions/tasks-actions'; import { exportActions } from 'actions/export-actions'; import { importActions } from 'actions/import-actions'; +import { mergeTaskConsensusJobsAsync } from 'actions/consensus-actions'; interface OwnProps { taskInstance: any; onViewAnalytics: () => void; onViewQualityControl: () => void; + onViewConsensusAnalytics: () => void; } interface StateToProps { @@ -35,6 +37,7 @@ interface DispatchToProps { openRunModelWindow: (taskInstance: any) => void; deleteTask: (taskInstance: any) => void; openMoveTaskToProjectWindow: (taskInstance: any) => void; + mergeConsensusJobs: (taskInstance: any) => void; } function mapStateToProps(state: CombinedState, own: OwnProps): StateToProps { @@ -73,6 +76,9 @@ function mapDispatchToProps(dispatch: any): DispatchToProps { openMoveTaskToProjectWindow: (taskId: number): void => { dispatch(switchMoveTaskModalVisible(true, taskId)); }, + mergeConsensusJobs: (taskInstance: any): void => { + dispatch(mergeTaskConsensusJobsAsync(taskInstance)); + }, }; } @@ -88,6 +94,8 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps): openMoveTaskToProjectWindow, onViewAnalytics, onViewQualityControl, + onViewConsensusAnalytics, + mergeConsensusJobs, } = props; const onClickMenu = (params: MenuInfo): void | JSX.Element => { const [action] = params.keyPath; @@ -109,6 +117,10 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps): onViewAnalytics(); } else if (action === Actions.QUALITY_CONTROL) { onViewQualityControl(); + } else if (action === Actions.VIEW_CONSENSUS_ANALYTICS) { + onViewConsensusAnalytics(); + } else if (action === Actions.MERGE_CONSENSUS_JOBS) { + mergeConsensusJobs(taskInstance); } }; @@ -123,6 +135,7 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps): inferenceIsActive={inferenceIsActive} onClickMenu={onClickMenu} taskDimension={taskInstance.dimension} + consensusJobsPerRegularJob={taskInstance.consensusJobsPerRegularJob} /> ); } diff --git a/cvat-ui/src/cvat-core-wrapper.ts b/cvat-ui/src/cvat-core-wrapper.ts index 295135e23e1..31fb64db2c0 100644 --- a/cvat-ui/src/cvat-core-wrapper.ts +++ b/cvat-ui/src/cvat-core-wrapper.ts @@ -22,6 +22,7 @@ import Project from 'cvat-core/src/project'; import QualityReport, { QualitySummary } from 'cvat-core/src/quality-report'; import QualityConflict, { AnnotationConflict, ConflictSeverity } from 'cvat-core/src/quality-conflict'; import QualitySettings, { TargetMetric } from 'cvat-core/src/quality-settings'; +import ConsensusSettings from 'cvat-core/src/consensus-settings'; import { FramesMetaData, FrameData } from 'cvat-core/src/frames'; import { ServerError, RequestError } from 'cvat-core/src/exceptions'; import { @@ -41,6 +42,8 @@ import { Event } from 'cvat-core/src/event'; import { APIWrapperEnterOptions } from 'cvat-core/src/plugins'; import BaseSingleFrameAction, { ActionParameterType, FrameSelectionType } from 'cvat-core/src/annotations-actions'; import { Request } from 'cvat-core/src/request'; +import ConsensusReport from 'cvat-core/src/consensus-report'; +import AssigneeConsensusReport from 'cvat-core/src/assignee-consensus-report'; const cvat: CVATCore = _cvat; @@ -91,6 +94,9 @@ export { QualityReport, QualityConflict, QualitySettings, + ConsensusSettings, + ConsensusReport, + AssigneeConsensusReport, TargetMetric, AnnotationConflict, ConflictSeverity, diff --git a/cvat-ui/src/reducers/consensus-reducer.ts b/cvat-ui/src/reducers/consensus-reducer.ts new file mode 100644 index 00000000000..188bd6ef874 --- /dev/null +++ b/cvat-ui/src/reducers/consensus-reducer.ts @@ -0,0 +1,118 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { ConsensusActions, ConsensusActionTypes } from 'actions/consensus-actions'; +import { ConsensusState } from '.'; + +const defaultState: ConsensusState = { + taskInstance: null, + jobInstance: null, + fetching: true, + consensusSettings: null, + mergingConsensus: {}, +}; + +function makeKey(id: number, instance: string): string { + return `${instance}_${id}`; +} + +export default (state: ConsensusState = defaultState, action: ConsensusActions): ConsensusState => { + switch (action.type) { + case ConsensusActionTypes.SET_FETCHING: { + return { + ...state, + fetching: action.payload.fetching, + }; + } + + case ConsensusActionTypes.SET_CONSENSUS_SETTINGS: { + return { + ...state, + consensusSettings: action.payload.consensusSettings, + }; + } + + case ConsensusActionTypes.MERGE_CONSENSUS_JOBS: { + const { taskID } = action.payload; + const { mergingConsensus } = state; + + mergingConsensus[makeKey(taskID, 'task')] = true; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + + case ConsensusActionTypes.MERGE_CONSENSUS_JOBS_SUCCESS: { + const { taskID } = action.payload; + const { mergingConsensus } = state; + + mergingConsensus[makeKey(taskID, 'task')] = false; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + case ConsensusActionTypes.MERGE_CONSENSUS_JOBS_FAILED: { + const { taskID } = action.payload; + const { mergingConsensus } = state; + + delete mergingConsensus[makeKey(taskID, 'task')]; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + case ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS: { + const { jobID } = action.payload; + const { mergingConsensus } = state; + + mergingConsensus[makeKey(jobID, 'job')] = true; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + case ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_SUCCESS: { + const { jobID } = action.payload; + const { mergingConsensus } = state; + + mergingConsensus[makeKey(jobID, 'job')] = false; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + case ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_FAILED: { + const { jobID } = action.payload; + const { mergingConsensus } = state; + + delete mergingConsensus[makeKey(jobID, 'job')]; + + return { + ...state, + mergingConsensus: { + ...mergingConsensus, + }, + }; + } + default: + return state; + } +}; diff --git a/cvat-ui/src/reducers/index.ts b/cvat-ui/src/reducers/index.ts index 6da90eb551c..a8fcfdbdcf4 100644 --- a/cvat-ui/src/reducers/index.ts +++ b/cvat-ui/src/reducers/index.ts @@ -87,6 +87,9 @@ export interface JobsState { [tid: number]: boolean; }; }; + regularJobViewUncollapse: { + [tid: number]: boolean; + }; } export interface TasksState { @@ -167,6 +170,16 @@ export interface ImportState { instanceType: 'project' | 'task' | 'job' | null; } +export interface ConsensusState { + fetching: boolean; + consensusSettings: any | null; + taskInstance: any | null; + jobInstance: any | null; + mergingConsensus: { + [tid: string]: boolean; + }; +} + export interface FormatsState { annotationFormats: any; fetching: boolean; @@ -473,6 +486,7 @@ export interface NotificationsState { exporting: null | ErrorState; importing: null | ErrorState; moving: null | ErrorState; + mergingConsensus: null | ErrorState; }; jobs: { updating: null | ErrorState; @@ -596,6 +610,7 @@ export interface NotificationsState { loadingDone: null | NotificationState; importingDone: null | NotificationState; movingDone: null | NotificationState; + mergingConsensusDone: null | NotificationState; }; models: { inferenceDone: null | NotificationState; @@ -1004,6 +1019,7 @@ export interface CombinedState { review: ReviewState; export: ExportState; import: ImportState; + consensus: ConsensusState; cloudStorages: CloudStoragesState; organizations: OrganizationState; invitations: InvitationsState; diff --git a/cvat-ui/src/reducers/jobs-reducer.ts b/cvat-ui/src/reducers/jobs-reducer.ts index c7b07fc1fa3..d126422d6d2 100644 --- a/cvat-ui/src/reducers/jobs-reducer.ts +++ b/cvat-ui/src/reducers/jobs-reducer.ts @@ -20,6 +20,7 @@ const defaultState: JobsState = { activities: { deletes: {}, }, + regularJobViewUncollapse: {}, }; export default (state: JobsState = defaultState, action: JobsActions): JobsState => { @@ -161,6 +162,28 @@ export default (state: JobsState = defaultState, action: JobsActions): JobsState fetching: false, }; } + case JobsActionTypes.COLLAPSE_REGULAR_JOB: { + const { jobID } = action.payload; + state.regularJobViewUncollapse[jobID] = true; + + return { + ...state, + regularJobViewUncollapse: { + ...state.regularJobViewUncollapse, + }, + }; + } + case JobsActionTypes.UNCOLLAPSE_REGULAR_JOB: { + const { jobID } = action.payload; + state.regularJobViewUncollapse[jobID] = false; + + return { + ...state, + regularJobViewUncollapse: { + ...state.regularJobViewUncollapse, + }, + }; + } default: { return state; } diff --git a/cvat-ui/src/reducers/notifications-reducer.ts b/cvat-ui/src/reducers/notifications-reducer.ts index be3f96f10c0..c378973f81c 100644 --- a/cvat-ui/src/reducers/notifications-reducer.ts +++ b/cvat-ui/src/reducers/notifications-reducer.ts @@ -26,6 +26,7 @@ import { ServerAPIActionTypes } from 'actions/server-actions'; import { RequestsActionsTypes, getInstanceType } from 'actions/requests-actions'; import { ImportActionTypes } from 'actions/import-actions'; import { ExportActionTypes } from 'actions/export-actions'; +import { ConsensusActionTypes } from 'actions/consensus-actions'; import config from 'config'; import { NotificationsState } from '.'; @@ -67,6 +68,7 @@ const defaultState: NotificationsState = { exporting: null, importing: null, moving: null, + mergingConsensus: null, }, jobs: { updating: null, @@ -190,6 +192,7 @@ const defaultState: NotificationsState = { loadingDone: null, importingDone: null, movingDone: null, + mergingConsensusDone: null, }, models: { inferenceDone: null, @@ -729,6 +732,74 @@ export default function (state = defaultState, action: AnyAction): Notifications }, }; } + case ConsensusActionTypes.MERGE_CONSENSUS_JOBS_FAILED: { + const { taskID } = action.payload; + return { + ...state, + errors: { + ...state.errors, + tasks: { + ...state.errors.tasks, + mergingConsensus: { + message: `Could not merge the [task ${taskID}](/tasks/${taskID})`, + reason: action.payload.error, + shouldLog: !(action.payload.error instanceof ServerError), + className: 'cvat-notification-notice-merge-task-failed', + }, + }, + }, + }; + } + case ConsensusActionTypes.MERGE_CONSENSUS_JOBS_SUCCESS: { + const { taskID } = action.payload; + return { + ...state, + messages: { + ...state.messages, + tasks: { + ...state.messages.tasks, + mergingConsensusDone: { + message: `Consensus Jobs in the [task ${taskID}](/tasks/${taskID}) \ + have been merged`, + }, + }, + }, + }; + } + case ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_FAILED: { + const { jobID, taskID } = action.payload; + return { + ...state, + errors: { + ...state.errors, + tasks: { + ...state.errors.tasks, + mergingConsensus: { + message: `Could not merge the [job ${jobID}](/tasks/${taskID}/jobs/${jobID})`, + reason: action.payload.error, + shouldLog: !(action.payload.error instanceof ServerError), + className: 'cvat-notification-notice-merge-task-failed', + }, + }, + }, + }; + } + case ConsensusActionTypes.MERGE_SPECIFIC_CONSENSUS_JOBS_SUCCESS: { + const { jobID, taskID } = action.payload; + return { + ...state, + messages: { + ...state.messages, + tasks: { + ...state.messages.tasks, + mergingConsensusDone: { + message: `Consensus Jobs in the [job ${jobID}](/tasks/${taskID}/jobs/${jobID}) \ + have been merged`, + }, + }, + }, + }; + } case TasksActionTypes.CREATE_TASK_FAILED: { return { ...state, diff --git a/cvat-ui/src/reducers/root-reducer.ts b/cvat-ui/src/reducers/root-reducer.ts index 13429f80d59..31a1278911e 100644 --- a/cvat-ui/src/reducers/root-reducer.ts +++ b/cvat-ui/src/reducers/root-reducer.ts @@ -20,6 +20,7 @@ import userAgreementsReducer from './useragreements-reducer'; import reviewReducer from './review-reducer'; import exportReducer from './export-reducer'; import importReducer from './import-reducer'; +import consensusReducer from './consensus-reducer'; import cloudStoragesReducer from './cloud-storages-reducer'; import organizationsReducer from './organizations-reducer'; import webhooksReducer from './webhooks-reducer'; @@ -46,6 +47,7 @@ export default function createRootReducer(): Reducer { review: reviewReducer, export: exportReducer, import: importReducer, + consensus: consensusReducer, cloudStorages: cloudStoragesReducer, organizations: organizationsReducer, webhooks: webhooksReducer, diff --git a/cvat-ui/src/utils/consensus.ts b/cvat-ui/src/utils/consensus.ts new file mode 100644 index 00000000000..1c611c0aaa5 --- /dev/null +++ b/cvat-ui/src/utils/consensus.ts @@ -0,0 +1,88 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { ColumnFilterItem } from 'antd/lib/table/interface'; +import { ConsensusReport } from 'cvat-core-wrapper'; +import config from 'config'; + +export enum ConsensusColors { + GREEN = '#237804', + YELLOW = '#ed9c00', + RED = '#ff4d4f', + GRAY = '#8c8c8c', +} + +const ratios = { + low: 0.82, + middle: 0.9, + high: 1, +}; + +export const consensusColorGenerator = (targetMetric: number) => (value?: number) => { + const baseValue = targetMetric * 100; + + const thresholds = { + low: baseValue * ratios.low, + middle: baseValue * ratios.middle, + high: baseValue * ratios.high, + }; + + if (!value) { + return ConsensusColors.GRAY; + } + + if (value >= thresholds.high) { + return ConsensusColors.GREEN; + } + if (value >= thresholds.middle) { + return ConsensusColors.YELLOW; + } + if (value >= thresholds.low) { + return ConsensusColors.RED; + } + + return ConsensusColors.GRAY; +}; + +export function collectAssignees(reports: ConsensusReport[]): ColumnFilterItem[] { + return Array.from( + new Set( + reports.map((report: ConsensusReport) => report.assignee?.username ?? null), + ), + ).map((value: string | null) => ({ text: value ?? 'Is Empty', value: value ?? false })); +} + +export function toRepresentation(val?: number, isPercent = true, decimals = 1): string { + if (!Number.isFinite(val)) { + return 'N/A'; + } + + let repr = ''; + if (!val || (isPercent && (val === 100))) { + repr = `${val}`; // remove noise in the fractional part + } else { + repr = `${val?.toFixed(decimals)}`; + } + + if (isPercent) { + repr += `${isPercent ? '%' : ''}`; + } + + return repr; +} + +export function percent(a?: number, b?: number, decimals = 1): string | number { + if (typeof a !== 'undefined' && Number.isFinite(a) && b) { + return toRepresentation(Number(a / b) * 100, true, decimals); + } + return 'N/A'; +} + +export function clampValue(a?: number): string | number { + if (typeof a !== 'undefined' && Number.isFinite(a)) { + if (a <= config.NUMERIC_VALUE_CLAMP_THRESHOLD) return a; + return `> ${config.NUMERIC_VALUE_CLAMP_THRESHOLD}`; + } + return 'N/A'; +} diff --git a/cvat-ui/src/utils/validate-integer.ts b/cvat-ui/src/utils/validate-integer.ts new file mode 100644 index 00000000000..9c6178757c0 --- /dev/null +++ b/cvat-ui/src/utils/validate-integer.ts @@ -0,0 +1,37 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { RuleObject } from 'antd/lib/form'; + +export const isInteger = ({ min, max, filter }: { + min?: number; + max?: number; + filter?: (intValue: number) => boolean; +}) => ( + _: RuleObject, + value?: number | string, +): Promise => { + if (typeof value === 'undefined' || value === '') { + return Promise.resolve(); + } + + const intValue = +value; + if (Number.isNaN(intValue) || !Number.isInteger(intValue)) { + return Promise.reject(new Error('Value must be a positive integer')); + } + + if (typeof min !== 'undefined' && intValue < min) { + return Promise.reject(new Error(`Value must be more than ${min}`)); + } + + if (typeof max !== 'undefined' && intValue > max) { + return Promise.reject(new Error(`Value must be less than ${max}`)); + } + + if (filter && !filter(intValue)) { + return Promise.reject(new Error(`Value can not be equal to ${intValue}`)); + } + + return Promise.resolve(); +}; diff --git a/cvat/apps/consensus/__init__.py b/cvat/apps/consensus/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/cvat/apps/consensus/apps.py b/cvat/apps/consensus/apps.py new file mode 100644 index 00000000000..a3504da9c0f --- /dev/null +++ b/cvat/apps/consensus/apps.py @@ -0,0 +1,18 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from django.apps import AppConfig + + +class ConsensusConfig(AppConfig): + name = "cvat.apps.consensus" + + def ready(self) -> None: + + from cvat.apps.iam.permissions import load_app_permissions + + load_app_permissions(self) + + # Required to define signals in the application + from . import signals # pylint: disable=unused-import diff --git a/cvat/apps/consensus/consensus_reports.py b/cvat/apps/consensus/consensus_reports.py new file mode 100644 index 00000000000..5993b54feee --- /dev/null +++ b/cvat/apps/consensus/consensus_reports.py @@ -0,0 +1,555 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import itertools +from collections import Counter +from copy import deepcopy +from functools import cached_property +from types import NoneType +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +import datumaro as dm +import numpy as np +from attrs import define, fields_dict +from datumaro.components.annotation import Annotation +from datumaro.util import dump_json, parse_json +from django.db import transaction + +from cvat.apps.consensus import models +from cvat.apps.consensus.intersect_merge import IntersectMerge +from cvat.apps.consensus.models import ( + AssigneeConsensusReport, + ConsensusConflict, + ConsensusConflictType, + ConsensusReport, + ConsensusSettings, +) +from cvat.apps.dataset_manager.util import bulk_create +from cvat.apps.engine import serializers as engine_serializers +from cvat.apps.engine.models import Job, Task, User +from cvat.apps.quality_control.quality_reports import AnnotationId, JobDataProvider, Serializable + + +@define(kw_only=True) +class AnnotationConflict(Serializable): + frame_id: int + type: models.ConsensusConflictType + annotation_ids: List[AnnotationId] + + def _value_serializer(self, v): + if isinstance(v, models.ConsensusConflictType): + return str(v) + else: + return super()._value_serializer(v) + + @classmethod + def from_dict(cls, d: dict): + return cls( + frame_id=d["frame_id"], + type=models.ConsensusConflictType(d["type"]), + annotation_ids=list(AnnotationId.from_dict(v) for v in d["annotation_ids"]), + ) + + +@define(kw_only=True) +class ComparisonReportComparisonSummary(Serializable): + frames: List[str] + + @property + def mean_conflict_count(self) -> float: + return self.conflict_count / (len(self.frames) or 1) + + conflict_count: int + conflicts_by_type: Dict[models.ConsensusConflictType, int] + + @property + def frame_count(self) -> int: + return len(self.frames) + + def _value_serializer(self, v): + if isinstance(v, models.ConsensusConflictType): + return str(v) + else: + return super()._value_serializer(v) + + def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + return super()._fields_dict( + include_properties=include_properties + or [ + "frame_count", + "mean_conflict_count", + "conflict_count", + "conflicts_by_type", + ] + ) + + @classmethod + def from_dict(cls, d: dict): + return cls( + frames=list(d["frames"]), + conflict_count=d["conflict_count"], + conflicts_by_type={ + models.ConsensusConflictType(k): v + for k, v in d.get("conflicts_by_type", {}).items() + }, + ) + + +@define(kw_only=True, init=False) +class ComparisonReportFrameSummary(Serializable): + conflicts: List[AnnotationConflict] + consensus_score: float + + @cached_property + def conflict_count(self) -> int: + return len(self.conflicts) + + @cached_property + def conflicts_by_type(self) -> Dict[models.ConsensusConflictType, int]: + return Counter(c.type for c in self.conflicts) + + _CACHED_FIELDS = ["conflict_count", "conflicts_by_type"] + + def _value_serializer(self, v): + if isinstance(v, models.ConsensusConflictType): + return str(v) + else: + return super()._value_serializer(v) + + def __init__(self, *args, **kwargs): + # these fields are optional, but can be computed on access + for field_name in self._CACHED_FIELDS: + if field_name in kwargs: + setattr(self, field_name, kwargs.pop(field_name)) + + self.__attrs_init__(*args, **kwargs) + + def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + return super()._fields_dict(include_properties=include_properties or self._CACHED_FIELDS) + + @classmethod + def from_dict(cls, d: dict): + optional_fields = set(cls._CACHED_FIELDS) - { + "conflicts_by_type" # requires extra conversion + } + return cls( + **{field: d[field] for field in optional_fields if field in d}, + **( + dict( + conflicts_by_type={ + models.ConsensusConflictType(k): v + for k, v in d["conflicts_by_type"].items() + } + ) + if "conflicts_by_type" in d + else {} + ), + conflicts=[AnnotationConflict.from_dict(v) for v in d["conflicts"]], + consensus_score=d["consensus_score"], + ) + + +@define(kw_only=True) +class ComparisonParameters(Serializable): + included_annotation_types: List[dm.AnnotationType] = [ + dm.AnnotationType.bbox, + dm.AnnotationType.points, + dm.AnnotationType.mask, + dm.AnnotationType.polygon, + dm.AnnotationType.polyline, + dm.AnnotationType.skeleton, + dm.AnnotationType.label, + ] + + agreement_score_threshold: float + quorum: int + iou_threshold: float + sigma: float + line_thickness: float + + def _value_serializer(self, v): + if isinstance(v, dm.AnnotationType): + return str(v.name) + else: + return super()._value_serializer(v) + + @classmethod + def from_dict(cls, d: dict): + fields = fields_dict(cls) + return cls(**{field_name: d[field_name] for field_name in fields if field_name in d}) + + +@define(kw_only=True) +class ComparisonReport(Serializable): + parameters: ComparisonParameters + comparison_summary: ComparisonReportComparisonSummary + frame_results: Dict[int, ComparisonReportFrameSummary] + + @property + def conflicts(self) -> List[AnnotationConflict]: + return list(itertools.chain.from_iterable(r.conflicts for r in self.frame_results.values())) + + @property + def consensus_score(self) -> int: + mean_consensus_score = 0 + frame_count = 0 + for frame_result in self.frame_results.values(): + if not isinstance(frame_result.consensus_score, NoneType): + mean_consensus_score += frame_result.consensus_score + frame_count += 1 + + return np.round(100 * (mean_consensus_score / (frame_count or 1))) + + def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + return super()._fields_dict( + include_properties=include_properties + or [ + "consensus_score", + ] + ) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> ComparisonReport: + return cls( + parameters=ComparisonParameters.from_dict(d["parameters"]), + comparison_summary=ComparisonReportComparisonSummary.from_dict(d["comparison_summary"]), + frame_results={ + int(k): ComparisonReportFrameSummary.from_dict(v) + for k, v in d["frame_results"].items() + }, + ) + + def to_json(self) -> str: + d = self.to_dict() + + # String keys are needed for json dumping + d["frame_results"] = {str(k): v for k, v in d["frame_results"].items()} + return dump_json(d).decode() + + @classmethod + def from_json(cls, data: str) -> ComparisonReport: + return cls.from_dict(parse_json(data)) + + +def _get_error_type(error: dm.errors) -> Optional[str]: + error_string = None + if isinstance(error, dm.errors.NoMatchingItemError): + error_string = "NoMatchingItemError" + elif isinstance(error, dm.errors.NoMatchingAnnError): + error_string = "NoMatchingAnnError" + elif isinstance(error, dm.errors.AnnotationsTooCloseError): + error_string = "AnnotationsTooCloseError" + elif isinstance(error, dm.errors.FailedLabelVotingError): + error_string = "FailedLabelVotingError" + return ConsensusConflictType[error_string].value if error_string else None + + +def generate_assignee_consensus_report( + consensus_job_ids: List[int], + assignees: List[User], + consensus_datasets: List[dm.Dataset], + dataset_mean_consensus_score: Dict[int, float], +): + assignee_report_data: Dict[User, Dict[str, Union[List[float], float]]] = {} + for idx, _ in enumerate(consensus_job_ids): + if not assignees[idx]: + continue + assignee_report_data.setdefault( + assignees[idx], {"consensus_score": [], "conflict_count": 0} + ) + job_consensus_score = dataset_mean_consensus_score[id(consensus_datasets[idx])] + assignee_report_data[assignees[idx]]["consensus_score"].append( + 0 if np.isnan(job_consensus_score) else job_consensus_score + ) + + for assignee_id, assignee_info in assignee_report_data.items(): + assignee_report_data[assignee_id]["consensus_score"] = sum( + assignee_info["consensus_score"] + ) / (len(assignee_info["consensus_score"]) or 1) + + return assignee_report_data + + +def generate_job_consensus_report( + consensus_settings: ConsensusSettings, + errors, + consensus_job_data_providers: List[JobDataProvider], + merged_dataset: dm.Dataset, + merger: IntersectMerge, + assignees: List[User], + assignee_report_data: Dict[User, Dict[str, float]], +) -> ComparisonReport: + + frame_results: Dict[int, ComparisonReportFrameSummary] = {} + frames = set() + conflicts_count = len(errors) + frame_wise_conflicts: Dict[int, List[AnnotationConflict]] = {} + frame_wise_consensus_score: Dict[int, List[float]] = {} + conflicts: List[AnnotationConflict] = [] + + for error in errors: + error_type = _get_error_type(error) + if not error_type: + continue + annotation_ids = [] + error_annotations = [] + + for arg in error.args: + if isinstance(arg, Annotation): + error_annotations.append(arg) + + for annotation in error_annotations: + # the annotation belongs to which consensus dataset + idx = merger.get_ann_dataset_id(id(annotation)) + annotation_ids.append(consensus_job_data_providers[idx].dm_ann_to_ann_id(annotation)) + if assignees[idx]: + assignee_report_data[assignees[idx]]["conflict_count"] += 1 + + dm_item = consensus_job_data_providers[0].dm_dataset.get(error.item_id[0]) + frame_id: int = consensus_job_data_providers[0].dm_item_id_to_frame_id(dm_item) + frames.add(frame_id) + frame_wise_conflicts.setdefault(frame_id, []).append( + AnnotationConflict( + frame_id=frame_id, + type=error_type, + annotation_ids=annotation_ids, + ) + ) + + # dataset item is a frame in the merged dataset, which corresponds to regular job + for dataset_item in merged_dataset: + frame_id = consensus_job_data_providers[0].dm_item_id_to_frame_id(dataset_item) + frames.add(frame_id) + consensus_score = np.mean( + [ann.attributes.get("score", 0) for ann in dataset_item.annotations] + ) + # if that frame has no annotations, the consensus score is NaN + frame_wise_consensus_score.setdefault(frame_id, []).append( + 0 if np.isnan(consensus_score) else consensus_score + ) + + for frame_id in frames: + conflicts += frame_wise_conflicts.get(frame_id, []) + frame_results[frame_id] = ComparisonReportFrameSummary( + conflicts=frame_wise_conflicts.get(frame_id, []), + consensus_score=np.mean(frame_wise_consensus_score.get(frame_id, [0])), + ) + + return ( + ComparisonReport( + parameters=ComparisonParameters.from_dict(consensus_settings.to_dict()), + comparison_summary=ComparisonReportComparisonSummary( + frames=list(frames), + conflict_count=conflicts_count, + conflicts_by_type=Counter(c.type for c in conflicts), + ), + frame_results=frame_results, + ), + assignee_report_data, + ) + + +def generate_task_consensus_report( + job_reports: List[ComparisonReport], +) -> Tuple[ComparisonReport, int]: + task_frames = set() + task_conflicts: List[AnnotationConflict] = [] + task_frame_results = {} + task_frame_results_counts = {} + task_mean_consensus_score = 0 + for r in job_reports: + task_frames.update(r.comparison_summary.frames) + task_conflicts.extend(r.conflicts) + task_mean_consensus_score += r.consensus_score + + for frame_id, job_frame_result in r.frame_results.items(): + task_frame_result = cast( + Optional[ComparisonReportFrameSummary], task_frame_results.get(frame_id) + ) + frame_results_count = task_frame_results_counts.get(frame_id, 0) + + if task_frame_result is None: + task_frame_result = deepcopy(job_frame_result) + else: + task_frame_result.conflicts += job_frame_result.conflicts + task_frame_result.consensus_score = ( + task_frame_result.consensus_score * task_frame_results_counts[frame_id] + + job_frame_result.consensus_score + ) / (task_frame_results_counts[frame_id] + 1) + + task_frame_results_counts[frame_id] = 1 + frame_results_count + task_frame_results[frame_id] = task_frame_result + + task_mean_consensus_score /= len(job_reports) + task_report_data = ComparisonReport( + parameters=job_reports[0].parameters, + comparison_summary=ComparisonReportComparisonSummary( + frames=sorted(task_frames), + conflict_count=len(task_conflicts), + conflicts_by_type=Counter(c.type for c in task_conflicts), + ), + frame_results=task_frame_results, + ) + return task_report_data, np.round(task_mean_consensus_score) + + +@transaction.atomic +def save_report( + task_id: int, + jobs: List[Job], + task_report_data: ComparisonReport, + job_report_data: Dict[int, ComparisonReport], + assignee_report_data: Dict[User, float], + task_mean_consensus_score: int, +): + try: + Task.objects.get(id=task_id) + except Task.DoesNotExist: + return + + task = Task.objects.filter(id=task_id).first() + + job_reports = {} + for job in jobs: + job_comparison_report = job_report_data[job.id] + job_consensus_score = job_comparison_report.consensus_score + job_report = dict( + job=job, + target_last_updated=job.updated_date, + data=job_comparison_report.to_json(), + conflicts=[c.to_dict() for c in job_comparison_report.conflicts], + consensus_score=job_consensus_score, + assignee=job.assignee, + ) + job_reports[job.id] = job_report + + job_reports = list(job_reports.values()) + + task_report = dict( + task=task, + target_last_updated=task.updated_date, + data=task_report_data.to_json(), + conflicts=[], # the task doesn't have own conflicts + consensus_score=task_mean_consensus_score, + assignee=task.assignee, + ) + + db_task_report = ConsensusReport( + task=task_report["task"], + target_last_updated=task_report["target_last_updated"], + data=task_report["data"], + consensus_score=task_report["consensus_score"], + assignee=task_report["assignee"], + ) + db_task_report.save() + + db_job_reports = [] + for job_report in job_reports: + db_job_report = ConsensusReport( + task=task_report["task"], + job=job_report["job"], + target_last_updated=job_report["target_last_updated"], + data=job_report["data"], + consensus_score=job_report["consensus_score"], + assignee=job_report["assignee"], + parent=db_task_report, + ) + db_job_reports.append(db_job_report) + + db_job_reports = bulk_create(db_model=ConsensusReport, objects=db_job_reports, flt_param={}) + + for assignee, assignee_info in assignee_report_data.items(): + # db_assignee = models.User.objects.get(id=) + db_assignee_report = AssigneeConsensusReport( + task=task_report["task"], + consensus_score=np.round(100 * assignee_info["consensus_score"]), + conflict_count=assignee_info["conflict_count"], + assignee=assignee, + consensus_report_id=db_task_report.id, + ) + db_assignee_report.save() + + db_conflicts = [] + db_report_iter = itertools.chain([db_task_report], db_job_reports) + report_iter = itertools.chain([task_report], job_reports) + for report, db_report in zip(report_iter, db_report_iter): + if not db_report.id: + continue + for conflict in report["conflicts"]: + db_conflict = ConsensusConflict( + report=db_report, + type=conflict["type"], + frame=conflict["frame_id"], + ) + db_conflicts.append(db_conflict) + + db_conflicts = bulk_create(db_model=ConsensusConflict, objects=db_conflicts, flt_param={}) + + db_ann_ids = [] + db_conflicts_iter = iter(db_conflicts) + for report in itertools.chain([task_report], job_reports): + for conflict, db_conflict in zip(report["conflicts"], db_conflicts_iter): + for ann_id in conflict["annotation_ids"]: + db_ann_id = models.AnnotationId( + conflict=db_conflict, + job_id=ann_id["job_id"], + obj_id=ann_id["obj_id"], + type=ann_id["type"], + shape_type=ann_id["shape_type"], + ) + db_ann_ids.append(db_ann_id) + + db_ann_ids = bulk_create(db_model=models.AnnotationId, objects=db_ann_ids, flt_param={}) + + return db_task_report.id + + +def prepare_report_for_downloading(db_report: ConsensusReport, *, host: str) -> str: + # copied from quality_reports.py + # Decorate the report for better usability and readability: + # - add conflicting annotation links like: + # /tasks/62/jobs/82?frame=250&type=shape&serverID=33741 + # - convert some fractions to percents + # - add common report info + + def _serialize_assignee(assignee: Optional[User]) -> Optional[dict]: + if not db_report.assignee: + return None + + reported_keys = ["id", "username", "first_name", "last_name"] + assert set(reported_keys).issubset(engine_serializers.BasicUserSerializer.Meta.fields) + # check that only safe fields are reported + + return {k: getattr(assignee, k) for k in reported_keys} + + task_id = db_report.get_task().id + serialized_data = dict( + job_id=db_report.job.id if db_report.job is not None else None, + task_id=task_id, + parent_id=db_report.parent.id if db_report.parent is not None else None, + created_date=str(db_report.created_date), + target_last_updated=str(db_report.target_last_updated), + assignee=_serialize_assignee(db_report.assignee), + ) + + comparison_report = ComparisonReport.from_json(db_report.get_json_report()) + serialized_data.update(comparison_report.to_dict()) + + for frame_result in serialized_data["frame_results"].values(): + for conflict in frame_result["conflicts"]: + for ann_id in conflict["annotation_ids"]: + ann_id["url"] = ( + f"{host}tasks/{task_id}/jobs/{ann_id['job_id']}" + f"?frame={conflict['frame_id']}" + f"&type={ann_id['type']}" + f"&serverID={ann_id['obj_id']}" + ) + + # String keys are needed for json dumping + serialized_data["frame_results"] = { + str(k): v for k, v in serialized_data["frame_results"].items() + } + return dump_json(serialized_data, indent=True, append_newline=True).decode() diff --git a/cvat/apps/consensus/intersect_merge.py b/cvat/apps/consensus/intersect_merge.py new file mode 100644 index 00000000000..cde75847f66 --- /dev/null +++ b/cvat/apps/consensus/intersect_merge.py @@ -0,0 +1,879 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import itertools +import logging as log +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import attr +import datumaro as dm +import numpy as np +from attr import attrib, attrs +from datumaro.components.annotation import AnnotationType, Bbox +from datumaro.components.dataset import Dataset +from datumaro.components.errors import FailedLabelVotingError, NoMatchingItemError +from datumaro.components.operations import ExactMerge +from datumaro.util.annotation_util import find_instances, max_bbox, mean_bbox +from datumaro.util.attrs_util import ensure_cls + +from cvat.apps.engine.models import Label +from cvat.apps.quality_control.quality_reports import DistanceComparator, KeypointsMatcher +from cvat.apps.quality_control.quality_reports import LineMatcher as LineMatcherQualityReports +from cvat.apps.quality_control.quality_reports import match_segments, oks, segment_iou, to_rle + + +@attrs +class IntersectMerge(dm.ops.IntersectMerge): + @attrs(repr_ns="IntersectMerge", kw_only=True) + class Conf: + pairwise_dist = attrib(converter=float, default=0.5) + sigma = attrib(converter=float, factory=float) + + output_conf_thresh = attrib(converter=float, default=0) + quorum = attrib(converter=int, default=0.1) + ignored_attributes = attrib(converter=set, factory=set) + torso_r = attrib(converter=float, default=0.01) + + groups = [] + close_distance = attrib(converter=float, default=0.75) + + conf = attrib(converter=ensure_cls(Conf), factory=Conf) + + # Error trackers: + errors = attrib(factory=list, init=False) + + def add_item_error(self, error, *args, **kwargs): + self.errors.append(error(self._item_id, *args, **kwargs)) + + # Indexes: + _dataset_map = attrib(init=False) # id(dataset) -> (dataset, index) + _item_map = attrib(init=False) # id(item) -> (item, id(dataset)) + _ann_map = attrib(init=False) # id(ann) -> (ann, id(item)) + _item_id = attrib(init=False) + _item = attrib(init=False) + dataset_mean_consensus_score = attrib(init=False) # id(dataset) -> mean consensus score: float + + # Misc. + _categories = attrib(init=False) # merged categories + + def __call__(self, datasets): + self.errors = [] + self._categories = self._merge_categories([d.categories() for d in datasets]) + merged = Dataset( + categories=self._categories, + media_type=ExactMerge.merge_media_types(datasets), + ) + + self._check_groups_definition() + + item_matches, item_map = self.match_items(datasets) + self._item_map = item_map + self.dataset_mean_consensus_score = {id(d): [] for d in datasets} + self._dataset_map = {id(d): (d, i) for i, d in enumerate(datasets)} + self._ann_map = {} + + for item_id, items in item_matches.items(): + self._item_id = item_id + + if len(items) < len(datasets): + missing_sources = set(id(s) for s in datasets) - set(items) + missing_sources = [self._dataset_map[s][1] for s in missing_sources] + self.add_item_error(NoMatchingItemError, sources=missing_sources) + merged.put(self.merge_items(items)) + + # now we have consensus score for all annotations in + for dataset_id in self.dataset_mean_consensus_score: + self.dataset_mean_consensus_score[dataset_id] = np.mean( + self.dataset_mean_consensus_score[dataset_id] + ) + + return merged + + def get_ann_dataset_id(self, ann_id: int) -> int: + return self._dataset_map[self.get_ann_source(ann_id)][1] + + def get_item_media_dims(self, ann_id: int) -> Tuple[int, int]: + return self._item_map[self._ann_map[ann_id][1]][0].image.size + + def get_label_id(self, label): + return self._get_label_id(label) + + def get_src_label_name(self, ann, label_id): + return self._get_src_label_name(ann, label_id) + + def get_dataset_source_id(self, dataset_id: int): + return self._dataset_map[dataset_id][1] + + def dataset_count(self) -> int: + return len(self._dataset_map) + + def merge_items(self, items): + self._item = next(iter(items.values())) + + sources = [] # [annotation of frame 0, frame 1, ...] + for item in items.values(): + self._ann_map.update({id(a): (a, id(item)) for a in item.annotations}) + sources.append(item.annotations) + log.debug( + "Merging item %s: source annotations %s" % (self._item_id, list(map(len, sources))) + ) + + annotations = self.merge_annotations(sources) + + annotations = [ + a for a in annotations if self.conf.output_conf_thresh <= a.attributes.get("score", 1) + ] + + for annotation in annotations: + annotation.attributes["source"] = "consensus" + + return self._item.wrap(annotations=annotations) + + def _make_mergers(self, sources): + def _make(c, **kwargs): + kwargs.update(attr.asdict(self.conf)) + fields = attr.fields_dict(c) + return c(**{k: v for k, v in kwargs.items() if k in fields}, context=self) + + def _for_type(t, **kwargs): + if t is AnnotationType.label: + return _make(LabelMerger, **kwargs) + elif t is AnnotationType.bbox: + return _make(BboxMerger, **kwargs) + elif t is AnnotationType.mask: + return _make(MaskMerger, **kwargs) + elif t is AnnotationType.polygon or t is AnnotationType.mask: + return _make(PolygonMerger, **kwargs) + elif t is AnnotationType.polyline: + return _make(LineMerger, **kwargs) + elif t is AnnotationType.points: + return _make(PointsMerger, **kwargs) + elif t is AnnotationType.skeleton: + return _make(SkeletonMerger, **kwargs) + # else: + # pass + # raise NotImplementedError("Type %s is not supported" % t) + + instance_map = {} + for s in sources: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox( + [ + a + for a in inst + if a.type + in { + AnnotationType.polygon, + AnnotationType.mask, + AnnotationType.bbox, + } + ] + ) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + + self._mergers = { + t: _for_type(t, instance_map=instance_map, categories=self._categories) + for t in AnnotationType + } + + def get_any_label_name(self, ann, label_id): + if label_id is None: + return None + try: + return self._get_src_label_name(ann, label_id) + except KeyError: + return self._get_label_name(label_id) + + +@attrs(kw_only=True) +class AnnotationMatcher: + _context: Optional[IntersectMerge] = attrib(default=None) + + def match_annotations(self, sources): + raise NotImplementedError() + + +@attrs +class LabelMatcher(AnnotationMatcher): + def distance(self, a, b): + a_label = self._context.get_any_label_name(a, a.label) + b_label = self._context.get_any_label_name(b, b.label) + return a_label == b_label + + def match_annotations(self, sources): + return [sum(sources, [])] + + +class CachedSimilarityFunction: + def __init__( + self, + sim_fn: Callable[[dm.Annotation, dm.Annotation], float], + *, + cache: Optional[Dict[Tuple[int, int], float]] = None, + ) -> None: + self.cache: Dict[Tuple[int, int], float] = cache or {} + self.sim_fn = sim_fn + + def __call__(self, a_ann: dm.Annotation, b_ann: dm.Annotation) -> float: + a_ann_id = id(a_ann) + b_ann_id = id(b_ann) + + if a_ann_id == b_ann_id: + return 1 + key: Tuple[int, int] = ( + a_ann_id, + b_ann_id, + ) # make sure the annotations have stable ids before calling this + key = self._sort_key(key) + cached_value = self.cache.get(key) + + if cached_value is None: + cached_value = self.sim_fn(a_ann, b_ann) + self.cache[key] = cached_value + + return cached_value + + @staticmethod + def _sort_key(key: Tuple[int, int]) -> Tuple[int, int]: + key: Union[List[int, int], Tuple[int, int]] = list(key) + key.sort() + key = tuple(key) + return key + + def pop(self, key: Tuple[int, int]) -> float: + return self.cache.pop(self._sort_key(key), None) + + def set(self, key: Tuple[int, int], value: float): + self.cache[self._sort_key(key)] = value + + def keys(self): + return self.cache.keys() + + def clear_cache(self): + self.cache.clear() + + +@attrs(kw_only=True) +class _ShapeMatcher(AnnotationMatcher): + pairwise_dist = attrib(converter=float, default=0.9) + cluster_dist = attrib(converter=float, default=-1.0) + categories = attrib(converter=dict, default={}) + distance_index = attrib(converter=dict, default={}) + _distance_comparator = attrib(converter=DistanceComparator, default={}) + _distance = attrib(converter=CachedSimilarityFunction, default=None) + + def __attrs_post_init__(self): + self._distance_comparator = DistanceComparator( + categories=self.categories, + iou_threshold=self._context.conf.pairwise_dist, + oks_sigma=self._context.conf.sigma, + line_torso_radius=self._context.conf.torso_r, + ) + self._distance = CachedSimilarityFunction(self._distance_func) + + def _distance_func(self, item_a, item_b): + return dm.ops.segment_iou(item_a, item_b) + + # def _distance(self) -> CachedSimilarityFunction: + # return CachedSimilarityFunction(self._distance_func) + + def distance(self, a, b): + return self._distance(a, b) + + def label_matcher(self, a, b): + a_label = self._context.get_any_label_name(a, a.label) + b_label = self._context.get_any_label_name(b, b.label) + return a_label == b_label + + @staticmethod + def _get_ann_type(t, item: dm.Annotation) -> Sequence[dm.Annotation]: + return [a for a in item if a.type == t and not a.attributes.get("outside", False)] + + def _match_segments( + self, + t, + item_a: List[dm.Annotation], + item_b: List[dm.Annotation], + *, + distance: Callable = distance, + label_matcher: Callable = None, + a_objs: Optional[Sequence[dm.Annotation]] = None, + b_objs: Optional[Sequence[dm.Annotation]] = None, + dist_thresh: Optional[float] = None, + ): + if label_matcher is None: + label_matcher = self.label_matcher + if dist_thresh is None: + dist_thresh = self.pairwise_dist + item_a = dm.DatasetItem(id=1, annotations=item_a) + item_b = dm.DatasetItem(id=2, annotations=item_b) + return self._distance_comparator.match_segments( + t=t, + item_a=item_a, + item_b=item_b, + distance=distance, + label_matcher=label_matcher, + a_objs=a_objs, + b_objs=b_objs, + dist_thresh=dist_thresh, + ) + + def match_annotations_two_sources( + self, item_a: List[dm.Annotation], item_b: List[dm.Annotation] + ) -> List[dm.Annotation]: + return [] + + def match_annotations(self, sources): + distance = self.distance + pairwise_dist = self.pairwise_dist + cluster_dist = self.cluster_dist + + if cluster_dist < 0: + cluster_dist = pairwise_dist + + id_segm = {id(a): (a, id(s)) for s in sources for a in s} + + def _is_close_enough(cluster, extra_id): + # check if whole cluster IoU will not be broken + # when this segment is added + b = id_segm[extra_id][0] + for a_id in cluster: + a = id_segm[a_id][0] + if distance(a, b) < cluster_dist: + return False + return True + + def _has_same_source(cluster, extra_id): + b = id_segm[extra_id][1] + for a_id in cluster: + a = id_segm[a_id][1] + if a == b: + return True + return False + + # match segments in sources, pairwise + adjacent = {i: [] for i in id_segm} # id(sgm) -> [id(adj_sgm1), ...] + for a_idx, src_a in enumerate(sources): + # matches further sources of same frame for matching annotations + for src_b in sources[a_idx + 1 :]: + # an annotation can be adjacent to multiple annotations + matches = self.match_annotations_two_sources( + src_a, + src_b, + ) + for a, b in matches: + adjacent[id(a)].append(id(b)) + + # join all segments into matching clusters + clusters = [] + visited = set() + for cluster_idx in adjacent: + if cluster_idx in visited: + continue + + cluster = set() + to_visit = {cluster_idx} + while to_visit: + c = to_visit.pop() + cluster.add(c) + visited.add(c) + + for i in adjacent[c]: + if i in visited: + # if that annotation is already in another cluster + continue + if 0 < cluster_dist and not _is_close_enough(cluster, i): + # if positive cluster_dist and this annotation isn't close enough with other annotations in + # cluster + continue + if _has_same_source(cluster, i): + # if both the annotation are belong to the same frame in same consensus job + continue + + to_visit.add( + i + ) # check whether annotations matching this element in cluster can be added in this cluster + + clusters.append([id_segm[i][0] for i in cluster]) + + return clusters + + +@attrs +class BboxMatcher(_ShapeMatcher): + def _distance_func(self, item_a, item_b): + def _bbox_iou(a: dm.Bbox, b: dm.Bbox, *, img_w: int, img_h: int) -> float: + if a.attributes.get("rotation", 0) == b.attributes.get("rotation", 0): + return dm.ops.bbox_iou(a, b) + else: + return segment_iou( + self._distance_comparator.to_polygon(a), + self._distance_comparator.to_polygon(b), + img_h=img_h, + img_w=img_w, + ) + + img_h, img_w = self._context.get_item_media_dims(id(item_a)) + return _bbox_iou(item_a, item_b, img_h=img_h, img_w=img_w) + + def match_annotations_two_sources(self, item_a: List[dm.Bbox], item_b: List[dm.Bbox]): + return self._match_segments( + dm.AnnotationType.bbox, + item_a, + item_b, + distance=self.distance, + )[0] + + +@attrs +class PolygonMatcher(_ShapeMatcher): + def _distance_func(self, item_a, item_b): + from pycocotools import mask as mask_utils + + def _get_segment(item): + img_h, img_w = self._context.get_item_media_dims(id(item)) + object_rle_groups = [to_rle(item, img_h=img_h, img_w=img_w)] + rle = mask_utils.merge(list(itertools.chain.from_iterable(object_rle_groups))) + return rle + + a_segm = _get_segment(item_a) + b_segm = _get_segment(item_b) + return float(mask_utils.iou([b_segm], [a_segm], [0])[0]) + + def match_annotations_two_sources( + self, item_a: List[Union[dm.Polygon, dm.Mask]], item_b: List[Union[dm.Polygon, dm.Mask]] + ): + def _get_segmentations(item): + return self._get_ann_type(dm.AnnotationType.polygon, item) + self._get_ann_type( + dm.AnnotationType.mask, item + ) + + img_h, img_w = self._context.get_item_media_dims(id(item_a[0])) + + def _find_instances(annotations): + # Group instance annotations by label. + # Annotations with the same label and group will be merged, + # and considered item_a single object in comparison + instances = [] + instance_map = {} # ann id -> instance id + for ann_group in dm.ops.find_instances(annotations): + ann_group = sorted(ann_group, key=lambda a: a.label) + for _, label_group in itertools.groupby(ann_group, key=lambda a: a.label): + label_group = list(label_group) + instance_id = len(instances) + instances.append(label_group) + for ann in label_group: + instance_map[id(ann)] = instance_id + + return instances, instance_map + + a_instances, _ = _find_instances(_get_segmentations(item_a)) + b_instances, _ = _find_instances(_get_segmentations(item_b)) + + a_compiled_mask = None + b_compiled_mask = None + + segment_cache = {} + + def _get_segment( + obj_id: int, *, compiled_mask: Optional[dm.CompiledMask] = None, instances + ): + key = (id(instances), obj_id) + rle = segment_cache.get(key) + + if rle is None: + from pycocotools import mask as mask_utils + + if compiled_mask is not None: + mask = compiled_mask.extract(obj_id + 1) + + rle = mask_utils.encode(mask) + else: + # Create merged RLE for the instance shapes + object_anns = instances[obj_id] + object_rle_groups = [ + to_rle(ann, img_h=img_h, img_w=img_w) for ann in object_anns + ] + rle = mask_utils.merge(list(itertools.chain.from_iterable(object_rle_groups))) + + segment_cache[key] = rle + + return rle + + def _segment_comparator(a_inst_id: int, b_inst_id: int) -> float: + a_segm = _get_segment(a_inst_id, compiled_mask=a_compiled_mask, instances=a_instances) + b_segm = _get_segment(b_inst_id, compiled_mask=b_compiled_mask, instances=b_instances) + + from pycocotools import mask as mask_utils + + return float(mask_utils.iou([b_segm], [a_segm], [0])[0]) + + def _label_matcher(a_inst_id: int, b_inst_id: int) -> bool: + # labels are the same in the instance annotations + # instances are required to have the same labels in all shapes + a = a_instances[a_inst_id][0] + b = b_instances[b_inst_id][0] + return a.label == b.label + + results = self._match_segments( + dm.AnnotationType.polygon, + item_a, + item_b, + a_objs=range(len(a_instances)), + b_objs=range(len(b_instances)), + distance=_segment_comparator, + label_matcher=_label_matcher, + ) + + # restore results for original annotations + matched = results[0] + + # i_x ~ instance idx in _x + # ia_x ~ instance annotation in _x + matched = [ + (ia_a, ia_b) + for (i_a, i_b) in matched + for (ia_a, ia_b) in itertools.product(a_instances[i_a], b_instances[i_b]) + ] + + return matched + + +@attrs +class MaskMatcher(PolygonMatcher): + pass + + +@attrs(kw_only=True) +class PointsMatcher(_ShapeMatcher): + sigma: Optional[list] = attrib(default=None) + instance_map = attrib(converter=dict) + + def _distance_func(self, a, b): + for instance_group in [[a], [b]]: + instance_bbox = self._distance_comparator.instance_bbox(instance_group) + + for ann in instance_group: + if ann.type == dm.AnnotationType.points: + self.instance_map[id(ann)] = [instance_group, instance_bbox] + + img_h, img_w = self._context.get_item_media_dims(id(a)) + a_bbox = self.instance_map[id(a)][1] + b_bbox = self.instance_map[id(b)][1] + a_area = a_bbox[2] * a_bbox[3] + b_area = b_bbox[2] * b_bbox[3] + + if a_area == 0 and b_area == 0: + # Simple case: singular points without bbox + # match them in the image space + return oks(a, b, sigma=self.sigma, scale=img_h * img_w) + + else: + # Complex case: multiple points, grouped points, points with item_a bbox + # Try to align points and then return the metric + # match them in their bbox space + + if dm.ops.bbox_iou(a_bbox, b_bbox) <= 0: + return 0 + + bbox = dm.ops.mean_bbox([a_bbox, b_bbox]) + scale = bbox[2] * bbox[3] + + a_points = np.reshape(a.points, (-1, 2)) + b_points = np.reshape(b.points, (-1, 2)) + + matches, mismatches, a_extra, b_extra = match_segments( + range(len(a_points)), + range(len(b_points)), + distance=lambda ai, bi: oks( + dm.Points(a_points[ai]), + dm.Points(b_points[bi]), + sigma=self.sigma, + scale=scale, + ), + dist_thresh=self._distance_comparator.iou_threshold, + label_matcher=lambda ai, bi: True, + ) + + # the exact array is determined by the label matcher + # all the points will have the same match status, + # because there is only 1 shared label for all the points + matched_points = matches + mismatches + + a_sorting_indices = [ai for ai, _ in matched_points] + a_points = a_points[a_sorting_indices] + + b_sorting_indices = [bi for _, bi in matched_points] + b_points = b_points[b_sorting_indices] + + # Compute oks for 2 groups of points, matching points aligned + dists = np.linalg.norm(a_points - b_points, axis=1) + return np.sum(np.exp(-(dists**2) / (2 * scale * (2 * self.sigma) ** 2))) / ( + len(matched_points) + len(a_extra) + len(b_extra) + ) + + def match_annotations_two_sources(self, item_a: List[dm.Points], item_b: List[dm.Points]): + a_points = self._get_ann_type(dm.AnnotationType.points, item_a) + b_points = self._get_ann_type(dm.AnnotationType.points, item_b) + + return self._match_segments( + dm.AnnotationType.points, + item_a, + item_b, + a_objs=a_points, + b_objs=b_points, + distance=self.distance, + )[0] + + +class SkeletonMatcher(_ShapeMatcher): + return_distances = True + sigma: float = 0.1 + instance_map = {} + skeleton_map = {} + + def _distance_func(self, a, b): + matcher = KeypointsMatcher(instance_map=self.instance_map, sigma=self.sigma) + if isinstance(a, dm.Skeleton) and isinstance(b, dm.Skeleton): + return self.distance(a, b) + return matcher.distance(a, b) + + def match_annotations_two_sources( + self, a_skeletons: List[dm.Skeleton], b_skeletons: List[dm.Skeleton] + ): + if not a_skeletons and not b_skeletons: + return [] + + # Convert skeletons to point lists for comparison + # This is required to compute correct per-instance distance + # It is assumed that labels are the same in the datasets + skeleton_infos = {} + points_map = {} + a_points = [] + b_points = [] + for source, source_points in [(a_skeletons, a_points), (b_skeletons, b_points)]: + for skeleton in source: + skeleton_info = skeleton_infos.setdefault( + skeleton.label, self._distance_comparator._get_skeleton_info(skeleton.label) + ) + + # Merge skeleton points into item_a single list + # The list is ordered by skeleton_info + skeleton_points = [ + next((p for p in skeleton.elements if p.label == sub_label), None) + for sub_label in skeleton_info + ] + + # Build item_a single Points object for further comparisons + merged_points = dm.Points() + merged_points.points = np.ravel( + [p.points if p else [0, 0] for p in skeleton_points] + ) + merged_points.visibility = np.ravel( + [p.visibility if p else [dm.Points.Visibility.absent] for p in skeleton_points] + ) + merged_points.label = skeleton.label + # no per-point attributes currently in CVAT + + if all(v == dm.Points.Visibility.absent for v in merged_points.visibility): + # The whole skeleton is outside, exclude it + self.skeleton_map[id(skeleton)] = None + continue + + points_map[id(merged_points)] = skeleton + self.skeleton_map[id(skeleton)] = merged_points + source_points.append(merged_points) + + for source in [a_skeletons, b_skeletons]: + for instance_group in dm.ops.find_instances(source): + instance_bbox = self._distance_comparator.instance_bbox(instance_group) + + instance_group = [ + self.skeleton_map[id(a)] if isinstance(a, dm.Skeleton) else a + for a in instance_group + if not isinstance(a, dm.Skeleton) or self.skeleton_map[id(a)] is not None + ] + for ann in instance_group: + self.instance_map[id(ann)] = [instance_group, instance_bbox] + + results = self._match_segments( + dm.AnnotationType.points, + a_skeletons, + b_skeletons, + a_objs=a_points, + b_objs=b_points, + distance=self.distance, + ) + + # Map points back to skeletons + if self.return_distances: + distances = self._distance + for p_a_id, p_b_id in list(distances.keys()): + dist = distances.pop((p_a_id, p_b_id)) + distances.set((id(points_map[p_a_id]), id(points_map[p_b_id])), dist) + + return [(points_map[id(p_a)], points_map[id(p_b)]) for (p_a, p_b) in results[0]] + + +@attrs +class LineMatcher(_ShapeMatcher): + def _distance_func(self, item_a, item_b): + img_h, img_w = self._context.get_item_media_dims(id(item_a)) + matcher = LineMatcherQualityReports( + torso_r=self._distance_comparator.line_torso_radius, + scale=np.prod([img_h, img_w]), + ) + return matcher.distance(item_a, item_b) + + def match_annotations_two_sources(self, item_a: List[dm.PolyLine], item_b: List[dm.PolyLine]): + return self._match_segments( + dm.AnnotationType.polyline, item_a, item_b, distance=self.distance + )[0] + + +@attrs(kw_only=True) +class LabelMerger(LabelMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + assert len(clusters) <= 1 + if len(clusters) == 0: + return [] + + votes = {} # label -> score + for ann in clusters[0]: + label = self._context.get_src_label_name(ann, ann.label) + votes[label] = 1 + votes.get(label, 0) + + merged = [] + for label, count in votes.items(): + if count < self.quorum: + sources = set( + self._context.get_ann_source(id(a)) + for a in clusters[0] + if label not in [self._context.get_src_label_name(l, l.label) for l in a] + ) + sources = [self._context.get_dataset_source_id(s) for s in sources] + self._context.add_item_error(FailedLabelVotingError, votes, sources=sources) + continue + + merged.append( + Label( + self._context.get_label_id(label), + attributes={"score": count / self._context.dataset_count()}, + ) + ) + + return merged + + +@attrs(kw_only=True) +class _ShapeMerger(_ShapeMatcher): + quorum = attrib(converter=int, default=0) + + def merge_clusters(self, clusters): + return list(filter(lambda x: x is not None, map(self.merge_cluster, clusters))) + + def find_cluster_label(self, cluster): + votes = {} + for s in cluster: + label = self._context.get_src_label_name(s, s.label) + state = votes.setdefault(label, [0, 0]) + state[0] += s.attributes.get("score", 1.0) + state[1] += 1 + + label, (score, count) = max(votes.items(), key=lambda e: e[1][0]) + if count < self.quorum: + self._context.add_item_error(FailedLabelVotingError, votes) + label = None + score = score / self._context.dataset_count() + label = self._context.get_label_id(label) + return label, score + + def _merge_cluster_shape_mean_box_nearest(self, cluster): + mbbox = Bbox(*mean_bbox(cluster)) + a = cluster[0] + img_h, img_w = self._context.get_item_media_dims(id(a)) + dist = [] + for s in cluster: + if isinstance(s, dm.Points) or isinstance(s, dm.PolyLine): + s = self._distance_comparator.to_polygon(Bbox(*s.get_bbox())) + elif isinstance(s, dm.Bbox): + s = self._distance_comparator.to_polygon(s) + dist.append( + segment_iou( + self._distance_comparator.to_polygon(mbbox), s, img_h=img_h, img_w=img_w + ) + ) + nearest_pos, _ = max(enumerate(dist), key=lambda e: e[1]) + return cluster[nearest_pos] + + def merge_cluster_shape(self, cluster): + shape = self._merge_cluster_shape_mean_box_nearest(cluster) + for ann in cluster: + dataset_id = self._context.get_ann_source(id(ann)) + self._context.dataset_mean_consensus_score.setdefault(dataset_id, []).append( + max(0, self.distance(ann, shape)) + ) + shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster) + return shape, shape_score + + def merge_cluster(self, cluster): + label, label_score = self.find_cluster_label(cluster) + + # when the merged annotation is rejected due to quorum constraint + if label is None: + return None + + shape, shape_score = self.merge_cluster_shape(cluster) + shape.z_order = max(cluster, key=lambda a: a.z_order).z_order + shape.label = label + shape.attributes["score"] = label_score * shape_score + + return shape + + +@attrs +class BboxMerger(_ShapeMerger, BboxMatcher): + pass + + +@attrs +class PolygonMerger(_ShapeMerger, PolygonMatcher): + pass + + +@attrs +class MaskMerger(_ShapeMerger, MaskMatcher): + pass + + +@attrs +class PointsMerger(_ShapeMerger, PointsMatcher): + pass + + +@attrs +class LineMerger(_ShapeMerger, LineMatcher): + pass + + +class SkeletonMerger(_ShapeMerger, SkeletonMatcher): + def _merge_cluster_shape_nearest(self, cluster): + dist = {} + for idx, skeleton1 in enumerate(cluster): + skeleton_distance = 0 + for skeleton2 in cluster: + skeleton_distance += self.distance(skeleton1, skeleton2) + + dist[idx] = skeleton_distance / len(cluster) + + return cluster[min(dist, key=dist.get)] + + def merge_cluster_shape(self, cluster): + shape = self._merge_cluster_shape_nearest(cluster) + shape_score = sum(max(0, self.distance(shape, s)) for s in cluster) / len(cluster) + return shape, shape_score diff --git a/cvat/apps/consensus/merging_manager.py b/cvat/apps/consensus/merging_manager.py new file mode 100644 index 00000000000..f76fbcacb90 --- /dev/null +++ b/cvat/apps/consensus/merging_manager.py @@ -0,0 +1,282 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import Dict, List, Tuple, Union +from uuid import uuid4 + +import datumaro as dm +import django_rq +from django.conf import settings +from django.db import transaction +from rest_framework import status +from rest_framework.exceptions import ValidationError +from rest_framework.response import Response + +from cvat.apps.consensus.consensus_reports import ( + ComparisonReport, + generate_assignee_consensus_report, + generate_job_consensus_report, + generate_task_consensus_report, + save_report, +) +from cvat.apps.consensus.intersect_merge import IntersectMerge +from cvat.apps.consensus.models import AssigneeConsensusReport, ConsensusReport, ConsensusSettings +from cvat.apps.dataset_manager.bindings import import_dm_annotations +from cvat.apps.dataset_manager.task import PatchAction, patch_job_data +from cvat.apps.engine.models import Job, JobType, StageChoice, StateChoice, Task, User +from cvat.apps.engine.serializers import RqIdSerializer +from cvat.apps.engine.utils import ( + define_dependent_job, + get_rq_job_meta, + get_rq_lock_by_user, + process_failed_job, +) +from cvat.apps.quality_control.quality_reports import JobDataProvider + + +class MergeConsensusJobs: + def __init__(self, task_id: int) -> None: + self.jobs: Dict[int, List[Tuple[int, User]]] + self.parent_jobs: List[Job] + self.merger: IntersectMerge + self.consensus_settings: ConsensusSettings + self.assignee_jobs_count: Dict[User, int] + self.task_id = task_id + self._get_consensus_jobs(task_id) + self._get_assignee_jobs_count() + self.consensus_settings = ConsensusSettings.objects.filter(task=task_id).first() + self.merger = IntersectMerge( + conf=IntersectMerge.Conf( + pairwise_dist=self.consensus_settings.iou_threshold, + output_conf_thresh=self.consensus_settings.agreement_score_threshold, + quorum=self.consensus_settings.quorum, + sigma=self.consensus_settings.sigma, + torso_r=self.consensus_settings.line_thickness, + ) + ) + if not self.jobs: + raise ValidationError( + "No annotated consensus jobs found or no regular jobs in annotation stage" + ) + + def _get_consensus_jobs(self, task_id: int) -> None: + job_map = {} # parent_job_id -> [(consensus_job_id, assignee)] + parent_jobs: dict[int, Job] = {} + for job in ( + Job.objects.prefetch_related("segment", "parent_job", "assignee") + .filter( + segment__task_id=task_id, + type=JobType.CONSENSUS.value, + parent_job__stage=StageChoice.ANNOTATION.value, + parent_job__isnull=False, + ) + .exclude(state=StateChoice.NEW.value) + ): + job_map.setdefault(job.parent_job_id, []).append((job.id, job.assignee)) + parent_jobs.setdefault(job.parent_job_id, job.parent_job) + + self.jobs = job_map + self.parent_jobs = list(parent_jobs.values()) + + def _get_assignee_jobs_count(self) -> None: + assignee_jobs_count = {} + for assignees in self.jobs.values(): + for _, assignee in assignees: + if assignee not in assignee_jobs_count: + assignee_jobs_count[assignee] = 1 + else: + assignee_jobs_count[assignee] += 1 + self.assignee_jobs_count = assignee_jobs_count + + @staticmethod + def _get_annotations(job_id: int) -> dm.Dataset: + return JobDataProvider(job_id).dm_dataset + + def _merge_consensus_jobs(self, parent_job_id: int): + consensus_job_info = self.jobs.get(parent_job_id) + if not consensus_job_info: + raise ValidationError(f"No consensus jobs found for parent job {parent_job_id}") + + consensus_job_ids = [consensus_job_id for consensus_job_id, _ in consensus_job_info] + assignees = [assignee for _, assignee in consensus_job_info] + + consensus_job_data_providers = list(map(JobDataProvider, consensus_job_ids)) + consensus_datasets = [ + consensus_job_data_provider.dm_dataset + for consensus_job_data_provider in consensus_job_data_providers + ] + + merged_dataset = self.merger(consensus_datasets) + + assignee_report_data = generate_assignee_consensus_report( + consensus_job_ids, + assignees, + consensus_datasets, + self.merger.dataset_mean_consensus_score, + ) + + # delete the existing annotations in the job + patch_job_data(parent_job_id, None, PatchAction.DELETE) + # if we don't delete exising annotations, the imported annotations + # will be appended to the existing annotations, and thus updated annotation + # would have both existing + imported annotations, but we only want the + # imported annotations + + parent_job_data_provider = JobDataProvider(parent_job_id) + + # imports the annotations in the `parent_job.job_data` instance + import_dm_annotations(merged_dataset, parent_job_data_provider.job_data) + + # updates the annotations in the job + patch_job_data( + parent_job_id, parent_job_data_provider.job_data.data.serialize(), PatchAction.UPDATE + ) + + job_comparison_report, assignee_report_data = generate_job_consensus_report( + consensus_settings=self.consensus_settings, + errors=self.merger.errors, + consensus_job_data_providers=consensus_job_data_providers, + merged_dataset=merged_dataset, + merger=self.merger, + assignees=assignees, + assignee_report_data=assignee_report_data, + ) + + for parent_job_id in self.parent_jobs: + if parent_job_id.id == parent_job_id and parent_job_id.type == JobType.ANNOTATION.value: + parent_job_id.state = StateChoice.COMPLETED.value + parent_job_id.save() + + return job_comparison_report, assignee_report_data + + @transaction.atomic + def merge_all_consensus_jobs(self, task_id: int) -> None: + job_comparison_reports: Dict[int, ComparisonReport] = {} + assignee_reports: Dict[User, Dict[str, float]] = {} + + for parent_job_id in self.jobs.keys(): + job_comparison_report, assignee_report_data = self._merge_consensus_jobs(parent_job_id) + job_comparison_reports[parent_job_id] = job_comparison_report + + for assignee in assignee_report_data: + if assignee not in assignee_reports: + assignee_reports[assignee] = assignee_report_data[assignee] + else: + assignee_reports[assignee]["conflict_count"] += assignee_report_data[assignee][ + "conflict_count" + ] + assignee_reports[assignee]["consensus_score"] += assignee_report_data[assignee][ + "consensus_score" + ] + + for assignee in assignee_reports: + assignee_reports[assignee]["consensus_score"] /= self.assignee_jobs_count[assignee] + + task_report_data, task_mean_consensus_score = generate_task_consensus_report( + list(job_comparison_reports.values()) + ) + return save_report( + self.task_id, + self.parent_jobs, + task_report_data, + job_comparison_reports, + assignee_reports, + task_mean_consensus_score, + ) + + @transaction.atomic + def merge_single_consensus_job(self, parent_job_id: int) -> None: + job_comparison_reports: Dict[int, ComparisonReport] = {} + assignee_reports: Dict[User, Dict[str, float]] = {} + + job_comparison_report, assignee_report_data = self._merge_consensus_jobs(parent_job_id) + + job_comparison_reports[parent_job_id] = job_comparison_report + + for assignee in self.assignee_jobs_count: + assignee_report = ( + AssigneeConsensusReport.objects.filter(assignee=assignee).order_by("-id").first() + ) + assignee_reports[assignee] = ( + assignee_report.to_dict() + if assignee_report + else {"conflict_count": 0, "consensus_score": 0} + ) + + for assignee in assignee_report_data: + if assignee not in assignee_reports: + assignee_reports[assignee] = assignee_report_data[assignee] + else: + assignee_reports[assignee]["conflict_count"] += assignee_report_data[assignee][ + "conflict_count" + ] + assignee_reports[assignee]["consensus_score"] += assignee_report_data[assignee][ + "consensus_score" + ] + + for assignee in assignee_reports: + assignee_reports[assignee]["consensus_score"] /= self.assignee_jobs_count[assignee] + + for parent_job in self.parent_jobs: + if parent_job.id == parent_job_id: + continue + + job_comparison_report = ( + ConsensusReport.objects.filter(job_id=parent_job_id).order_by("-id").first() + ) + job_comparison_reports[parent_job_id] = ComparisonReport.from_dict( + job_comparison_report + ) + + task_report_data, task_mean_consensus_score = generate_task_consensus_report( + list(job_comparison_reports.values()) + ) + return save_report( + self.task_id, + self.parent_jobs, + task_report_data, + job_comparison_reports, + assignee_reports, + task_mean_consensus_score, + ) + + +def scehdule_consensus_merging(instance: Union[Job, Task], request) -> Response: + queue_name = settings.CVAT_QUEUES.CONSENSUS.value + queue = django_rq.get_queue(queue_name) + rq_id = request.query_params.get("rq_id", uuid4().hex) + rq_job = queue.fetch_job(rq_id) + user_id = request.user.id + serializer = RqIdSerializer({"rq_id": rq_id}) + + if rq_job: + if rq_job.is_finished: + rq_job.delete() + return Response(serializer.data, status=status.HTTP_201_CREATED) + elif rq_job.is_failed: + exc_info = process_failed_job(rq_job) + return Response(data=exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + else: + # rq_job is in queued stage or might be running + return Response(serializer.data, status=status.HTTP_202_ACCEPTED) + + if isinstance(instance, Task): + consensus_job_merger = MergeConsensusJobs(task_id=instance.id) + func = consensus_job_merger.merge_all_consensus_jobs + else: + consensus_job_merger = MergeConsensusJobs(task_id=instance.get_task_id()) + func = consensus_job_merger.merge_single_consensus_job + + func_args = [instance.id] + + with get_rq_lock_by_user(queue, user_id): + queue.enqueue_call( + func=func, + args=func_args, + job_id=rq_id, + meta=get_rq_job_meta(request=request, db_obj=instance), + depends_on=define_dependent_job(queue, user_id), + ) + + return Response(serializer.data, status=status.HTTP_202_ACCEPTED) diff --git a/cvat/apps/consensus/migrations/0001_initial.py b/cvat/apps/consensus/migrations/0001_initial.py new file mode 100644 index 00000000000..cedba9e9e28 --- /dev/null +++ b/cvat/apps/consensus/migrations/0001_initial.py @@ -0,0 +1,211 @@ +# Generated by Django 4.2.15 on 2024-09-08 23:28 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("engine", "0084_job_parent_job_task_consensus_jobs_per_regular_job_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="ConsensusSettings", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("agreement_score_threshold", models.FloatField(default=0)), + ("quorum", models.IntegerField(default=-1)), + ("iou_threshold", models.FloatField(default=0.5)), + ("sigma", models.FloatField(default=0.1)), + ("line_thickness", models.FloatField(default=0.01)), + ( + "task", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="consensus_settings", + to="engine.task", + ), + ), + ], + ), + migrations.CreateModel( + name="ConsensusReport", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("created_date", models.DateTimeField(auto_now_add=True)), + ("target_last_updated", models.DateTimeField()), + ("consensus_score", models.IntegerField()), + ("data", models.JSONField()), + ( + "assignee", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="consensus", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "job", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="consensus_reports", + to="engine.job", + ), + ), + ( + "parent", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="children_reports", + to="consensus.consensusreport", + ), + ), + ( + "task", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="consensus_reports", + to="engine.task", + ), + ), + ], + ), + migrations.CreateModel( + name="ConsensusConflict", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("frame", models.PositiveIntegerField()), + ( + "type", + models.CharField( + choices=[ + ("no_matching_item", "NoMatchingItemError"), + ("no_matching_annotation", "NoMatchingAnnError"), + ("annotation_too_close", "AnnotationsTooCloseError"), + ("failed_label_voting", "FailedLabelVotingError"), + ], + max_length=32, + ), + ), + ( + "report", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="conflicts", + to="consensus.consensusreport", + ), + ), + ], + ), + migrations.CreateModel( + name="AssigneeConsensusReport", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("consensus_score", models.IntegerField()), + ("conflict_count", models.IntegerField()), + ("consensus_report_id", models.PositiveIntegerField()), + ( + "assignee", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="assignee_consensus_reports", + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "task", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="assignee_consensus_reports", + to="engine.task", + ), + ), + ], + ), + migrations.CreateModel( + name="AnnotationId", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("obj_id", models.PositiveIntegerField()), + ("job_id", models.PositiveIntegerField()), + ( + "type", + models.CharField( + choices=[("tag", "TAG"), ("shape", "SHAPE"), ("track", "TRACK")], + max_length=32, + ), + ), + ( + "shape_type", + models.CharField( + choices=[ + ("rectangle", "RECTANGLE"), + ("polygon", "POLYGON"), + ("polyline", "POLYLINE"), + ("points", "POINTS"), + ("ellipse", "ELLIPSE"), + ("cuboid", "CUBOID"), + ("mask", "MASK"), + ("skeleton", "SKELETON"), + ], + default=None, + max_length=32, + null=True, + ), + ), + ( + "conflict", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="annotation_ids", + to="consensus.consensusconflict", + ), + ), + ], + ), + ] diff --git a/cvat/apps/consensus/migrations/__init__.py b/cvat/apps/consensus/migrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/cvat/apps/consensus/models.py b/cvat/apps/consensus/models.py new file mode 100644 index 00000000000..7f1deb312aa --- /dev/null +++ b/cvat/apps/consensus/models.py @@ -0,0 +1,213 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations # this allows forward references + +from enum import Enum +from typing import Any, Sequence + +from django.core.exceptions import ValidationError +from django.db import models +from django.forms.models import model_to_dict + +from cvat.apps.engine.models import Job, ShapeType, Task, User + + +class ConsensusConflictType(str, Enum): + NoMatchingItemError = "no_matching_item" + NoMatchingAnnError = "no_matching_annotation" + AnnotationsTooCloseError = "annotation_too_close" + FailedLabelVotingError = "failed_label_voting" + + def __str__(self) -> str: + return self.value + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + +class AnnotationType(str, Enum): + TAG = "tag" + SHAPE = "shape" + TRACK = "track" + + def __str__(self) -> str: + return self.value + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + +class ConsensusReportTarget(str, Enum): + JOB = "job" + TASK = "task" + + def __str__(self) -> str: + return self.value + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + +class ConsensusSettings(models.Model): + task = models.ForeignKey( + Task, + on_delete=models.CASCADE, + related_name="consensus_settings", + null=True, + blank=True, + ) + agreement_score_threshold = models.FloatField(default=0) + quorum = models.IntegerField(default=-1) + iou_threshold = models.FloatField(default=0.5) + sigma = models.FloatField(default=0.1) + line_thickness = models.FloatField(default=0.01) + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def to_dict(self): + return model_to_dict(self) + + @property + def organization_id(self): + return getattr(self.task.organization, "id", None) + + +class ConsensusReport(models.Model): + task = models.ForeignKey( + Task, + on_delete=models.CASCADE, + related_name="consensus_reports", + null=True, + blank=True, + ) + job = models.ForeignKey( + Job, + on_delete=models.CASCADE, + related_name="consensus_reports", + null=True, + blank=True, + ) + + created_date = models.DateTimeField(auto_now_add=True) + target_last_updated = models.DateTimeField() + consensus_score = models.IntegerField() + assignee = models.ForeignKey( + User, on_delete=models.SET_NULL, related_name="consensus", null=True, blank=True + ) + parent = models.ForeignKey( + "self", on_delete=models.CASCADE, related_name="children_reports", null=True, blank=True + ) + data = models.JSONField() + + conflicts: Sequence[ConsensusConflict] + + def _parse_report(self): + from cvat.apps.consensus.consensus_reports import ComparisonReport + + return ComparisonReport.from_json(self.data) + + @property + def summary(self): + report = self._parse_report() + return report.comparison_summary + + @property + def target(self) -> ConsensusReportTarget: + if self.job: + return ConsensusReportTarget.JOB + elif self.task: + return ConsensusReportTarget.TASK + else: + assert False + + def get_task(self) -> Task: + if self.task is not None: + return self.task + else: + return self.job.segment.task + + def get_json_report(self) -> str: + return self.data + + def clean(self): + if not (self.job is not None) ^ (self.task is not None): + raise ValidationError("One of the 'job' and 'task' fields must be set") + + @property + def organization_id(self): + if task := self.get_task(): + return getattr(task.organization, "id", None) + return None + + +class ConsensusConflict(models.Model): + report = models.ForeignKey(ConsensusReport, on_delete=models.CASCADE, related_name="conflicts") + frame = models.PositiveIntegerField() + type = models.CharField(max_length=32, choices=ConsensusConflictType.choices()) + + annotation_ids: Sequence[AnnotationId] + + @property + def organization_id(self): + return self.report.organization_id + + +class AnnotationId(models.Model): + conflict = models.ForeignKey( + ConsensusConflict, on_delete=models.CASCADE, related_name="annotation_ids" + ) + + obj_id = models.PositiveIntegerField() + job_id = models.PositiveIntegerField() + type = models.CharField(max_length=32, choices=AnnotationType.choices()) + shape_type = models.CharField( + max_length=32, choices=ShapeType.choices(), null=True, default=None + ) + + def clean(self) -> None: + if self.type in [AnnotationType.SHAPE, AnnotationType.TRACK]: + if not self.shape_type: + raise ValidationError("Annotation kind must be specified") + elif self.type == AnnotationType.TAG: + if self.shape_type: + raise ValidationError("Annotation kind must be empty") + else: + raise ValidationError(f"Unexpected type value '{self.type}'") + + +class AssigneeConsensusReport(models.Model): + task = models.ForeignKey( + Task, + on_delete=models.CASCADE, + related_name="assignee_consensus_reports", + null=True, + blank=True, + ) + assignee = models.ForeignKey( + User, + on_delete=models.SET_NULL, + related_name="assignee_consensus_reports", + null=True, + blank=True, + ) + consensus_score = models.IntegerField() + conflict_count = models.IntegerField() + consensus_report_id = models.PositiveIntegerField() + + def get_task(self) -> Task: + return self.task + + def to_dict(self): + return model_to_dict(self) + + @property + def organization_id(self): + if task := self.get_task(): + return getattr(task.organization, "id", None) + return None diff --git a/cvat/apps/consensus/permissions.py b/cvat/apps/consensus/permissions.py new file mode 100644 index 00000000000..cf8d01cb1b3 --- /dev/null +++ b/cvat/apps/consensus/permissions.py @@ -0,0 +1,363 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import Optional, Union, cast + +from django.conf import settings +from rest_framework.exceptions import ValidationError + +from cvat.apps.engine.models import Task +from cvat.apps.engine.permissions import JobPermission, TaskPermission +from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum, get_iam_context + +from .models import AssigneeConsensusReport, ConsensusConflict, ConsensusReport, ConsensusSettings + + +class ConsensusReportPermission(OpenPolicyAgentPermission): + obj: Optional[ConsensusReport] + job_owner_id: Optional[int] + + class Scopes(StrEnum): + LIST = "list" + CREATE = "create" + VIEW = "view" + VIEW_STATUS = "view:status" + + @classmethod + def create_scope_check_status(cls, request, job_owner_id: int, iam_context=None): + if not iam_context and request: + iam_context = get_iam_context(request, None) + return cls(**iam_context, scope="view:status", job_owner_id=job_owner_id) + + @classmethod + def create_scope_view(cls, request, report: Union[int, ConsensusReport], iam_context=None): + if isinstance(report, int): + try: + report = ConsensusReport.objects.get(id=report) + except ConsensusReport.DoesNotExist as ex: + raise ValidationError(str(ex)) + + # Access rights are the same as in the owning task + # This component doesn't define its own rules in this case + return TaskPermission.create_scope_view( + request, + task=report.get_task(), + iam_context=iam_context, + ) + + @classmethod + def create(cls, request, view, obj, iam_context): + Scopes = __class__.Scopes + + permissions = [] + if view.basename == "consensus_reports": + for scope in cls.get_scopes(request, view, obj): + if scope == Scopes.VIEW: + permissions.append(cls.create_scope_view(request, obj, iam_context=iam_context)) + elif scope == Scopes.LIST and isinstance(obj, Task): + permissions.append(TaskPermission.create_scope_view(request, task=obj)) + elif scope == Scopes.CREATE: + if task_id := request.data.get("task_id"): + permissions.append(TaskPermission.create_scope_view(request, task_id)) + elif job_id := request.data.get("job_id"): + permissions.append(JobPermission.create_scope_view(request, job_id)) + + permissions.append(cls.create_base_perm(request, view, scope, iam_context, obj)) + else: + permissions.append(cls.create_base_perm(request, view, scope, iam_context, obj)) + + return permissions + + def __init__(self, **kwargs): + if "job_owner_id" in kwargs: + self.job_owner_id = int(kwargs.pop("job_owner_id")) + + super().__init__(**kwargs) + self.url = settings.IAM_OPA_DATA_URL + "/consensus_reports/allow" + + @staticmethod + def get_scopes(request, view, obj): + Scopes = __class__.Scopes + return [ + { + "list": Scopes.LIST, + "create": Scopes.CREATE, + "retrieve": Scopes.VIEW, + "data": Scopes.VIEW, + }[view.action] + ] + + def get_resource(self): + data = None + + if self.obj: + task = self.obj.get_task() + if task.project: + organization = task.project.organization + else: + organization = task.organization + + data = { + "id": self.obj.id, + "organization": {"id": getattr(organization, "id", None)}, + "task": ( + { + "owner": {"id": getattr(task.owner, "id", None)}, + "assignee": {"id": getattr(task.assignee, "id", None)}, + } + if task + else None + ), + "project": ( + { + "owner": {"id": getattr(task.project.owner, "id", None)}, + "assignee": {"id": getattr(task.project.assignee, "id", None)}, + } + if task.project + else None + ), + } + + return data + + +class ConsensusConflictPermission(OpenPolicyAgentPermission): + obj: Optional[ConsensusConflict] + + class Scopes(StrEnum): + LIST = "list" + + @classmethod + def create(cls, request, view, obj, iam_context): + permissions = [] + if view.basename == "conflicts": + for scope in cls.get_scopes(request, view, obj): + if scope == cls.Scopes.LIST and isinstance(obj, ConsensusReport): + permissions.append( + ConsensusReportPermission.create_scope_view( + request, + obj, + iam_context=iam_context, + ) + ) + else: + permissions.append(cls.create_base_perm(request, view, scope, iam_context, obj)) + + return permissions + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.url = settings.IAM_OPA_DATA_URL + "/conflicts/allow" + + @staticmethod + def get_scopes(request, view, obj): + Scopes = __class__.Scopes + return [ + { + "list": Scopes.LIST, + }[view.action] + ] + + def get_resource(self): + return None + + +class ConsensusSettingPermission(OpenPolicyAgentPermission): + obj: Optional[ConsensusSettings] + + class Scopes(StrEnum): + LIST = "list" + VIEW = "view" + UPDATE = "update" + + @classmethod + def create(cls, request, view, obj, iam_context): + Scopes = __class__.Scopes + + permissions = [] + if view.basename == "consensus_settings": + for scope in cls.get_scopes(request, view, obj): + if scope in [Scopes.VIEW, Scopes.UPDATE]: + obj = cast(ConsensusSettings, obj) + + if scope == Scopes.VIEW: + task_scope = TaskPermission.Scopes.VIEW + elif scope == Scopes.UPDATE: + task_scope = TaskPermission.Scopes.UPDATE_DESC + else: + assert False + + # Access rights are the same as in the owning task + # This component doesn't define its own rules in this case + permissions.append( + TaskPermission.create_base_perm( + request, + view, + iam_context=iam_context, + scope=task_scope, + obj=obj.task, + ) + ) + elif scope == cls.Scopes.LIST: + if task_id := request.query_params.get("task_id", None): + permissions.append( + TaskPermission.create_scope_view( + request, + int(task_id), + iam_context=iam_context, + ) + ) + + permissions.append(cls.create_scope_list(request, iam_context)) + else: + permissions.append(cls.create_base_perm(request, view, scope, iam_context, obj)) + + return permissions + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.url = settings.IAM_OPA_DATA_URL + "/consensus_settings/allow" + + @staticmethod + def get_scopes(request, view, obj): + Scopes = __class__.Scopes + return [ + { + "list": Scopes.LIST, + "retrieve": Scopes.VIEW, + "partial_update": Scopes.UPDATE, + }.get(view.action, None) + ] + + def get_resource(self): + data = None + + if self.obj: + task = self.obj.task + if task.project: + organization = task.project.organization + else: + organization = task.organization + + data = { + "id": self.obj.id, + "organization": {"id": getattr(organization, "id", None)}, + "task": ( + { + "owner": {"id": getattr(task.owner, "id", None)}, + "assignee": {"id": getattr(task.assignee, "id", None)}, + } + if task + else None + ), + "project": ( + { + "owner": {"id": getattr(task.project.owner, "id", None)}, + "assignee": {"id": getattr(task.project.assignee, "id", None)}, + } + if task.project + else None + ), + } + + return data + + +class AssigneeConsensusReportPermission(OpenPolicyAgentPermission): + obj: Optional[AssigneeConsensusReport] + job_owner_id: Optional[int] + + class Scopes(StrEnum): + LIST = "list" + VIEW = "view" + VIEW_STATUS = "view:status" + + @classmethod + def create_scope_check_status(cls, request, job_owner_id: int, iam_context=None): + if not iam_context and request: + iam_context = get_iam_context(request, None) + return cls(**iam_context, scope="view:status", job_owner_id=job_owner_id) + + @classmethod + def create_scope_view(cls, request, report: Union[int, ConsensusReport], iam_context=None): + if isinstance(report, int): + try: + report = AssigneeConsensusReport.objects.get(id=report) + except AssigneeConsensusReport.DoesNotExist as ex: + raise ValidationError(str(ex)) + + # Access rights are the same as in the owning task + # This component doesn't define its own rules in this case + return TaskPermission.create_scope_view( + request, + task=report.get_task(), + iam_context=iam_context, + ) + + @classmethod + def create(cls, request, view, obj, iam_context): + Scopes = __class__.Scopes + + permissions = [] + if view.basename == "assignee_consensus_reports": + for scope in cls.get_scopes(request, view, obj): + if scope == Scopes.VIEW: + permissions.append(cls.create_scope_view(request, obj, iam_context=iam_context)) + elif scope == Scopes.LIST and isinstance(obj, Task): + permissions.append(TaskPermission.create_scope_view(request, task=obj)) + else: + permissions.append(cls.create_base_perm(request, view, scope, iam_context, obj)) + + return permissions + + def __init__(self, **kwargs): + if "job_owner_id" in kwargs: + self.job_owner_id = int(kwargs.pop("job_owner_id")) + + super().__init__(**kwargs) + self.url = settings.IAM_OPA_DATA_URL + "/assignee_consensus_reports/allow" + + @staticmethod + def get_scopes(request, view, obj): + Scopes = __class__.Scopes + return [ + { + "list": Scopes.LIST, + "retrieve": Scopes.VIEW, + "data": Scopes.VIEW, + }[view.action] + ] + + def get_resource(self): + data = None + + if self.obj: + task = self.obj.get_task() + if task.project: + organization = task.project.organization + else: + organization = task.organization + + data = { + "id": self.obj.id, + "organization": {"id": getattr(organization, "id", None)}, + "task": ( + { + "owner": {"id": getattr(task.owner, "id", None)}, + "assignee": {"id": getattr(task.assignee, "id", None)}, + } + if task + else None + ), + "project": ( + { + "owner": {"id": getattr(task.project.owner, "id", None)}, + "assignee": {"id": getattr(task.project.assignee, "id", None)}, + } + if task.project + else None + ), + } + + return data diff --git a/cvat/apps/consensus/pyproject.toml b/cvat/apps/consensus/pyproject.toml new file mode 100644 index 00000000000..567b7836258 --- /dev/null +++ b/cvat/apps/consensus/pyproject.toml @@ -0,0 +1,12 @@ +[tool.isort] +profile = "black" +forced_separate = ["tests"] +line_length = 100 +skip_gitignore = true # align tool behavior with Black +known_first_party = ["cvat"] + +# Can't just use a pyproject in the root dir, so duplicate +# https://github.com/psf/black/issues/2863 +[tool.black] +line-length = 100 +target-version = ['py38'] diff --git a/cvat/apps/consensus/rules/assignee_consensus_reports.rego b/cvat/apps/consensus/rules/assignee_consensus_reports.rego new file mode 100644 index 00000000000..e94a3981498 --- /dev/null +++ b/cvat/apps/consensus/rules/assignee_consensus_reports.rego @@ -0,0 +1,104 @@ +package assignee_consensus_reports + +import rego.v1 + +import data.utils +import data.organizations + +# input: { +# "scope": <"view"|"list"|"view:status"> or null, +# "auth": { +# "user": { +# "id": , +# "privilege": <"admin"|"business"|"user"|"worker"> or null +# }, +# "organization": { +# "id": , +# "owner": { +# "id": +# }, +# "user": { +# "role": <"owner"|"maintainer"|"supervisor"|"worker"> or null +# } +# } or null, +# }, +# "resource": { +# "id": , +# "owner": { "id": }, +# "organization": { "id": } or null, +# "task": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# "project": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# } +# } + +default allow := false + +allow if { + utils.is_admin +} + +allow if { + input.scope == utils.LIST + utils.is_sandbox +} + +allow if { + input.scope == utils.LIST + organizations.is_member +} + +filter := [] if { # Django Q object to filter list of entries + utils.is_admin + utils.is_sandbox +} else := qobject if { + utils.is_admin + utils.is_organization + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + utils.is_sandbox + user := input.auth.user + qobject := [ + {"task__owner_id": user.id}, + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + ] +} else := qobject if { + utils.is_organization + utils.has_perm(utils.USER) + organizations.has_perm(organizations.MAINTAINER) + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + organizations.has_perm(organizations.WORKER) + user := input.auth.user + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + + {"task__owner_id": user.id}, "|", + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + + "&" + ] +} diff --git a/cvat/apps/consensus/rules/conflicts.rego b/cvat/apps/consensus/rules/conflicts.rego new file mode 100644 index 00000000000..6661a35db3a --- /dev/null +++ b/cvat/apps/consensus/rules/conflicts.rego @@ -0,0 +1,118 @@ +package consensus_conflicts + +import rego.v1 + +import data.utils +import data.organizations + +# input: { +# "scope": <"list"> or null, +# "auth": { +# "user": { +# "id": , +# "privilege": <"admin"|"business"|"user"|"worker"> or null +# }, +# "organization": { +# "id": , +# "owner": { +# "id": +# }, +# "user": { +# "role": <"owner"|"maintainer"|"supervisor"|"worker"> or null +# } +# } or null, +# }, +# "resource": { +# "id": , +# "owner": { "id": }, +# "organization": { "id": } or null, +# "task": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# "project": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# } +# } + +default allow := false + +allow if { + utils.is_admin +} + +allow if { + input.scope == utils.LIST + utils.is_sandbox +} + +allow if { + input.scope == utils.LIST + organizations.is_member +} + +filter := [] if { # Django Q object to filter list of entries + utils.is_admin + utils.is_sandbox +} else := qobject if { + utils.is_admin + utils.is_organization + org := input.auth.organization + qobject := [ + {"report__job__segment__task__organization": org.id}, + {"report__job__segment__task__project__organization": org.id}, "|", + {"report__task__organization": org.id}, "|", + {"report__task__project__organization": org.id}, "|", + ] +} else := qobject if { + utils.is_sandbox + user := input.auth.user + qobject := [ + {"report__job__segment__task__owner_id": user.id}, + {"report__job__segment__task__assignee_id": user.id}, "|", + {"report__job__segment__task__project__owner_id": user.id}, "|", + {"report__job__segment__task__project__assignee_id": user.id}, "|", + {"report__task__owner_id": user.id}, "|", + {"report__task__assignee_id": user.id}, "|", + {"report__task__project__owner_id": user.id}, "|", + {"report__task__project__assignee_id": user.id}, "|", + ] +} else := qobject if { + utils.is_organization + utils.has_perm(utils.USER) + organizations.has_perm(organizations.MAINTAINER) + org := input.auth.organization + qobject := [ + {"report__job__segment__task__organization": org.id}, + {"report__job__segment__task__project__organization": org.id}, "|", + {"report__task__organization": org.id}, "|", + {"report__task__project__organization": org.id}, "|", + ] +} else := qobject if { + organizations.has_perm(organizations.WORKER) + user := input.auth.user + org := input.auth.organization + qobject := [ + {"report__job__segment__task__organization": org.id}, + {"report__job__segment__task__project__organization": org.id}, "|", + {"report__task__organization": org.id}, "|", + {"report__task__project__organization": org.id}, "|", + + {"report__job__segment__task__owner_id": user.id}, + {"report__job__segment__task__assignee_id": user.id}, "|", + {"report__job__segment__task__project__owner_id": user.id}, "|", + {"report__job__segment__task__project__assignee_id": user.id}, "|", + {"report__task__owner_id": user.id}, "|", + {"report__task__assignee_id": user.id}, "|", + {"report__task__project__owner_id": user.id}, "|", + {"report__task__project__assignee_id": user.id}, "|", + + "&" + ] +} diff --git a/cvat/apps/consensus/rules/consensus_reports.rego b/cvat/apps/consensus/rules/consensus_reports.rego new file mode 100644 index 00000000000..c37b70205a2 --- /dev/null +++ b/cvat/apps/consensus/rules/consensus_reports.rego @@ -0,0 +1,118 @@ +package consensus_reports + +import rego.v1 + +import data.utils +import data.organizations + +# input: { +# "scope": <"view"|"list"|"create"|"view:status"> or null, +# "auth": { +# "user": { +# "id": , +# "privilege": <"admin"|"business"|"user"|"worker"> or null +# }, +# "organization": { +# "id": , +# "owner": { +# "id": +# }, +# "user": { +# "role": <"owner"|"maintainer"|"supervisor"|"worker"> or null +# } +# } or null, +# }, +# "resource": { +# "id": , +# "owner": { "id": }, +# "organization": { "id": } or null, +# "task": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# "project": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# } +# } + +default allow := false + +allow if { + utils.is_admin +} + +allow if { + input.scope == utils.LIST + utils.is_sandbox +} + +allow if { + input.scope == utils.LIST + organizations.is_member +} + +filter := [] if { # Django Q object to filter list of entries + utils.is_admin + utils.is_sandbox +} else := qobject if { + utils.is_admin + utils.is_organization + org := input.auth.organization + qobject := [ + {"job__segment__task__organization": org.id}, + {"job__segment__task__project__organization": org.id}, "|", + {"task__organization": org.id}, "|", + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + utils.is_sandbox + user := input.auth.user + qobject := [ + {"job__segment__task__owner_id": user.id}, + {"job__segment__task__assignee_id": user.id}, "|", + {"job__segment__task__project__owner_id": user.id}, "|", + {"job__segment__task__project__assignee_id": user.id}, "|", + {"task__owner_id": user.id}, "|", + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + ] +} else := qobject if { + utils.is_organization + utils.has_perm(utils.USER) + organizations.has_perm(organizations.MAINTAINER) + org := input.auth.organization + qobject := [ + {"job__segment__task__organization": org.id}, + {"job__segment__task__project__organization": org.id}, "|", + {"task__organization": org.id}, "|", + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + organizations.has_perm(organizations.WORKER) + user := input.auth.user + org := input.auth.organization + qobject := [ + {"job__segment__task__organization": org.id}, + {"job__segment__task__project__organization": org.id}, "|", + {"task__organization": org.id}, "|", + {"task__project__organization": org.id}, "|", + + {"job__segment__task__owner_id": user.id}, + {"job__segment__task__assignee_id": user.id}, "|", + {"job__segment__task__project__owner_id": user.id}, "|", + {"job__segment__task__project__assignee_id": user.id}, "|", + {"task__owner_id": user.id}, "|", + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + + "&" + ] +} diff --git a/cvat/apps/consensus/rules/consensus_settings.rego b/cvat/apps/consensus/rules/consensus_settings.rego new file mode 100644 index 00000000000..fd67061c6e1 --- /dev/null +++ b/cvat/apps/consensus/rules/consensus_settings.rego @@ -0,0 +1,104 @@ +package consensus_settings + +import rego.v1 + +import data.utils +import data.organizations + +# input: { +# "scope": <"view"> or null, +# "auth": { +# "user": { +# "id": , +# "privilege": <"admin"|"business"|"user"|"worker"> or null +# }, +# "organization": { +# "id": , +# "owner": { +# "id": +# }, +# "user": { +# "role": <"owner"|"maintainer"|"supervisor"|"worker"> or null +# } +# } or null, +# }, +# "resource": { +# "id": , +# "owner": { "id": }, +# "organization": { "id": } or null, +# "task": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# "project": { +# "id": , +# "owner": { "id": }, +# "assignee": { "id": }, +# "organization": { "id": } or null, +# } or null, +# } +# } + +default allow := false + +allow if { + utils.is_admin +} + +allow if { + input.scope == utils.LIST + utils.is_sandbox +} + +allow if { + input.scope == utils.LIST + organizations.is_member +} + +filter := [] if { # Django Q object to filter list of entries + utils.is_admin + utils.is_sandbox +} else := qobject if { + utils.is_admin + utils.is_organization + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + utils.is_sandbox + user := input.auth.user + qobject := [ + {"task__owner_id": user.id}, + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + ] +} else := qobject if { + utils.is_organization + utils.has_perm(utils.USER) + organizations.has_perm(organizations.MAINTAINER) + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + ] +} else := qobject if { + organizations.has_perm(organizations.WORKER) + user := input.auth.user + org := input.auth.organization + qobject := [ + {"task__organization": org.id}, + {"task__project__organization": org.id}, "|", + + {"task__owner_id": user.id}, + {"task__assignee_id": user.id}, "|", + {"task__project__owner_id": user.id}, "|", + {"task__project__assignee_id": user.id}, "|", + + "&" + ] +} diff --git a/cvat/apps/consensus/serializers.py b/cvat/apps/consensus/serializers.py new file mode 100644 index 00000000000..678ae958228 --- /dev/null +++ b/cvat/apps/consensus/serializers.py @@ -0,0 +1,126 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import textwrap + +from rest_framework import serializers + +from cvat.apps.consensus import models +from cvat.apps.consensus.models import AnnotationId +from cvat.apps.engine import serializers as engine_serializers + + +class ConsensusAnnotationIdSerializer(serializers.ModelSerializer): + class Meta: + model = AnnotationId + fields = ("obj_id", "job_id", "type", "shape_type") + read_only_fields = fields + + +class ConsensusConflictSerializer(serializers.ModelSerializer): + annotation_ids = ConsensusAnnotationIdSerializer(many=True) + + class Meta: + model = models.ConsensusConflict + fields = ("id", "frame", "type", "annotation_ids", "report_id") + read_only_fields = fields + + +class ConsensusReportSummarySerializer(serializers.Serializer): + frame_count = serializers.IntegerField() + conflict_count = serializers.IntegerField() + conflicts_by_type = serializers.DictField(child=serializers.IntegerField()) + + +class ConsensusReportSerializer(serializers.ModelSerializer): + target = serializers.ChoiceField(models.ConsensusReportTarget.choices()) + assignee = engine_serializers.BasicUserSerializer(allow_null=True, read_only=True) + summary = ConsensusReportSummarySerializer() + + class Meta: + model = models.ConsensusReport + fields = ( + "id", + "job_id", + "task_id", + "parent_id", + "summary", + "created_date", + "target_last_updated", + "target", + "assignee", + "consensus_score", + ) + read_only_fields = fields + + +class ConsensusReportCreateSerializer(serializers.Serializer): + task_id = serializers.IntegerField(write_only=True, required=False) + job_id = serializers.IntegerField(write_only=True, required=False) + + +class AssigneeConsensusReportSerializer(serializers.ModelSerializer): + assignee = engine_serializers.BasicUserSerializer(allow_null=True, read_only=True) + + class Meta: + model = models.AssigneeConsensusReport + fields = ( + "task_id", + "assignee", + "consensus_score", + "consensus_report_id", + "conflict_count", + ) + read_only_fields = fields + + +class ConsensusSettingsSerializer(serializers.ModelSerializer): + class Meta: + model = models.ConsensusSettings + fields = ( + "id", + "task_id", + "iou_threshold", + "agreement_score_threshold", + "quorum", + "sigma", + "line_thickness", + ) + read_only_fields = ( + "id", + "task_id", + ) + + extra_kwargs = {k: {"required": False} for k in fields} + + for field_name, help_text in { + "iou_threshold": "Used for distinction between matched / unmatched shapes", + "agreement_score_threshold": """ + Confidence threshold for output annotations + """, + "quorum": """ + Minimum count for a label and attribute voting results to be counted + """, + "sigma": """ + Sigma value for OKS calculation + """, + "line_thickness": """ + thickness of polylines, relatively to the (image area) ^ 0.5. + """, + }.items(): + extra_kwargs.setdefault(field_name, {}).setdefault( + "help_text", textwrap.dedent(help_text.lstrip("\n")) + ) + + def validate(self, attrs): + for k, v in attrs.items(): + if (k.endswith("_threshold") or k == "line_thickness") and not 0 <= v <= 1: + raise serializers.ValidationError(f"{k} must be in the range [0; 1]") + elif k == "quorum" and not 0 <= v <= 10: + # since we have constrained max. consensus jobs per regular job to 10 + raise serializers.ValidationError(f"{k} must be in the range [0; 10]") + elif k == "sigma" and not 0 < v < 1: + raise serializers.ValidationError(f"{k} must be in the range [0; 100]") + + return super().validate(attrs) diff --git a/cvat/apps/consensus/signals.py b/cvat/apps/consensus/signals.py new file mode 100644 index 00000000000..9d8b60fecc6 --- /dev/null +++ b/cvat/apps/consensus/signals.py @@ -0,0 +1,38 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import math + +from django.db.models.signals import post_save +from django.dispatch import receiver + +from cvat.apps.consensus.models import ConsensusSettings +from cvat.apps.engine.models import Job, Task + + +@receiver( + post_save, + sender=Task, + dispatch_uid=__name__ + ".save_task-initialize_consensus_settings", +) +@receiver( + post_save, + sender=Job, + dispatch_uid=__name__ + ".save_job-initialize_consensus_settings", +) +def __save_task__initialize_consensus_settings(instance, created, **kwargs): + # Initializes default quality settings for the task + # this is done in a signal to decouple this component from the engine app + + if created: + if isinstance(instance, Task): + task = instance + elif isinstance(instance, Job): + task = instance.segment.task + else: + assert False + + ConsensusSettings.objects.get_or_create( + task=task, quorum=math.ceil(task.consensus_jobs_per_regular_job / 2) + ) diff --git a/cvat/apps/consensus/urls.py b/cvat/apps/consensus/urls.py new file mode 100644 index 00000000000..8d5da0a6f13 --- /dev/null +++ b/cvat/apps/consensus/urls.py @@ -0,0 +1,21 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from django.urls import include, path +from rest_framework import routers + +from cvat.apps.consensus import views + +router = routers.DefaultRouter(trailing_slash=False) +router.register("reports", views.ConsensusReportViewSet, basename="consensus_reports") +router.register("settings", views.ConsensusSettingsViewSet, basename="consensus_settings") +router.register("conflicts", views.ConsensusConflictsViewSet, basename="conflicts") +router.register( + "assignee_reports", views.AssigneeConsensusReportViewSet, basename="assignee_consensus_reports" +) + +urlpatterns = [ + # entry point for API + path("consensus/", include(router.urls)), +] diff --git a/cvat/apps/consensus/views.py b/cvat/apps/consensus/views.py new file mode 100644 index 00000000000..3b2b8e972b1 --- /dev/null +++ b/cvat/apps/consensus/views.py @@ -0,0 +1,472 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import textwrap + +import django_rq +from django.conf import settings +from django.db.models import Q +from django.http import HttpResponse +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import ( + OpenApiParameter, + OpenApiResponse, + extend_schema, + extend_schema_view, +) +from rest_framework import mixins, status, viewsets +from rest_framework.decorators import action +from rest_framework.exceptions import NotFound, ValidationError +from rest_framework.response import Response + +from cvat.apps.consensus.consensus_reports import prepare_report_for_downloading +from cvat.apps.consensus.merging_manager import scehdule_consensus_merging +from cvat.apps.consensus.models import ( + AssigneeConsensusReport, + ConsensusConflict, + ConsensusReport, + ConsensusReportTarget, + ConsensusSettings, +) +from cvat.apps.consensus.permissions import ( + AssigneeConsensusReportPermission, + ConsensusConflictPermission, + ConsensusReportPermission, + ConsensusSettingPermission, +) +from cvat.apps.consensus.serializers import ( + AssigneeConsensusReportSerializer, + ConsensusConflictSerializer, + ConsensusReportCreateSerializer, + ConsensusReportSerializer, + ConsensusSettingsSerializer, +) +from cvat.apps.engine.mixins import PartialUpdateModelMixin +from cvat.apps.engine.models import Job, Task +from cvat.apps.engine.serializers import RqIdSerializer +from cvat.apps.engine.utils import get_server_url + + +@extend_schema(tags=["consensus"]) +@extend_schema_view( + list=extend_schema( + summary="List annotation conflicts in a consensus report", + parameters=[ + # These filters are implemented differently from others + OpenApiParameter( + "report_id", + type=OpenApiTypes.INT, + description="A simple equality filter for report id", + ), + ], + responses={ + "200": ConsensusConflictSerializer(many=True), + }, + ), +) +class ConsensusConflictsViewSet(viewsets.GenericViewSet, mixins.ListModelMixin): + queryset = ( + ConsensusConflict.objects.select_related( + "report", + "report__parent", + "report__job", + "report__job__segment", + "report__job__segment__task", + "report__job__segment__task__organization", + "report__task", + "report__task__organization", + ) + .prefetch_related( + "annotation_ids", + ) + .all() + ) + + iam_organization_field = [ + "report__job__segment__task__organization", + "report__task__organization", + ] + + search_fields = [] + filter_fields = list(search_fields) + ["id", "frame", "type", "job_id", "task_id"] + simple_filters = set(filter_fields) - {"id"} + lookup_fields = { + "job_id": "report__job__id", + "task_id": "report__job__segment__task__id", # task reports do not contain own conflicts + } + ordering_fields = list(filter_fields) + ordering = "-id" + serializer_class = ConsensusConflictSerializer + + def get_queryset(self): + queryset = super().get_queryset() + + if self.action == "list": + if report_id := self.request.query_params.get("report_id", None): + # NOTE: This filter is too complex to be implemented by other means, + # it has a dependency on the report type + try: + report = ConsensusReport.objects.get(id=report_id) + except ConsensusReport.DoesNotExist as ex: + raise NotFound(f"Report {report_id} does not exist") from ex + + self.check_object_permissions(self.request, report) + + if report.target == ConsensusReportTarget.TASK: + queryset = queryset.filter( + Q(report=report) | Q(report__parent=report) + ).distinct() + elif report.target == ConsensusReportTarget.JOB: + queryset = queryset.filter(report=report) + else: + assert False + else: + perm = ConsensusConflictPermission.create_scope_list(self.request) + queryset = perm.filter(queryset) + + return queryset + + +@extend_schema(tags=["consensus"]) +@extend_schema_view( + retrieve=extend_schema( + operation_id="consensus_retrieve_report", # the default produces the plural + summary="Get consensus report details", + responses={ + "200": ConsensusReportSerializer, + }, + ), + list=extend_schema( + summary="List consensus reports", + parameters=[ + # These filters are implemented differently from others + OpenApiParameter( + "task_id", type=OpenApiTypes.INT, description="A simple equality filter for task id" + ), + OpenApiParameter( + "target", type=OpenApiTypes.STR, description="A simple equality filter for target" + ), + ], + responses={ + "200": ConsensusReportSerializer(many=True), + }, + ), +) +class ConsensusReportViewSet( + viewsets.GenericViewSet, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + mixins.CreateModelMixin, +): + queryset = ConsensusReport.objects.prefetch_related( + "job", + "job__segment", + "job__segment__task", + "job__segment__task__organization", + "task", + "task__organization", + ).all() + + iam_organization_field = ["job__segment__task__organization", "task__organization"] + + search_fields = [] + filter_fields = list(search_fields) + [ + "id", + "job_id", + "created_date", + "target_last_updated", + "parent_id", + ] + simple_filters = list(set(filter_fields) - {"id", "created_date", "target_last_updated"}) + ordering_fields = list(filter_fields) + ordering = "id" + + def get_serializer_class(self): + # a separate method is required for drf-spectacular to work + return ConsensusReportSerializer + + def get_queryset(self): + queryset = super().get_queryset() + + if self.action == "list": + if task_id := self.request.query_params.get("task_id", None): + # NOTE: This filter is too complex to be implemented by other means + try: + task = Task.objects.get(id=task_id) + except Task.DoesNotExist as ex: + raise NotFound(f"Task {task_id} does not exist") from ex + + self.check_object_permissions(self.request, task) + + queryset = queryset.filter( + Q(job__segment__task__id=task_id) | Q(task__id=task_id) + ).distinct() + else: + perm = ConsensusReportPermission.create_scope_list(self.request) + queryset = perm.filter(queryset) + + if target := self.request.query_params.get("target", None): + if target == ConsensusReportTarget.JOB: + queryset = queryset.filter(job__isnull=False) + elif target == ConsensusReportTarget.TASK: + queryset = queryset.filter(job__isnull=True) + else: + raise ValidationError( + "Unexpected 'target' filter value '{}'. Valid values are: {}".format( + target, ", ".join(m[0] for m in ConsensusReportTarget.choices()) + ) + ) + + return queryset + + CREATE_REPORT_RQ_ID_PARAMETER = "rq_id" + + @extend_schema( + operation_id="consensus_create_report", + summary="Create a consensus report", + parameters=[ + OpenApiParameter( + CREATE_REPORT_RQ_ID_PARAMETER, + type=str, + description=textwrap.dedent( + """\ + The report creation request id. Can be specified to check the report + creation status. + """ + ), + ) + ], + request=ConsensusReportCreateSerializer(required=False), + responses={ + "201": ConsensusReportSerializer, + "202": OpenApiResponse( + RqIdSerializer, + description=textwrap.dedent( + """\ + A consensus report request has been enqueued, the request id is returned. + The request status can be checked at this endpoint by passing the {} + as the query parameter. If the request id is specified, this response + means the consensus report request is queued or is being processed. + """.format( + CREATE_REPORT_RQ_ID_PARAMETER + ) + ), + ), + "400": OpenApiResponse( + description="Invalid or failed request, check the response data for details" + ), + }, + ) + def create(self, request, *args, **kwargs): + self.check_permissions(request) + input_serializer = ConsensusReportCreateSerializer(data=request.data) + input_serializer.is_valid(raise_exception=True) + + queue_name = settings.CVAT_QUEUES.CONSENSUS.value + queue = django_rq.get_queue(queue_name) + rq_id = request.query_params.get(self.CREATE_REPORT_RQ_ID_PARAMETER, None) + + if rq_id is None: + try: + task_id = input_serializer.validated_data.get("task_id", 0) + job_id = input_serializer.validated_data.get("job_id", 0) + if task_id: + instance = Task.objects.get(pk=task_id) + elif job_id: + instance = Job.objects.get(pk=job_id) + else: + raise ValidationError("Task or Job id is required") + except Task.DoesNotExist as ex: + raise NotFound(f"Task {task_id} does not exist") from ex + + try: + return scehdule_consensus_merging(instance, request) + except Exception as ex: + raise ValidationError(str(ex)) + + else: + rq_job = queue.fetch_job(rq_id) + if ( + not rq_job + or not ConsensusReportPermission.create_scope_check_status( + request, job_owner_id=rq_job.meta["user"]["id"] + ) + .check_access() + .allow + ): + # We should not provide job existence information to unauthorized users + raise NotFound("Unknown request id") + + if rq_job.is_failed: + message = str(rq_job.exc_info) + rq_job.delete() + raise ValidationError(message) + elif rq_job.is_queued or rq_job.is_started: + return Response(status=status.HTTP_202_ACCEPTED) + elif rq_job.is_finished: + return_value = rq_job.return_value() + rq_job.delete() + if not return_value: + raise ValidationError("No report has been computed") + + report = self.get_queryset().get(pk=return_value) + report_serializer = ConsensusReportSerializer(instance=report) + return Response( + data=report_serializer.data, + status=status.HTTP_201_CREATED, + headers=self.get_success_headers(report_serializer.data), + ) + + @extend_schema( + operation_id="consensus_retrieve_report_data", + summary="Get consensus report contents", + responses={"200": OpenApiTypes.OBJECT}, + ) + @action(detail=True, methods=["GET"], url_path="data", serializer_class=None) + def data(self, request, pk): + report = self.get_object() # check permissions + json_report = prepare_report_for_downloading(report, host=get_server_url(request)) + return HttpResponse(json_report.encode(), content_type="application/json") + + +@extend_schema(tags=["consensus"]) +@extend_schema_view( + list=extend_schema( + summary="List consensus settings instances", + responses={ + "200": ConsensusSettingsSerializer(many=True), + }, + ), + retrieve=extend_schema( + summary="Get consensus settings instance details", + parameters=[ + OpenApiParameter( + "id", + type=OpenApiTypes.INT, + location="path", + description="An id of a consensus settings instance", + ) + ], + responses={ + "200": ConsensusSettingsSerializer, + }, + ), + partial_update=extend_schema( + summary="Update a consensus settings instance", + parameters=[ + OpenApiParameter( + "id", + type=OpenApiTypes.INT, + location="path", + description="An id of a consensus settings instance", + ) + ], + request=ConsensusSettingsSerializer(partial=True), + responses={ + "200": ConsensusSettingsSerializer, + }, + ), +) +class ConsensusSettingsViewSet( + viewsets.GenericViewSet, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + PartialUpdateModelMixin, +): + queryset = ConsensusSettings.objects.select_related("task", "task__organization").all() + + iam_organization_field = "task__organization" + + search_fields = [] + filter_fields = ["id", "task_id"] + simple_filters = ["task_id"] + ordering_fields = ["id"] + ordering = "id" + + serializer_class = ConsensusSettingsSerializer + + def get_queryset(self): + queryset = super().get_queryset() + + if self.action == "list": + permissions = ConsensusSettingPermission.create_scope_list(self.request) + queryset = permissions.filter(queryset) + + return queryset + + +@extend_schema(tags=["consensus"]) +@extend_schema_view( + retrieve=extend_schema( + operation_id="assignee_consensus_retrieve_report", + summary="Get assignee consensus report details", + responses={ + "200": AssigneeConsensusReportSerializer, + }, + ), + list=extend_schema( + summary="List assignee consensus reports", + parameters=[ + # These filters are implemented differently from others + OpenApiParameter( + "task_id", type=OpenApiTypes.INT, description="A simple equality filter for task id" + ), + ], + responses={ + "200": AssigneeConsensusReportSerializer(many=True), + }, + ), +) +class AssigneeConsensusReportViewSet( + viewsets.GenericViewSet, + mixins.ListModelMixin, + mixins.RetrieveModelMixin, +): + queryset = AssigneeConsensusReport.objects.prefetch_related( + "task", + "task__organization", + ).all() + + iam_organization_field = ["task__organization"] + + search_fields = [] + filter_fields = list(search_fields) + ["id", "consensus_report_id"] + simple_filters = list(set(filter_fields) - {"id"}) + ordering_fields = list(filter_fields) + ordering = "id" + + def get_serializer_class(self): + # a separate method is required for drf-spectacular to work + return AssigneeConsensusReportSerializer + + def get_queryset(self): + queryset = super().get_queryset() + + if self.action == "list": + if task_id := self.request.query_params.get("task_id", None): + # NOTE: This filter is too complex to be implemented by other means + try: + task = Task.objects.get(id=task_id) + except Task.DoesNotExist as ex: + raise NotFound(f"Task {task_id} does not exist") from ex + + self.check_object_permissions(self.request, task) + + queryset = queryset.filter(Q(task__id=task_id)).distinct() + else: + perm = AssigneeConsensusReportPermission.create_scope_list(self.request) + queryset = perm.filter(queryset) + + return queryset + + @extend_schema( + operation_id="assignee_consensus_retrieve_report_data", + summary="Get assignee consensus report contents", + responses={"200": OpenApiTypes.OBJECT}, + ) + @action(detail=True, methods=["GET"], url_path="data", serializer_class=None) + def data(self, request, pk): + report = self.get_object() # check permissions + json_report = prepare_report_for_downloading(report, host=get_server_url(request)) + return HttpResponse(json_report.encode(), content_type="application/json") diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index eb8fdf26b52..5e8c68286e4 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -766,6 +766,7 @@ def _init_meta(self): ("start_frame", str(self._db_data.start_frame + db_segment.start_frame * self._frame_step)), ("stop_frame", str(self._db_data.start_frame + db_segment.stop_frame * self._frame_step)), ("frame_filter", self._db_data.frame_filter), + ("parent_job_id", str(self._db_job.parent_job_id)), ("segments", [ ("segment", OrderedDict([ ("id", str(db_segment.id)), @@ -2085,6 +2086,8 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa dm.AnnotationType.mask: ShapeType.MASK } + sources = {'auto', 'semi-auto', 'manual', 'file', 'consensus'} + track_formats = [ 'cvat', 'datumaro', @@ -2161,7 +2164,7 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa track_id = ann.attributes.pop('track_id', None) source = ann.attributes.pop('source').lower() \ - if ann.attributes.get('source', '').lower() in {'auto', 'semi-auto', 'manual', 'file'} else 'manual' + if ann.attributes.get('source', '').lower() in sources else 'manual' shape_type = shapes[ann.type] if track_id is None or 'keyframe' not in ann.attributes or dm_dataset.format not in track_formats: @@ -2175,7 +2178,7 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa element_occluded = element.visibility[0] == dm.Points.Visibility.hidden element_outside = element.visibility[0] == dm.Points.Visibility.absent element_source = element.attributes.pop('source').lower() \ - if element.attributes.get('source', '').lower() in {'auto', 'semi-auto', 'manual', 'file'} else 'manual' + if element.attributes.get('source', '').lower() in sources else 'manual' elements.append(instance_data.LabeledShape( type=shapes[element.type], frame=frame_number, @@ -2247,7 +2250,7 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa for n, v in element.attributes.items() ] element_source = element.attributes.pop('source').lower() \ - if element.attributes.get('source', '').lower() in {'auto', 'semi-auto', 'manual', 'file'} else 'manual' + if element.attributes.get('source', '').lower() in sources else 'manual' tracks[track_id]['elements'][element.label].shapes.append(instance_data.TrackedShape( type=shapes[element.type], diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index d7d16f8cba3..1c54b68e66d 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -769,6 +769,7 @@ def __init__(self, pk): ).get(id=pk) # Postgres doesn't guarantee an order by default without explicit order_by + # Only select regular jobs, not ground truth or consensus jobs self.db_jobs = models.Job.objects.select_related("segment").filter( segment__task_id=pk, type=models.JobType.ANNOTATION.value, ).order_by('id') diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index bd3918a037a..810ff90db2e 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -185,6 +185,7 @@ def _prepare_task_meta(self, task): 'status', 'subset', 'labels', + 'consensus_jobs_per_regular_job', } return self._prepare_meta(allowed_fields, task) @@ -292,7 +293,7 @@ def _get_db_jobs(self): if self._db_task: db_segments = list(self._db_task.segment_set.all().prefetch_related('job_set')) db_segments.sort(key=lambda i: i.job_set.first().id) - db_jobs = (s.job_set.first() for s in db_segments) + db_jobs = (job for s in db_segments for job in s.job_set.all()) return db_jobs return () @@ -397,27 +398,31 @@ def serialize_task(): return task def serialize_segment(db_segment): - db_job = db_segment.job_set.first() - job_serializer = SimpleJobSerializer(db_job) - for field in ('url', 'assignee'): - job_serializer.fields.pop(field) - job_data = self._prepare_job_meta(job_serializer.data) + segments = [] + db_jobs = db_segment.job_set.all() + for db_job in db_jobs: + job_serializer = SimpleJobSerializer(db_job) + for field in ('url', 'assignee'): + job_serializer.fields.pop(field) + job_data = self._prepare_job_meta(job_serializer.data) - segment_serializer = SegmentSerializer(db_segment) - segment_serializer.fields.pop('jobs') - segment = segment_serializer.data - segment_type = segment.pop("type") - segment.update(job_data) + segment_serializer = SegmentSerializer(db_segment) + segment_serializer.fields.pop('jobs') + segment = segment_serializer.data + segment_type = segment.pop("type") + segment.update(job_data) - if self._db_task.segment_size == 0 and segment_type == models.SegmentType.RANGE: - segment.update(serialize_custom_file_mapping(db_segment)) + if self._db_task.segment_size == 0 and segment_type == models.SegmentType.RANGE: + segment.update(serialize_custom_file_mapping(db_segment)) - return segment + segments.append(segment) + + return segments def serialize_jobs(): db_segments = list(self._db_task.segment_set.all()) db_segments.sort(key=lambda i: i.job_set.first().id) - return (serialize_segment(s) for s in db_segments) + return (serialized_job for s in db_segments for serialized_job in serialize_segment(s)) def serialize_custom_file_mapping(db_segment: models.Segment): if self._db_task.mode == 'annotation': @@ -725,10 +730,6 @@ def _import_gt_jobs(self, jobs): }) job_serializer.is_valid(raise_exception=True) job_serializer.save() - elif job_type == models.JobType.ANNOTATION: - continue - else: - assert False def _import_annotations(self): db_jobs = self._get_db_jobs() diff --git a/cvat/apps/engine/migrations/0084_job_parent_job_task_consensus_jobs_per_regular_job_and_more.py b/cvat/apps/engine/migrations/0084_job_parent_job_task_consensus_jobs_per_regular_job_and_more.py new file mode 100644 index 00000000000..846087f62fa --- /dev/null +++ b/cvat/apps/engine/migrations/0084_job_parent_job_task_consensus_jobs_per_regular_job_and_more.py @@ -0,0 +1,44 @@ +# Generated by Django 4.2.13 on 2024-08-23 05:25 + +import cvat.apps.engine.models +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("engine", "0083_move_to_segment_chunks"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="parent_job", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="children_jobs", + to="engine.job", + ), + ), + migrations.AddField( + model_name="task", + name="consensus_jobs_per_regular_job", + field=models.IntegerField(blank=True, default=0), + ), + migrations.AlterField( + model_name="job", + name="type", + field=models.CharField( + choices=[ + ("annotation", "ANNOTATION"), + ("ground_truth", "GROUND_TRUTH"), + ("consensus", "CONSENSUS"), + ], + default=cvat.apps.engine.models.JobType["ANNOTATION"], + max_length=32, + ), + ), + ] diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index c57eb0371d5..0fba8ab2122 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -163,6 +163,7 @@ def __str__(self): class JobType(str, Enum): ANNOTATION = 'annotation' GROUND_TRUTH = 'ground_truth' + CONSENSUS = 'consensus' @classmethod def choices(cls): @@ -412,6 +413,7 @@ def with_job_summary(self): return self.prefetch_related( 'segment_set__job_set', ).annotate( + total_jobs_count=models.Count('segment__job', distinct=True), completed_jobs_count=models.Count( 'segment__job', filter=models.Q(segment__job__state=StateChoice.COMPLETED.value) & @@ -454,6 +456,7 @@ class Task(TimestampedModel): blank=True, on_delete=models.SET_NULL, related_name='+') target_storage = models.ForeignKey('Storage', null=True, default=None, blank=True, on_delete=models.SET_NULL, related_name='+') + consensus_jobs_per_regular_job = models.IntegerField(default=0, blank=True) # Extend default permission model class Meta: @@ -731,6 +734,7 @@ class Job(TimestampedModel): type = models.CharField(max_length=32, choices=JobType.choices(), default=JobType.ANNOTATION) + parent_job = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children_jobs') def get_target_storage(self) -> Optional[Storage]: return self.segment.task.target_storage diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index ed937a993ff..88aa52140c7 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -140,13 +140,15 @@ def bind(self, field_name, parent): def get_fields(self): fields = super().get_fields() fields['url'] = HyperlinkedEndpointSerializer(self._model, filter_key=self._url_filter_key) - fields['count'].source = self._collection_key + '.count' + if not fields['count'].source: + fields['count'].source = self._collection_key + '.count' return fields def get_attribute(self, instance): return instance class JobsSummarySerializer(_CollectionSummarySerializer): + count = serializers.IntegerField(source='total_jobs_count', allow_null=True) completed = serializers.IntegerField(source='completed_jobs_count', allow_null=True) validation = serializers.IntegerField(source='validation_jobs_count', allow_null=True) @@ -603,6 +605,7 @@ class JobReadSerializer(serializers.ModelSerializer): issues = IssuesSummarySerializer(source='*') target_storage = StorageSerializer(required=False, allow_null=True) source_storage = StorageSerializer(required=False, allow_null=True) + parent_job_id = serializers.ReadOnlyField(allow_null=True) class Meta: model = models.Job @@ -611,7 +614,7 @@ class Meta: 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type', 'created_date', 'updated_date', 'issues', 'labels', 'type', 'organization', - 'target_storage', 'source_storage', 'assignee_updated_date') + 'target_storage', 'source_storage', 'assignee_updated_date', 'parent_job_id') read_only_fields = fields def to_representation(self, instance): @@ -1118,6 +1121,7 @@ class TaskReadSerializer(serializers.ModelSerializer): source_storage = StorageSerializer(required=False, allow_null=True) jobs = JobsSummarySerializer(url_filter_key='task_id', source='segment_set') labels = LabelsSummarySerializer(source='*') + consensus_jobs_per_regular_job = serializers.ReadOnlyField(required=False, allow_null=True) class Meta: model = models.Task @@ -1126,7 +1130,7 @@ class Meta: 'status', 'data_chunk_size', 'data_compressed_chunk_type', 'guide_id', 'data_original_chunk_type', 'size', 'image_quality', 'data', 'dimension', 'subset', 'organization', 'target_storage', 'source_storage', 'jobs', 'labels', - 'assignee_updated_date' + 'assignee_updated_date', 'consensus_jobs_per_regular_job', ) read_only_fields = fields extra_kwargs = { @@ -1142,12 +1146,13 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): project_id = serializers.IntegerField(required=False, allow_null=True) target_storage = StorageSerializer(required=False, allow_null=True) source_storage = StorageSerializer(required=False, allow_null=True) + consensus_jobs_per_regular_job = serializers.IntegerField(required=False, allow_null=True) class Meta: model = models.Task fields = ('url', 'id', 'name', 'project_id', 'owner_id', 'assignee_id', 'bug_tracker', 'overlap', 'segment_size', 'labels', 'subset', - 'target_storage', 'source_storage', + 'target_storage', 'source_storage', 'consensus_jobs_per_regular_job', ) write_once_fields = ('overlap', 'segment_size') @@ -1336,6 +1341,11 @@ def validate(self, attrs): if sublabels != target_project_sublabel_names.get(label): raise serializers.ValidationError('All task or project label names must be mapped to the target project') + consensus_jobs_per_regular_job = attrs.get('consensus_jobs_per_regular_job', self.instance.consensus_jobs_per_regular_job if self.instance else None) + + if consensus_jobs_per_regular_job and (consensus_jobs_per_regular_job == 1 or consensus_jobs_per_regular_job < 0 or consensus_jobs_per_regular_job > 10): + raise serializers.ValidationError("Consensus jobs per regular job shouldn't be negative, less than 10 except 1") + return attrs class ProjectReadSerializer(serializers.ModelSerializer): diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index f24cd686a58..9d3e0034869 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -203,6 +203,12 @@ def _create_segments_and_jobs( db_job.save() db_job.make_dirs() + # consensus jobs use the same `db_segment` as the regular job, thus data not duplicated in backups, exports + for _ in range(db_task.consensus_jobs_per_regular_job): + consensus_db_job = models.Job(segment=db_segment, parent_job_id=db_job.id, type=models.JobType.CONSENSUS) + consensus_db_job.save() + consensus_db_job.make_dirs() + db_task.data.save() db_task.save() diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 3cb7e34c5c4..ed276c955d3 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -110,6 +110,7 @@ CommentPermission, IssuePermission, JobPermission, LabelPermission, ProjectPermission, TaskPermission, UserPermission) from cvat.apps.engine.view_utils import tus_chunk_action +from cvat.apps.consensus.merging_manager import scehdule_consensus_merging slogger = ServerLogManager(__name__) @@ -1712,7 +1713,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, mixins.CreateMo iam_organization_field = 'segment__task__organization' search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage') filter_fields = list(search_fields) + [ - 'id', 'task_id', 'project_id', 'updated_date', 'dimension', 'type' + 'id', 'task_id', 'project_id', 'updated_date', 'dimension', 'type', 'parent_job_id', ] simple_filters = list(set(filter_fields) - {'id', 'updated_date'}) ordering_fields = list(filter_fields) @@ -1967,7 +1968,6 @@ def annotations(self, request, pk): return Response(data=str(e), status=status.HTTP_400_BAD_REQUEST) return Response(data) - @tus_chunk_action(detail=True, suffix_base="annotations") def append_annotations_chunk(self, request, pk, file_id): self._object = self.get_object() diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index a62e46f52cc..958e9c87b8a 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -55,9 +55,9 @@ from cvat.utils.background_jobs import schedule_job_with_throttling -class _Serializable: +class Serializable: def _value_serializer(self, v): - if isinstance(v, _Serializable): + if isinstance(v, Serializable): return v.to_dict() elif isinstance(v, (list, tuple, set, frozenset)): return [self._value_serializer(vv) for vv in v] @@ -83,7 +83,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class AnnotationId(_Serializable): +class AnnotationId(Serializable): obj_id: int job_id: int type: AnnotationType @@ -106,7 +106,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class AnnotationConflict(_Serializable): +class AnnotationConflict(Serializable): frame_id: int type: AnnotationConflictType annotation_ids: List[AnnotationId] @@ -151,7 +151,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonParameters(_Serializable): +class ComparisonParameters(Serializable): included_annotation_types: List[dm.AnnotationType] = [ dm.AnnotationType.bbox, dm.AnnotationType.points, @@ -219,7 +219,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ConfusionMatrix(_Serializable): +class ConfusionMatrix(Serializable): labels: List[str] rows: np.array precision: np.array @@ -255,7 +255,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReportAnnotationsSummary(_Serializable): +class ComparisonReportAnnotationsSummary(Serializable): valid_count: int missing_count: int extra_count: int @@ -306,7 +306,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReportAnnotationShapeSummary(_Serializable): +class ComparisonReportAnnotationShapeSummary(Serializable): valid_count: int missing_count: int extra_count: int @@ -347,7 +347,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReportAnnotationLabelSummary(_Serializable): +class ComparisonReportAnnotationLabelSummary(Serializable): valid_count: int invalid_count: int total_count: int @@ -373,7 +373,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReportAnnotationComponentsSummary(_Serializable): +class ComparisonReportAnnotationComponentsSummary(Serializable): shape: ComparisonReportAnnotationShapeSummary label: ComparisonReportAnnotationLabelSummary @@ -390,7 +390,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReportComparisonSummary(_Serializable): +class ComparisonReportComparisonSummary(Serializable): frame_share: float frames: List[str] @@ -447,7 +447,7 @@ def from_dict(cls, d: dict): @define(kw_only=True, init=False) -class ComparisonReportFrameSummary(_Serializable): +class ComparisonReportFrameSummary(Serializable): conflicts: List[AnnotationConflict] @cached_property @@ -513,7 +513,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) -class ComparisonReport(_Serializable): +class ComparisonReport(Serializable): parameters: ComparisonParameters comparison_summary: ComparisonReportComparisonSummary frame_results: Dict[int, ComparisonReportFrameSummary] @@ -641,7 +641,7 @@ def _convert_shape(self, shape, *, index): return converted -def _match_segments( +def match_segments( a_segms, b_segms, distance=dm.ops.segment_iou, @@ -700,7 +700,7 @@ def _match_segments( return matches, mispred, a_unmatched, b_unmatched -def _OKS(a, b, sigma=0.1, bbox=None, scale=None, visibility_a=None, visibility_b=None): +def oks(a, b, sigma=0.1, bbox=None, scale=None, visibility_a=None, visibility_b=None): """ Object Keypoint Similarity metric. https://cocodataset.org/#keypoints-eval @@ -728,12 +728,14 @@ def _OKS(a, b, sigma=0.1, bbox=None, scale=None, visibility_a=None, visibility_b dists = np.linalg.norm(p1 - p2, axis=1) return np.sum( - visibility_a * visibility_b * np.exp(-(dists**2) / (2 * scale * (2 * sigma) ** 2)) + visibility_a + * visibility_b + * np.exp((visibility_a == visibility_b) * (-(dists**2) / (2 * scale * (2 * sigma) ** 2))) ) / np.sum(visibility_a | visibility_b, dtype=float) @define(kw_only=True) -class _KeypointsMatcher(dm.ops.PointsMatcher): +class KeypointsMatcher(dm.ops.PointsMatcher): def distance(self, a: dm.Points, b: dm.Points) -> float: a_bbox = self.instance_map[id(a)][1] b_bbox = self.instance_map[id(b)][1] @@ -741,7 +743,7 @@ def distance(self, a: dm.Points, b: dm.Points) -> float: return 0 bbox = dm.ops.mean_bbox([a_bbox, b_bbox]) - return _OKS( + return oks( a, b, sigma=self.sigma, @@ -757,7 +759,7 @@ def _arr_div(a_arr: np.ndarray, b_arr: np.ndarray) -> np.ndarray: return a_arr / divisor -def _to_rle(ann: dm.Annotation, *, img_h: int, img_w: int): +def to_rle(ann: dm.Annotation, *, img_h: int, img_w: int): from pycocotools import mask as mask_utils if ann.type == dm.AnnotationType.polygon: @@ -770,7 +772,7 @@ def _to_rle(ann: dm.Annotation, *, img_h: int, img_w: int): assert False -def _segment_iou(a: dm.Annotation, b: dm.Annotation, *, img_h: int, img_w: int) -> float: +def segment_iou(a: dm.Annotation, b: dm.Annotation, *, img_h: int, img_w: int) -> float: """ Generic IoU computation with masks and polygons. Returns -1 if no intersection, [0; 1] otherwise @@ -781,15 +783,15 @@ def _segment_iou(a: dm.Annotation, b: dm.Annotation, *, img_h: int, img_w: int) from pycocotools import mask as mask_utils - a = _to_rle(a, img_h=img_h, img_w=img_w) - b = _to_rle(b, img_h=img_h, img_w=img_w) + a = to_rle(a, img_h=img_h, img_w=img_w) + b = to_rle(b, img_h=img_h, img_w=img_w) # Note that mask_utils.iou expects (dt, gt). Check this if the 3rd param is True return float(mask_utils.iou(b, a, [0])) @define(kw_only=True) -class _LineMatcher(dm.ops.LineMatcher): +class LineMatcher(dm.ops.LineMatcher): EPSILON = 1e-7 torso_r: float = 0.25 @@ -936,7 +938,7 @@ def approximate_points(cls, a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, n return a_new_points, b_new_points -class _DistanceComparator(dm.ops.DistanceComparator): +class DistanceComparator(dm.ops.DistanceComparator): def __init__( self, categories: dm.CategoriesInfo, @@ -970,7 +972,7 @@ def __init__( self.panoptic_comparison = panoptic_comparison "Compare only the visible parts of polygons and masks" - def _instance_bbox( + def instance_bbox( self, instance_anns: Sequence[dm.Annotation] ) -> Tuple[float, float, float, float]: return dm.ops.max_bbox( @@ -980,7 +982,24 @@ def _instance_bbox( ) @staticmethod - def _get_ann_type(t, item: dm.Annotation) -> Sequence[dm.Annotation]: + def to_polygon(bbox_ann: dm.Bbox): + points = bbox_ann.as_polygon() + angle = bbox_ann.attributes.get("rotation", 0) / 180 * math.pi + + if angle: + points = np.reshape(points, (-1, 2)) + center = (points[0] + points[2]) / 2 + rel_points = points - center + cos = np.cos(angle) + sin = np.sin(angle) + rotation_matrix = ((cos, sin), (-sin, cos)) + points = np.matmul(rel_points, rotation_matrix) + center + points = points.flatten() + + return dm.Polygon(points) + + @staticmethod + def _get_ann_type(t, item: dm.DatasetItem) -> Sequence[dm.Annotation]: return [ a for a in item.annotations if a.type == t and not a.attributes.get("outside", False) ] @@ -1012,7 +1031,7 @@ def label_distance(a, b): return 0 return 0.5 + (a.label == b.label) / 2 - return self._match_segments( + return self.match_segments( dm.AnnotationType.label, item_a, item_b, @@ -1021,7 +1040,7 @@ def label_distance(a, b): dist_thresh=0.5, ) - def _match_segments( + def match_segments( self, t, item_a, @@ -1049,7 +1068,7 @@ def _match_segments( if label_matcher: extra_args["label_matcher"] = label_matcher - returned_values = _match_segments( + returned_values = match_segments( a_objs, b_objs, distance=distance, @@ -1063,30 +1082,15 @@ def _match_segments( return returned_values def match_boxes(self, item_a, item_b): - def _to_polygon(bbox_ann: dm.Bbox): - points = bbox_ann.as_polygon() - angle = bbox_ann.attributes.get("rotation", 0) / 180 * math.pi - - if angle: - points = np.reshape(points, (-1, 2)) - center = (points[0] + points[2]) / 2 - rel_points = points - center - cos = np.cos(angle) - sin = np.sin(angle) - rotation_matrix = ((cos, sin), (-sin, cos)) - points = np.matmul(rel_points, rotation_matrix) + center - points = points.flatten() - - return dm.Polygon(points) def _bbox_iou(a: dm.Bbox, b: dm.Bbox, *, img_w: int, img_h: int) -> float: if a.attributes.get("rotation", 0) == b.attributes.get("rotation", 0): return dm.ops.bbox_iou(a, b) else: - return _segment_iou(_to_polygon(a), _to_polygon(b), img_h=img_h, img_w=img_w) + return segment_iou(self.to_polygon(a), self.to_polygon(b), img_h=img_h, img_w=img_w) img_h, img_w = item_a.image.size - return self._match_segments( + return self.match_segments( dm.AnnotationType.bbox, item_a, item_b, @@ -1126,7 +1130,7 @@ def _get_compiled_mask( from pycocotools import mask as mask_utils - object_rle_groups = [_to_rle(ann, img_h=img_h, img_w=img_w) for ann in anns] + object_rle_groups = [to_rle(ann, img_h=img_h, img_w=img_w) for ann in anns] object_rles = [mask_utils.merge(g) for g in object_rle_groups] object_masks = mask_utils.decode(object_rles) @@ -1178,7 +1182,7 @@ def _get_segment( # Create merged RLE for the instance shapes object_anns = instances[obj_id] object_rle_groups = [ - _to_rle(ann, img_h=img_h, img_w=img_w) for ann in object_anns + to_rle(ann, img_h=img_h, img_w=img_w) for ann in object_anns ] rle = mask_utils.merge(list(itertools.chain.from_iterable(object_rle_groups))) @@ -1201,7 +1205,7 @@ def _label_matcher(a_inst_id: int, b_inst_id: int) -> bool: b = b_instances[b_inst_id][0] return a.label == b.label - results = self._match_segments( + results = self.match_segments( dm.AnnotationType.polygon, item_a, item_b, @@ -1246,12 +1250,12 @@ def _label_matcher(a_inst_id: int, b_inst_id: int) -> bool: return returned_values def match_lines(self, item_a, item_b): - matcher = _LineMatcher( + matcher = LineMatcher( oriented=self.compare_line_orientation, torso_r=self.line_torso_radius, scale=np.prod(item_a.image.size), ) - return self._match_segments( + return self.match_segments( dm.AnnotationType.polyline, item_a, item_b, distance=matcher.distance ) @@ -1263,7 +1267,7 @@ def match_points(self, item_a, item_b): for source_anns in [item_a.annotations, item_b.annotations]: source_instances = dm.ops.find_instances(source_anns) for instance_group in source_instances: - instance_bbox = self._instance_bbox(instance_group) + instance_bbox = self.instance_bbox(instance_group) for ann in instance_group: if ann.type == dm.AnnotationType.points: @@ -1280,7 +1284,7 @@ def _distance(a: dm.Points, b: dm.Points) -> float: if a_area == 0 and b_area == 0: # Simple case: singular points without bbox # match them in the image space - return _OKS(a, b, sigma=self.oks_sigma, scale=img_h * img_w) + return oks(a, b, sigma=self.oks_sigma, scale=img_h * img_w) else: # Complex case: multiple points, grouped points, points with a bbox @@ -1296,10 +1300,10 @@ def _distance(a: dm.Points, b: dm.Points) -> float: a_points = np.reshape(a.points, (-1, 2)) b_points = np.reshape(b.points, (-1, 2)) - matches, mismatches, a_extra, b_extra = _match_segments( + matches, mismatches, a_extra, b_extra = match_segments( range(len(a_points)), range(len(b_points)), - distance=lambda ai, bi: _OKS( + distance=lambda ai, bi: oks( dm.Points(a_points[ai]), dm.Points(b_points[bi]), sigma=self.oks_sigma, @@ -1326,7 +1330,7 @@ def _distance(a: dm.Points, b: dm.Points) -> float: len(matched_points) + len(a_extra) + len(b_extra) ) - return self._match_segments( + return self.match_segments( dm.AnnotationType.points, item_a, item_b, @@ -1400,7 +1404,7 @@ def match_skeletons(self, item_a, item_b): instance_map = {} for source in [item_a.annotations, item_b.annotations]: for instance_group in dm.ops.find_instances(source): - instance_bbox = self._instance_bbox(instance_group) + instance_bbox = self.instance_bbox(instance_group) instance_group = [ skeleton_map[id(a)] if isinstance(a, dm.Skeleton) else a @@ -1410,9 +1414,9 @@ def match_skeletons(self, item_a, item_b): for ann in instance_group: instance_map[id(ann)] = [instance_group, instance_bbox] - matcher = _KeypointsMatcher(instance_map=instance_map, sigma=self.oks_sigma) + matcher = KeypointsMatcher(instance_map=instance_map, sigma=self.oks_sigma) - results = self._match_segments( + results = self.match_segments( dm.AnnotationType.points, item_a, item_b, @@ -1509,7 +1513,7 @@ def __init__(self, categories: dm.CategoriesInfo, *, settings: ComparisonParamet } self.included_ann_types = settings.included_annotation_types self.non_groupable_ann_type = settings.non_groupable_ann_type - self._annotation_comparator = _DistanceComparator( + self._annotation_comparator = DistanceComparator( categories, included_ann_types=set(self.included_ann_types) - {dm.AnnotationType.mask}, # masks are compared together with polygons @@ -1584,7 +1588,7 @@ def _group_distance(gt_group_id, ds_group_id): union = len(gt_groups[gt_group_id]) + len(ds_groups[ds_group_id]) - intersection return intersection / (union or 1) - matches, mismatches, gt_unmatched, ds_unmatched = _match_segments( + matches, mismatches, gt_unmatched, ds_unmatched = match_segments( list(gt_groups), list(ds_groups), distance=_group_distance, @@ -1826,7 +1830,7 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): and dm.AnnotationType.polyline in self.comparator.included_ann_types ): # Check line directions - line_matcher = _LineMatcher( + line_matcher = LineMatcher( torso_r=self.settings.line_thickness, oriented=True, scale=np.prod(gt_item.image.size), diff --git a/cvat/schema.yml b/cvat/schema.yml index 779b08fe376..1c7c12bf10c 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -1049,6 +1049,554 @@ paths: responses: '204': description: The comment has been deleted + /api/consensus/assignee_reports: + get: + operationId: consensus_list_assignee_reports + summary: List assignee consensus reports + parameters: + - name: X-Organization + in: header + description: Organization unique slug + schema: + type: string + - name: consensus_report_id + in: query + description: A simple equality filter for the consensus_report_id field + schema: + type: integer + - name: filter + required: false + in: query + description: |2- + + JSON Logic filter. This filter can be used to perform complex filtering by grouping rules. + + For example, using such a filter you can get all resources created by you: + + - {"and":[{"==":[{"var":"owner"},""]}]} + + Details about the syntax used can be found at the link: https://jsonlogic.com/ + + Available filter_fields: ['id', 'consensus_report_id']. + schema: + type: string + - name: org + in: query + description: Organization unique slug + schema: + type: string + - name: org_id + in: query + description: Organization identifier + schema: + type: integer + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: sort + required: false + in: query + description: 'Which field to use when ordering the results. Available ordering_fields: + [''id'', ''consensus_report_id'']' + schema: + type: string + - in: query + name: task_id + schema: + type: integer + description: A simple equality filter for task id + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/PaginatedAssigneeConsensusReportList' + description: '' + /api/consensus/assignee_reports/{id}: + get: + operationId: assignee_consensus_retrieve_report + summary: Get assignee consensus report details + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this assignee consensus report. + required: true + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/AssigneeConsensusReport' + description: '' + /api/consensus/assignee_reports/{id}/data: + get: + operationId: assignee_consensus_retrieve_report_data + summary: Get assignee consensus report contents + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this assignee consensus report. + required: true + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + type: object + description: '' + /api/consensus/conflicts: + get: + operationId: consensus_list_conflicts + summary: List annotation conflicts in a consensus report + parameters: + - name: X-Organization + in: header + description: Organization unique slug + schema: + type: string + - name: filter + required: false + in: query + description: |2- + + JSON Logic filter. This filter can be used to perform complex filtering by grouping rules. + + For example, using such a filter you can get all resources created by you: + + - {"and":[{"==":[{"var":"owner"},""]}]} + + Details about the syntax used can be found at the link: https://jsonlogic.com/ + + Available filter_fields: ['id', 'frame', 'type', 'job_id', 'task_id']. + schema: + type: string + - name: frame + in: query + description: A simple equality filter for the frame field + schema: + type: integer + - name: job_id + in: query + description: A simple equality filter for the job_id field + schema: + type: integer + - name: org + in: query + description: Organization unique slug + schema: + type: string + - name: org_id + in: query + description: Organization identifier + schema: + type: integer + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - in: query + name: report_id + schema: + type: integer + description: A simple equality filter for report id + - name: sort + required: false + in: query + description: 'Which field to use when ordering the results. Available ordering_fields: + [''id'', ''frame'', ''type'', ''job_id'', ''task_id'']' + schema: + type: string + - name: task_id + in: query + description: A simple equality filter for the task_id field + schema: + type: integer + - name: type + in: query + description: A simple equality filter for the type field + schema: + type: string + enum: + - no_matching_item + - no_matching_annotation + - annotation_too_close + - failed_label_voting + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/PaginatedConsensusConflictList' + description: '' + /api/consensus/reports: + get: + operationId: consensus_list_reports + summary: List consensus reports + parameters: + - name: X-Organization + in: header + description: Organization unique slug + schema: + type: string + - name: filter + required: false + in: query + description: |2- + + JSON Logic filter. This filter can be used to perform complex filtering by grouping rules. + + For example, using such a filter you can get all resources created by you: + + - {"and":[{"==":[{"var":"owner"},""]}]} + + Details about the syntax used can be found at the link: https://jsonlogic.com/ + + Available filter_fields: ['id', 'job_id', 'created_date', 'target_last_updated', 'parent_id']. + schema: + type: string + - name: job_id + in: query + description: A simple equality filter for the job_id field + schema: + type: integer + - name: org + in: query + description: Organization unique slug + schema: + type: string + - name: org_id + in: query + description: Organization identifier + schema: + type: integer + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: parent_id + in: query + description: A simple equality filter for the parent_id field + schema: + type: integer + - name: sort + required: false + in: query + description: 'Which field to use when ordering the results. Available ordering_fields: + [''id'', ''job_id'', ''created_date'', ''target_last_updated'', ''parent_id'']' + schema: + type: string + - in: query + name: target + schema: + type: string + description: A simple equality filter for target + - in: query + name: task_id + schema: + type: integer + description: A simple equality filter for task id + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/PaginatedConsensusReportList' + description: '' + post: + operationId: consensus_create_report + summary: Create a consensus report + parameters: + - in: query + name: rq_id + schema: + type: string + description: | + The report creation request id. Can be specified to check the report + creation status. + tags: + - consensus + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ConsensusReportCreateRequest' + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '201': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/ConsensusReport' + description: '' + '202': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/RqId' + description: | + A consensus report request has been enqueued, the request id is returned. + The request status can be checked at this endpoint by passing the rq_id + as the query parameter. If the request id is specified, this response + means the consensus report request is queued or is being processed. + '400': + description: Invalid or failed request, check the response data for details + /api/consensus/reports/{id}: + get: + operationId: consensus_retrieve_report + summary: Get consensus report details + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this consensus report. + required: true + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/ConsensusReport' + description: '' + /api/consensus/reports/{id}/data: + get: + operationId: consensus_retrieve_report_data + summary: Get consensus report contents + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this consensus report. + required: true + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + type: object + description: '' + /api/consensus/settings: + get: + operationId: consensus_list_settings + summary: List consensus settings instances + parameters: + - name: X-Organization + in: header + description: Organization unique slug + schema: + type: string + - name: filter + required: false + in: query + description: |2- + + JSON Logic filter. This filter can be used to perform complex filtering by grouping rules. + + For example, using such a filter you can get all resources created by you: + + - {"and":[{"==":[{"var":"owner"},""]}]} + + Details about the syntax used can be found at the link: https://jsonlogic.com/ + + Available filter_fields: ['id', 'task_id']. + schema: + type: string + - name: org + in: query + description: Organization unique slug + schema: + type: string + - name: org_id + in: query + description: Organization identifier + schema: + type: integer + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: sort + required: false + in: query + description: 'Which field to use when ordering the results. Available ordering_fields: + [''id'']' + schema: + type: string + - name: task_id + in: query + description: A simple equality filter for the task_id field + schema: + type: integer + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/PaginatedConsensusSettingsList' + description: '' + /api/consensus/settings/{id}: + get: + operationId: consensus_retrieve_settings + summary: Get consensus settings instance details + parameters: + - in: path + name: id + schema: + type: integer + description: An id of a consensus settings instance + required: true + tags: + - consensus + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/ConsensusSettings' + description: '' + patch: + operationId: consensus_partial_update_settings + summary: Update a consensus settings instance + parameters: + - in: path + name: id + schema: + type: integer + description: An id of a consensus settings instance + required: true + tags: + - consensus + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedConsensusSettingsRequest' + security: + - sessionAuth: [] + csrfAuth: [] + tokenAuth: [] + - signatureAuth: [] + - basicAuth: [] + responses: + '200': + content: + application/vnd.cvat+json: + schema: + $ref: '#/components/schemas/ConsensusSettings' + description: '' /api/events: get: operationId: events_list @@ -1804,7 +2352,7 @@ paths: Details about the syntax used can be found at the link: https://jsonlogic.com/ - Available filter_fields: ['task_name', 'project_name', 'assignee', 'state', 'stage', 'id', 'task_id', 'project_id', 'updated_date', 'dimension', 'type']. + Available filter_fields: ['task_name', 'project_name', 'assignee', 'state', 'stage', 'id', 'task_id', 'project_id', 'updated_date', 'dimension', 'type', 'parent_job_id']. schema: type: string - name: org @@ -1829,6 +2377,11 @@ paths: description: Number of results to return per page. schema: type: integer + - name: parent_job_id + in: query + description: A simple equality filter for the parent_job_id field + schema: + type: integer - name: project_id in: query description: A simple equality filter for the project_id field @@ -1851,7 +2404,8 @@ paths: in: query description: 'Which field to use when ordering the results. Available ordering_fields: [''task_name'', ''project_name'', ''assignee'', ''state'', ''stage'', ''id'', - ''task_id'', ''project_id'', ''updated_date'', ''dimension'', ''type'']' + ''task_id'', ''project_id'', ''updated_date'', ''dimension'', ''type'', + ''parent_job_id'']' schema: type: string - name: stage @@ -1891,6 +2445,7 @@ paths: enum: - annotation - ground_truth + - consensus tags: - jobs security: @@ -6829,7 +7384,7 @@ components: readOnly: true type: allOf: - - $ref: '#/components/schemas/AnnotationIdTypeEnum' + - $ref: '#/components/schemas/Type457Enum' readOnly: true shape_type: readOnly: true @@ -6837,16 +7392,6 @@ components: oneOf: - $ref: '#/components/schemas/ShapeType' - $ref: '#/components/schemas/NullEnum' - AnnotationIdTypeEnum: - enum: - - tag - - shape - - track - type: string - description: |- - * `tag` - TAG - * `shape` - SHAPE - * `track` - TRACK AnnotationsRead: oneOf: - $ref: '#/components/schemas/LabeledData' @@ -6859,20 +7404,41 @@ components: type: string format: uuid readOnly: true - filename: - type: string - maxLength: 1024 - created_date: - type: string - format: date-time + filename: + type: string + maxLength: 1024 + created_date: + type: string + format: date-time + readOnly: true + owner: + $ref: '#/components/schemas/BasicUser' + guide_id: + type: integer + readOnly: true + required: + - filename + AssigneeConsensusReport: + type: object + properties: + task_id: + type: integer + nullable: true + readOnly: true + assignee: + allOf: + - $ref: '#/components/schemas/BasicUser' + readOnly: true + nullable: true + consensus_score: + type: integer + readOnly: true + consensus_report_id: + type: integer readOnly: true - owner: - $ref: '#/components/schemas/BasicUser' - guide_id: + conflict_count: type: integer readOnly: true - required: - - filename Attribute: type: object properties: @@ -7217,6 +7783,159 @@ components: type: string format: uri readOnly: true + ConsensusAnnotationId: + type: object + properties: + obj_id: + type: integer + readOnly: true + job_id: + type: integer + readOnly: true + type: + allOf: + - $ref: '#/components/schemas/Type457Enum' + readOnly: true + shape_type: + readOnly: true + nullable: true + oneOf: + - $ref: '#/components/schemas/ShapeType' + - $ref: '#/components/schemas/NullEnum' + ConsensusConflict: + type: object + properties: + id: + type: integer + readOnly: true + frame: + type: integer + readOnly: true + type: + allOf: + - $ref: '#/components/schemas/ConsensusConflictTypeEnum' + readOnly: true + annotation_ids: + type: array + items: + $ref: '#/components/schemas/ConsensusAnnotationId' + report_id: + type: integer + readOnly: true + required: + - annotation_ids + ConsensusConflictTypeEnum: + enum: + - no_matching_item + - no_matching_annotation + - annotation_too_close + - failed_label_voting + type: string + description: |- + * `no_matching_item` - NoMatchingItemError + * `no_matching_annotation` - NoMatchingAnnError + * `annotation_too_close` - AnnotationsTooCloseError + * `failed_label_voting` - FailedLabelVotingError + ConsensusReport: + type: object + properties: + id: + type: integer + readOnly: true + job_id: + type: integer + nullable: true + readOnly: true + task_id: + type: integer + nullable: true + readOnly: true + parent_id: + type: integer + nullable: true + readOnly: true + summary: + $ref: '#/components/schemas/ConsensusReportSummary' + created_date: + type: string + format: date-time + readOnly: true + target_last_updated: + type: string + format: date-time + readOnly: true + target: + $ref: '#/components/schemas/QualityReportTarget' + assignee: + allOf: + - $ref: '#/components/schemas/BasicUser' + readOnly: true + nullable: true + consensus_score: + type: integer + readOnly: true + required: + - summary + - target + ConsensusReportCreateRequest: + type: object + properties: + task_id: + type: integer + writeOnly: true + job_id: + type: integer + writeOnly: true + ConsensusReportSummary: + type: object + properties: + frame_count: + type: integer + conflict_count: + type: integer + conflicts_by_type: + type: object + additionalProperties: + type: integer + required: + - conflict_count + - conflicts_by_type + - frame_count + ConsensusSettings: + type: object + properties: + id: + type: integer + readOnly: true + task_id: + type: integer + nullable: true + readOnly: true + iou_threshold: + type: number + format: double + description: Used for distinction between matched / unmatched shapes + agreement_score_threshold: + type: number + format: double + description: | + Confidence threshold for output annotations + quorum: + type: integer + maximum: 2147483647 + minimum: -2147483648 + description: | + Minimum count for a label and attribute voting results to be counted + sigma: + type: number + format: double + description: | + Sigma value for OKS calculation + line_thickness: + type: number + format: double + description: | + thickness of polylines, relatively to the (image area) ^ 0.5. CredentialsTypeEnum: enum: - KEY_SECRET_KEY_PAIR @@ -8118,6 +8837,10 @@ components: format: date-time readOnly: true nullable: true + parent_job_id: + type: integer + nullable: true + readOnly: true required: - issues - labels @@ -8145,10 +8868,12 @@ components: enum: - annotation - ground_truth + - consensus type: string description: |- * `annotation` - ANNOTATION * `ground_truth` - GROUND_TRUTH + * `consensus` - CONSENSUS JobWriteRequest: type: object properties: @@ -8194,7 +8919,7 @@ components: properties: count: type: integer - default: 0 + nullable: true completed: type: integer nullable: true @@ -8207,6 +8932,7 @@ components: readOnly: true required: - completed + - count - validation Label: type: object @@ -8749,6 +9475,26 @@ components: type: array items: $ref: '#/components/schemas/AnnotationConflict' + PaginatedAssigneeConsensusReportList: + type: object + properties: + count: + type: integer + example: 123 + next: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=4 + previous: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=2 + results: + type: array + items: + $ref: '#/components/schemas/AssigneeConsensusReport' PaginatedCloudStorageReadList: type: object properties: @@ -8789,6 +9535,66 @@ components: type: array items: $ref: '#/components/schemas/CommentRead' + PaginatedConsensusConflictList: + type: object + properties: + count: + type: integer + example: 123 + next: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=4 + previous: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=2 + results: + type: array + items: + $ref: '#/components/schemas/ConsensusConflict' + PaginatedConsensusReportList: + type: object + properties: + count: + type: integer + example: 123 + next: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=4 + previous: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=2 + results: + type: array + items: + $ref: '#/components/schemas/ConsensusReport' + PaginatedConsensusSettingsList: + type: object + properties: + count: + type: integer + example: 123 + next: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=4 + previous: + type: string + nullable: true + format: uri + example: http://api.example.org/accounts/?page=2 + results: + type: array + items: + $ref: '#/components/schemas/ConsensusSettings' PaginatedInvitationReadList: type: object properties: @@ -9185,6 +9991,34 @@ components: message: type: string minLength: 1 + PatchedConsensusSettingsRequest: + type: object + properties: + iou_threshold: + type: number + format: double + description: Used for distinction between matched / unmatched shapes + agreement_score_threshold: + type: number + format: double + description: | + Confidence threshold for output annotations + quorum: + type: integer + maximum: 2147483647 + minimum: -2147483648 + description: | + Minimum count for a label and attribute voting results to be counted + sigma: + type: number + format: double + description: | + Sigma value for OKS calculation + line_thickness: + type: number + format: double + description: | + thickness of polylines, relatively to the (image area) ^ 0.5. PatchedDataMetaWriteRequest: type: object properties: @@ -9460,6 +10294,9 @@ components: allOf: - $ref: '#/components/schemas/StorageRequest' nullable: true + consensus_jobs_per_regular_job: + type: integer + nullable: true PatchedUserRequest: type: object properties: @@ -10537,6 +11374,12 @@ components: format: date-time readOnly: true nullable: true + consensus_jobs_per_regular_job: + type: integer + maximum: 2147483647 + minimum: -2147483648 + readOnly: true + nullable: true required: - jobs - labels @@ -10585,6 +11428,9 @@ components: allOf: - $ref: '#/components/schemas/StorageRequest' nullable: true + consensus_jobs_per_regular_job: + type: integer + nullable: true required: - name TasksSummary: @@ -10695,6 +11541,16 @@ components: nullable: true required: - name + Type457Enum: + enum: + - tag + - shape + - track + type: string + description: |- + * `tag` - TAG + * `shape` - SHAPE + * `track` - TRACK User: type: object properties: diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 1cd454564ff..943a2f7543a 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -117,6 +117,7 @@ def generate_secret_key(): 'cvat.apps.events', 'cvat.apps.quality_control', 'cvat.apps.analytics_report', + 'cvat.apps.consensus', ] SITE_ID = 1 @@ -274,6 +275,7 @@ class CVAT_QUEUES(Enum): QUALITY_REPORTS = 'quality_reports' ANALYTICS_REPORTS = 'analytics_reports' CLEANING = 'cleaning' + CONSENSUS = 'consensus' redis_inmem_host = os.getenv('CVAT_REDIS_INMEM_HOST', 'localhost') redis_inmem_port = os.getenv('CVAT_REDIS_INMEM_PORT', 6379) @@ -319,6 +321,10 @@ class CVAT_QUEUES(Enum): **shared_queue_settings, 'DEFAULT_TIMEOUT': '1h', }, + CVAT_QUEUES.CONSENSUS.value: { + **shared_queue_settings, + 'DEFAULT_TIMEOUT': '1h', + }, } NUCLIO = { diff --git a/cvat/urls.py b/cvat/urls.py index 144ed619f76..0eb10869f84 100644 --- a/cvat/urls.py +++ b/cvat/urls.py @@ -51,3 +51,6 @@ if apps.is_installed('cvat.apps.analytics_report'): urlpatterns.append(path('api/', include('cvat.apps.analytics_report.urls'))) + +if apps.is_installed('cvat.apps.consensus'): + urlpatterns.append(path('api/', include('cvat.apps.consensus.urls'))) diff --git a/dev/format_python_code.sh b/dev/format_python_code.sh index 7eff923abb8..911fab3abea 100755 --- a/dev/format_python_code.sh +++ b/dev/format_python_code.sh @@ -23,6 +23,7 @@ for paths in \ "tests/python/" \ "cvat/apps/quality_control" \ "cvat/apps/analytics_report" \ + "cvat/apps/consensus" \ "cvat/apps/engine/lazy_list.py" \ "cvat/apps/engine/background.py" \ "cvat/apps/engine/frame_provider.py" \ diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index db9a8bbac01..e2d960aa25d 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -78,6 +78,18 @@ services: ports: - '9095:9095' + cvat_worker_consensus: + environment: + # For debugging, make sure to set 1 process + # Due to the supervisord specifics, the extra processes will fail and + # after few attempts supervisord will give up restarting, leaving only 1 process + # NUMPROCS: 1 + CVAT_DEBUG_ENABLED: '${CVAT_DEBUG_ENABLED:-no}' + CVAT_DEBUG_PORT: '9096' + COVERAGE_PROCESS_START: + ports: + - '9096:9096' + cvat_worker_annotation: environment: # For debugging, make sure to set 1 process diff --git a/docker-compose.external_db.yml b/docker-compose.external_db.yml index decd1e9ed14..6e56247c43f 100644 --- a/docker-compose.external_db.yml +++ b/docker-compose.external_db.yml @@ -27,6 +27,7 @@ services: cvat_worker_import: *backend-settings cvat_worker_quality_reports: *backend-settings cvat_worker_webhooks: *backend-settings + cvat_worker_consensus: *backend-settings secrets: postgres_password: diff --git a/docker-compose.tests.yml b/docker-compose.tests.yml new file mode 100644 index 00000000000..24c3792973c --- /dev/null +++ b/docker-compose.tests.yml @@ -0,0 +1,454 @@ +networks: + cvat: null +services: + cvat_clickhouse: + environment: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + image: clickhouse/clickhouse-server:23.11-alpine + networks: + cvat: + aliases: + - clickhouse + restart: always + volumes: + - ./components/analytics/clickhouse/init.sh:/docker-entrypoint-initdb.d/init.sh:ro + - cvat_events_db:/var/lib/clickhouse/ + cvat_db: + environment: + POSTGRES_DB: cvat + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_USER: root + image: postgres:15-alpine + networks: + - cvat + restart: always + volumes: + - cvat_db:/var/lib/postgresql/data + cvat_grafana: + entrypoint: + - sh + - -euc + - "mkdir -p /etc/grafana/provisioning/datasources\ncat << 'EOF' > /etc/grafana/provisioning/datasources/ds.yaml\n\ + apiVersion: 1\ndatasources:\n - name: 'ClickHouse'\n type: 'grafana-clickhouse-datasource'\n\ + \ isDefault: true\n jsonData:\n defaultDatabase: $${CLICKHOUSE_DB}\n\ + \ port: $${CLICKHOUSE_PORT}\n server: $${CLICKHOUSE_HOST}\n username:\ + \ $${CLICKHOUSE_USER}\n tlsSkipVerify: false\n protocol: http\n \ + \ secureJsonData:\n password: $${CLICKHOUSE_PASSWORD}\n editable: true\n\ + EOF\nmkdir -p /etc/grafana/provisioning/dashboards\ncat << EOF > /etc/grafana/provisioning/dashboards/dashboard.yaml\n\ + apiVersion: 1\nproviders:\n - name: cvat-logs\n type: file\n updateIntervalSeconds:\ + \ 30\n options:\n path: /var/lib/grafana/dashboards\n foldersFromFilesStructure:\ + \ true\nEOF\nexec /run.sh\n" + environment: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + GF_AUTH_ANONYMOUS_ENABLED: true + GF_AUTH_ANONYMOUS_ORG_ROLE: Admin + GF_AUTH_BASIC_ENABLED: false + GF_AUTH_DISABLE_LOGIN_FORM: true + GF_DASHBOARDS_DEFAULT_HOME_DASHBOARD_PATH: /var/lib/grafana/dashboards/all_events.json + GF_INSTALL_PLUGINS: https://github.com/grafana/clickhouse-datasource/releases/download/v4.0.8/grafana-clickhouse-datasource-4.0.8.linux_amd64.zip;grafana-clickhouse-datasource + GF_PATHS_PROVISIONING: /etc/grafana/provisioning + GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-clickhouse-datasource + GF_SERVER_ROOT_URL: http://${CVAT_HOST:-localhost}/analytics + image: grafana/grafana-oss:10.1.2 + networks: + cvat: + aliases: + - grafana + volumes: + - ./components/analytics/grafana/dashboards/:/var/lib/grafana/dashboards/:ro + cvat_opa: + command: + - run + - --server + - --log-level=error + - --set=services.cvat.url=http://cvat-server:8080 + - --set=bundles.cvat.service=cvat + - --set=bundles.cvat.resource=/api/auth/rules + - --set=bundles.cvat.polling.min_delay_seconds=5 + - --set=bundles.cvat.polling.max_delay_seconds=15 + image: openpolicyagent/opa:0.63.0 + networks: + cvat: + aliases: + - opa + restart: always + cvat_redis_inmem: + command: + - redis-server + - --save + - '60' + - '100' + - --appendonly + - 'yes' + image: redis:7.2.3-alpine + networks: + - cvat + restart: always + volumes: + - cvat_inmem_db:/data + cvat_redis_ondisk: + command: + - --dir + - /var/lib/kvrocks/data + image: apache/kvrocks:2.7.0 + init: true + networks: + - cvat + restart: always + volumes: + - cvat_cache_db:/var/lib/kvrocks/data + cvat_server: + command: init run server + depends_on: + cvat_db: &id001 + condition: service_started + cvat_opa: + condition: service_started + cvat_redis_inmem: &id002 + condition: service_started + cvat_redis_ondisk: &id003 + condition: service_started + environment: + ADAPTIVE_AUTO_ANNOTATION: 'false' + ALLOWED_HOSTS: '*' + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + COVERAGE_PROCESS_START: .coveragerc + CVAT_ANALYTICS: 1 + CVAT_BASE_URL: null + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + DJANGO_MODWSGI_EXTRA_ARGS: '' + DJANGO_SETTINGS_MODULE: cvat.settings.testing_rest + NUMPROCS: 2 + ONE_RUNNING_JOB_IN_QUEUE_PER_USER: null + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + labels: + - traefik.enable=true + - traefik.http.services.cvat.loadbalancer.server.port=8080 + - traefik.http.routers.cvat.rule=Host(`${CVAT_HOST:-localhost}`) && PathPrefix(`/api/`, + `/static/`, `/admin`, `/documentation/`, `/django-rq`) + - traefik.http.routers.cvat.entrypoints=web + networks: + cvat: + aliases: + - cvat-server + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_ui: + depends_on: + - cvat_server + image: cvat/ui:${CVAT_VERSION:-dev} + labels: + - traefik.enable=true + - traefik.http.services.cvat-ui.loadbalancer.server.port=80 + - traefik.http.routers.cvat-ui.rule=Host(`${CVAT_HOST:-localhost}`) + - traefik.http.routers.cvat-ui.entrypoints=web + networks: + - cvat + restart: always + cvat_utils: + command: run utils + depends_on: &id004 + cvat_db: *id001 + cvat_redis_inmem: *id002 + cvat_redis_ondisk: *id003 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PASSWORD: '' + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + DJANGO_SETTINGS_MODULE: cvat.settings.testing_rest + NUMPROCS: 1 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_vector: + depends_on: + - cvat_clickhouse + environment: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + image: timberio/vector:0.26.0-alpine + networks: + cvat: + aliases: + - vector + restart: always + volumes: + - ./components/analytics/vector/vector.toml:/etc/vector/vector.toml:ro + cvat_worker_analytics_reports: + command: run worker.analytics_reports + depends_on: *id004 + environment: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 2 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + cvat_worker_annotation: + command: run worker.annotation + depends_on: *id004 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 1 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_worker_export: + command: run worker.export + depends_on: *id004 + environment: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 2 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_worker_import: + command: run worker.import + depends_on: *id004 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 2 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_worker_quality_reports: + command: run worker.quality_reports + depends_on: *id004 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 1 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_worker_consensus: + command: run worker.consensus + depends_on: *id004 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 1 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + cvat_worker_webhooks: + command: run worker.webhooks + depends_on: *id004 + environment: + COVERAGE_PROCESS_START: .coveragerc + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + NUMPROCS: 1 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} + image: cvat/server:${CVAT_VERSION:-dev} + networks: + - cvat + restart: always + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + - ./tests/python/.coveragerc:/home/django/.coveragerc + traefik: + environment: + CVAT_HOST: ${CVAT_HOST:-localhost} + DJANGO_LOG_VIEWER_HOST: grafana + DJANGO_LOG_VIEWER_PORT: 3000 + TRAEFIK_ACCESSLOG_FORMAT: json + TRAEFIK_ENTRYPOINTS_web_ADDRESS: :8080 + TRAEFIK_LOG_FORMAT: json + TRAEFIK_PROVIDERS_DOCKER_EXPOSEDBYDEFAULT: 'false' + TRAEFIK_PROVIDERS_DOCKER_NETWORK: cvat + TRAEFIK_PROVIDERS_FILE_DIRECTORY: /etc/traefik/rules + image: traefik:v2.10 + logging: + driver: json-file + options: + max-file: '10' + max-size: 100m + networks: + - cvat + ports: + - 8080:8080 + - 8090:8090 + restart: always + volumes: + - /var/run/docker.sock:/var/run/docker.sock:ro + - ./components/analytics/grafana_conf.yml:/etc/traefik/rules/grafana_conf.yml:ro +volumes: + cvat_cache_db: null + cvat_data: null + cvat_db: null + cvat_events_db: null + cvat_inmem_db: null + cvat_keys: null + cvat_logs: null +x-backend-deps: *id004 +x-backend-env: + CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_POSTGRES_HOST: cvat_db + CVAT_REDIS_INMEM_HOST: cvat_redis_inmem + CVAT_REDIS_INMEM_PORT: 6379 + CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk + CVAT_REDIS_ONDISK_PORT: 6666 + DJANGO_LOG_SERVER_HOST: vector + DJANGO_LOG_SERVER_PORT: 80 + SMOKESCREEN_OPTS: ${SMOKESCREEN_OPTS:-} + no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} +x-clickhouse-env: + CLICKHOUSE_DB: cvat + CLICKHOUSE_HOST: clickhouse + CLICKHOUSE_PASSWORD: user + CLICKHOUSE_PORT: 8123 + CLICKHOUSE_USER: user diff --git a/docker-compose.yml b/docker-compose.yml index 569e163e9fe..1a7503d33db 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -224,6 +224,22 @@ services: networks: - cvat + cvat_worker_consensus: + container_name: cvat_worker_consensus + image: cvat/server:${CVAT_VERSION:-dev} + restart: always + depends_on: *backend-deps + environment: + <<: *backend-env + NUMPROCS: 1 + command: run worker.consensus + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + networks: + - cvat + cvat_ui: container_name: cvat_ui image: cvat/ui:${CVAT_VERSION:-dev} diff --git a/helm-chart/templates/cvat_backend/worker_consensus/deployment.yml b/helm-chart/templates/cvat_backend/worker_consensus/deployment.yml new file mode 100644 index 00000000000..5d898a0149a --- /dev/null +++ b/helm-chart/templates/cvat_backend/worker_consensus/deployment.yml @@ -0,0 +1,82 @@ +{{- $localValues := .Values.cvat.backend.worker.consensus -}} + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Release.Name }}-backend-worker-consensus + namespace: {{ .Release.Namespace }} + labels: + app: cvat-app + tier: backend + component: worker-consensus + {{- include "cvat.labels" . | nindent 4 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 4 }} + {{- end }} + {{- with merge $localValues.annotations .Values.cvat.backend.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + replicas: {{ $localValues.replicas }} + strategy: + type: Recreate + selector: + matchLabels: + {{- include "cvat.labels" . | nindent 6 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 6 }} + {{- end }} + app: cvat-app + tier: backend + component: worker-consensus + template: + metadata: + labels: + app: cvat-app + tier: backend + component: worker-consensus + {{- include "cvat.labels" . | nindent 8 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with merge $localValues.annotations .Values.cvat.backend.annotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + serviceAccountName: {{ include "cvat.backend.serviceAccountName" . }} + containers: + - name: cvat-backend + image: {{ .Values.cvat.backend.image }}:{{ .Values.cvat.backend.tag }} + imagePullPolicy: {{ .Values.cvat.backend.imagePullPolicy }} + {{- with merge $localValues.resources .Values.cvat.backend.resources }} + resources: + {{- toYaml . | nindent 12 }} + {{- end }} + args: ["run", "worker.consensus"] + env: + {{ include "cvat.sharedBackendEnv" . | indent 10 }} + {{- with concat .Values.cvat.backend.additionalEnv $localValues.additionalEnv }} + {{- toYaml . | nindent 10 }} + {{- end }} + {{- with concat .Values.cvat.backend.additionalVolumeMounts $localValues.additionalVolumeMounts }} + volumeMounts: + {{- toYaml . | nindent 10 }} + {{- end }} + {{- with merge $localValues.affinity .Values.cvat.backend.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with concat .Values.cvat.backend.tolerations $localValues.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with concat .Values.cvat.backend.additionalVolumes $localValues.additionalVolumes }} + volumes: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/supervisord/worker.consensus.conf b/supervisord/worker.consensus.conf new file mode 100644 index 00000000000..1072226b778 --- /dev/null +++ b/supervisord/worker.consensus.conf @@ -0,0 +1,27 @@ +[unix_http_server] +file = /tmp/supervisord/supervisor.sock + +[supervisorctl] +serverurl = unix:///tmp/supervisord/supervisor.sock + + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[supervisord] +nodaemon=true +logfile=%(ENV_HOME)s/logs/supervisord.log ; supervisord log file +logfile_maxbytes=50MB ; maximum size of logfile before rotation +logfile_backups=10 ; number of backed up logfiles +loglevel=debug ; info, debug, warn, trace +pidfile=/tmp/supervisord/supervisord.pid ; pidfile location +childlogdir=%(ENV_HOME)s/logs/ ; where child log files will live + +[program:rqworker-consensus] +command=%(ENV_HOME)s/wait_for_deps.sh + python3 %(ENV_HOME)s/manage.py rqworker -v 3 consensus + --worker-class cvat.rqworker.DefaultWorker +environment=VECTOR_EVENT_HANDLER="SynchronousLogstashHandler",CVAT_POSTGRES_APPLICATION_NAME="cvat:worker:consensus" +numprocs=%(ENV_NUMPROCS)s +process_name=%(program_name)s-%(process_num)d +autorestart=true diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index fd5db6388b3..456d3a14cd5 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -443,6 +443,28 @@ def test_can_create_with_assignee(self, admin_user, users_by_name, assignee): assert task.assignee is None assert task.assignee_updated_date is None + @pytest.mark.parametrize( + "consensus_jobs_per_regular_job, success", [(0, True), (1, False), (2, True), (11, False)] + ) + def test_can_create_with_consensus_jobs_per_regular_job( + self, admin_user, consensus_jobs_per_regular_job, success + ): + task_spec = { + "name": "test task creation with assignee", + "labels": [{"name": "car"}], + "consensus_jobs_per_regular_job": consensus_jobs_per_regular_job, + } + + with make_api_client(admin_user) as api_client: + if success: + (task, response) = api_client.tasks_api.create(task_write_request=task_spec) + assert response.status == HTTPStatus.CREATED + assert task.consensus_jobs_per_regular_job == consensus_jobs_per_regular_job + else: + with pytest.raises(ApiException) as exc: + _ = api_client.tasks_api.create(task_write_request=task_spec) + assert exc.status == HTTPStatus.BAD_REQUEST + @pytest.mark.usefixtures("restore_db_per_class") class TestGetData: @@ -2884,6 +2906,7 @@ def test_can_import_backup_for_task_in_nondefault_state(self, tasks, mode): task = self.client.tasks.retrieve(task_json["id"]) jobs = task.get_jobs() for j in jobs: + # print(j) j.update({"stage": "validation"}) self._test_can_restore_backup_task(task_json["id"]) diff --git a/tests/python/shared/fixtures/data.py b/tests/python/shared/fixtures/data.py index 6c328336dd1..22c1f8cbde2 100644 --- a/tests/python/shared/fixtures/data.py +++ b/tests/python/shared/fixtures/data.py @@ -226,6 +226,30 @@ def quality_settings(): return Container(json.load(f)["results"]) +@pytest.fixture(scope="session") +def consensus_reports(): + with open(ASSETS_DIR / "consensus_reports.json") as f: + return Container(json.load(f)["results"]) + + +@pytest.fixture(scope="session") +def consensus_conflicts(): + with open(ASSETS_DIR / "consensus_conflicts.json") as f: + return Container(json.load(f)["results"]) + + +@pytest.fixture(scope="session") +def consensus_assignee_reports(): + with open(ASSETS_DIR / "consensus_assignee_reports.json") as f: + return Container(json.load(f)["results"]) + + +@pytest.fixture(scope="session") +def consensus_settings(): + with open(ASSETS_DIR / "consensus_settings.json") as f: + return Container(json.load(f)["results"]) + + @pytest.fixture(scope="session") def users_by_name(users): return {user["username"]: user for user in users} diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py index 4a17454617d..0dda076006d 100644 --- a/tests/python/shared/fixtures/init.py +++ b/tests/python/shared/fixtures/init.py @@ -309,9 +309,8 @@ def delete_compose_files(container_name_files): def wait_for_services(num_secs: int = 300) -> None: for i in range(num_secs): logger.debug(f"waiting for the server to load ... ({i})") - response = requests.get(get_server_url("api/server/health/", format="json")) - try: + response = requests.get(get_server_url("api/server/health/", format="json")) statuses = response.json() logger.debug(f"server status: \n{statuses}") diff --git a/tests/python/shared/utils/dump_objects.py b/tests/python/shared/utils/dump_objects.py index ecab740f0ec..d2f89a7e9f9 100644 --- a/tests/python/shared/utils/dump_objects.py +++ b/tests/python/shared/utils/dump_objects.py @@ -26,6 +26,10 @@ "quality/report", "quality/conflict", "quality/setting", + "consensus/report", + "consensus/conflict", + "consensus/setting", + "consensus/assignee_report", ]: response = get_method("admin1", f"{obj}s", page_size="all")