diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index 9502e7a0b18..c6369340b5c 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -165,6 +165,8 @@ jobs: id: run_tests run: | pytest tests/python/ + ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py + CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python - name: Creating a log file from cvat containers if: failure() && steps.run_tests.conclusion == 'failure' diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 75b200597f4..0c9211b0c4a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -177,8 +177,9 @@ jobs: COVERAGE_PROCESS_START: ".coveragerc" run: | pytest tests/python/ --cov --cov-report=json - for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py --cov --cov-report=json + CVAT_ALLOW_STATIC_CACHE="true" pytest -k "TestTaskData" tests/python --cov --cov-report=json + for COVERAGE_FILE in `find -name "coverage*.json" -type f -printf "%f\n"`; do mv ${COVERAGE_FILE} "${COVERAGE_FILE%%.*}_0.json"; done - name: Uploading code coverage results as an artifact uses: actions/upload-artifact@v4 diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml index d8e514cbb44..c2071cd85d1 100644 --- a/.github/workflows/schedule.yml +++ b/.github/workflows/schedule.yml @@ -170,6 +170,12 @@ jobs: pytest tests/python/ pytest tests/python/ --stop-services + ONE_RUNNING_JOB_IN_QUEUE_PER_USER="true" pytest tests/python/rest_api/test_queues.py + pytest tests/python/ --stop-services + + CVAT_ALLOW_STATIC_CACHE="true" pytest tests/python + pytest tests/python/ --stop-services + - name: Unit tests env: HOST_COVERAGE_DATA_DIR: ${{ github.workspace }} diff --git a/changelog.d/20240812_161617_mzhiltso_job_chunks.md b/changelog.d/20240812_161617_mzhiltso_job_chunks.md new file mode 100644 index 00000000000..af931641d6d --- /dev/null +++ b/changelog.d/20240812_161617_mzhiltso_job_chunks.md @@ -0,0 +1,24 @@ +### Added + +- A server setting to enable or disable storage of permanent media chunks on the server filesystem + () +- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&index=x` parameter combination. + The new `index` parameter allows to retrieve job chunks using 0-based index in each job, + instead of the `number` parameter, which used task chunk ids. + () + +### Changed + +- Job assignees will not receive frames from adjacent jobs in chunks + () + +### Deprecated + +- \[Server API\] `GET /api/jobs/{id}/data/?type=chunk&number=x` parameter combination + () + + +### Fixed + +- Various memory leaks in video reading on the server + () diff --git a/cvat-core/src/frames.ts b/cvat-core/src/frames.ts index 96295af7d57..dda847cf7a7 100644 --- a/cvat-core/src/frames.ts +++ b/cvat-core/src/frames.ts @@ -3,7 +3,7 @@ // // SPDX-License-Identifier: MIT -import _ from 'lodash'; +import _, { range, sortedIndexOf } from 'lodash'; import { FrameDecoder, BlockType, DimensionType, ChunkQuality, decodeContextImages, RequestOutdatedError, } from 'cvat-data'; @@ -25,7 +25,7 @@ const frameDataCache: Record | null; activeContextRequest: Promise> | null; @@ -34,7 +34,7 @@ const frameDataCache: Record; - getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise; + getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise; }> = {}; // frame meta data storage by job id @@ -55,6 +55,8 @@ export class FramesMetaData { public size: number; public startFrame: number; public stopFrame: number; + public frameStep: number; + public chunkCount: number; #updateTrigger: FieldUpdateTrigger; @@ -103,6 +105,17 @@ export class FramesMetaData { } } + const frameStep: number = (() => { + if (data.frame_filter) { + const frameStepParts = data.frame_filter.split('=', 2); + if (frameStepParts.length !== 2) { + throw new ArgumentError(`Invalid frame filter '${data.frame_filter}'`); + } + return +frameStepParts[1]; + } + return 1; + })(); + Object.defineProperties( this, Object.freeze({ @@ -133,6 +146,20 @@ export class FramesMetaData { stopFrame: { get: () => data.stop_frame, }, + frameStep: { + get: () => frameStep, + }, + }), + ); + + const chunkCount: number = Math.ceil(this.getDataFrameNumbers().length / this.chunkSize); + + Object.defineProperties( + this, + Object.freeze({ + chunkCount: { + get: () => chunkCount, + }, }), ); } @@ -144,6 +171,40 @@ export class FramesMetaData { resetUpdated(): void { this.#updateTrigger.reset(); } + + getFrameIndex(dataFrameNumber: number): number { + // Here we use absolute (task source data) frame numbers. + // TODO: migrate from data frame numbers to local frame numbers to simplify code. + // Requires server changes in api/jobs/{id}/data/meta/ + // for included_frames, start_frame, stop_frame fields + + if (dataFrameNumber < this.startFrame || dataFrameNumber > this.stopFrame) { + throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`); + } + + let frameIndex = null; + if (this.includedFrames) { + frameIndex = sortedIndexOf(this.includedFrames, dataFrameNumber); + if (frameIndex === -1) { + throw new ArgumentError(`Frame number ${dataFrameNumber} doesn't belong to the job`); + } + } else { + frameIndex = Math.floor((dataFrameNumber - this.startFrame) / this.frameStep); + } + return frameIndex; + } + + getFrameChunkIndex(dataFrameNumber: number): number { + return Math.floor(this.getFrameIndex(dataFrameNumber) / this.chunkSize); + } + + getDataFrameNumbers(): number[] { + if (this.includedFrames) { + return this.includedFrames; + } + + return range(this.startFrame, this.stopFrame + 1, this.frameStep); + } } export class FrameData { @@ -206,12 +267,14 @@ export class FrameData { } class PrefetchAnalyzer { - #chunkSize: number; #requestedFrames: number[]; + #meta: FramesMetaData; + #getDataFrameNumber: (frameNumber: number) => number; - constructor(chunkSize) { - this.#chunkSize = chunkSize; + constructor(meta: FramesMetaData, dataFrameNumberGetter: (frameNumber: number) => number) { this.#requestedFrames = []; + this.#meta = meta; + this.#getDataFrameNumber = dataFrameNumberGetter; } shouldPrefetchNext(current: number, isPlaying: boolean, isChunkCached: (chunk) => boolean): boolean { @@ -219,13 +282,16 @@ class PrefetchAnalyzer { return true; } - const currentChunk = Math.floor(current / this.#chunkSize); + const currentDataFrameNumber = this.#getDataFrameNumber(current); + const currentChunk = this.#meta.getFrameChunkIndex(currentDataFrameNumber); const { length } = this.#requestedFrames; const isIncreasingOrder = this.#requestedFrames .every((val, index) => index === 0 || val > this.#requestedFrames[index - 1]); if ( length && (isIncreasingOrder && current > this.#requestedFrames[length - 1]) && - (current % this.#chunkSize) >= Math.ceil(this.#chunkSize / 2) && + ( + this.#meta.getFrameIndex(currentDataFrameNumber) % this.#meta.chunkSize + ) >= Math.ceil(this.#meta.chunkSize / 2) && !isChunkCached(currentChunk + 1) ) { // is increasing order including the current frame @@ -247,13 +313,25 @@ class PrefetchAnalyzer { this.#requestedFrames.push(frame); // only half of chunk size is considered in this logic - const limit = Math.ceil(this.#chunkSize / 2); + const limit = Math.ceil(this.#meta.chunkSize / 2); if (this.#requestedFrames.length > limit) { this.#requestedFrames.shift(); } } } +function getDataStartFrame(meta: FramesMetaData, localStartFrame: number): number { + return meta.startFrame - localStartFrame * meta.frameStep; +} + +function getDataFrameNumber(frameNumber: number, dataStartFrame: number, step: number): number { + return frameNumber * step + dataStartFrame; +} + +function getFrameNumber(dataFrameNumber: number, dataStartFrame: number, step: number): number { + return (dataFrameNumber - dataStartFrame) / step; +} + Object.defineProperty(FrameData.prototype.data, 'implementation', { value(this: FrameData, onServerRequest) { return new Promise<{ @@ -262,40 +340,57 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { imageData: ImageBitmap | Blob; } | Blob>((resolve, reject) => { const { - provider, prefetchAnalizer, chunkSize, stopFrame, decodeForward, forwardStep, decodedBlocksCacheSize, + meta, provider, prefetchAnalyzer, chunkSize, startFrame, + decodeForward, forwardStep, decodedBlocksCacheSize, } = frameDataCache[this.jobID]; const requestId = +_.uniqueId(); - const chunkNumber = Math.floor(this.number / chunkSize); + const dataStartFrame = getDataStartFrame(meta, startFrame); + const requestedDataFrameNumber = getDataFrameNumber( + this.number, dataStartFrame, meta.frameStep, + ); + const chunkIndex = meta.getFrameChunkIndex(requestedDataFrameNumber); + const segmentFrameNumbers = meta.getDataFrameNumbers().map( + (dataFrameNumber: number) => getFrameNumber( + dataFrameNumber, dataStartFrame, meta.frameStep, + ), + ); const frame = provider.frame(this.number); - function findTheNextNotDecodedChunk(searchFrom: number): number { - let firstFrameInNextChunk = searchFrom + forwardStep; - let nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize); - while (nextChunkNumber === chunkNumber) { - firstFrameInNextChunk += forwardStep; - nextChunkNumber = Math.floor(firstFrameInNextChunk / chunkSize); + function findTheNextNotDecodedChunk(currentFrameIndex: number): number | null { + const { chunkCount } = meta; + let nextFrameIndex = currentFrameIndex + forwardStep; + let nextChunkIndex = Math.floor(nextFrameIndex / chunkSize); + while (nextChunkIndex === chunkIndex) { + nextFrameIndex += forwardStep; + nextChunkIndex = Math.floor(nextFrameIndex / chunkSize); } - if (provider.isChunkCached(nextChunkNumber)) { - return findTheNextNotDecodedChunk(firstFrameInNextChunk); + if (nextChunkIndex < 0 || chunkCount <= nextChunkIndex) { + return null; } - return nextChunkNumber; + if (provider.isChunkCached(nextChunkIndex)) { + return findTheNextNotDecodedChunk(nextFrameIndex); + } + + return nextChunkIndex; } if (frame) { if ( - prefetchAnalizer.shouldPrefetchNext( + prefetchAnalyzer.shouldPrefetchNext( this.number, decodeForward, (chunk) => provider.isChunkCached(chunk), ) && decodedBlocksCacheSize > 1 && !frameDataCache[this.jobID].activeChunkRequest ) { - const nextChunkNumber = findTheNextNotDecodedChunk(this.number); + const nextChunkIndex = findTheNextNotDecodedChunk( + meta.getFrameIndex(requestedDataFrameNumber), + ); const predecodeChunksMax = Math.floor(decodedBlocksCacheSize / 2); - if (nextChunkNumber * chunkSize <= stopFrame && - nextChunkNumber <= chunkNumber + predecodeChunksMax + if (nextChunkIndex !== null && + nextChunkIndex <= chunkIndex + predecodeChunksMax ) { frameDataCache[this.jobID].activeChunkRequest = new Promise((resolveForward) => { const releasePromise = (): void => { @@ -304,7 +399,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { }; frameDataCache[this.jobID].getChunk( - nextChunkNumber, ChunkQuality.COMPRESSED, + nextChunkIndex, ChunkQuality.COMPRESSED, ).then((chunk: ArrayBuffer) => { if (!(this.jobID in frameDataCache)) { // check if frameDataCache still exist @@ -316,8 +411,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { provider.cleanup(1); provider.requestDecodeBlock( chunk, - nextChunkNumber * chunkSize, - Math.min(stopFrame, (nextChunkNumber + 1) * chunkSize - 1), + nextChunkIndex, + segmentFrameNumbers.slice( + nextChunkIndex * chunkSize, + (nextChunkIndex + 1) * chunkSize, + ), () => {}, releasePromise, releasePromise, @@ -334,7 +432,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: frame, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); return; } @@ -355,7 +453,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: currentFrame, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); return; } @@ -364,7 +462,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { ) => { let wasResolved = false; frameDataCache[this.jobID].getChunk( - chunkNumber, ChunkQuality.COMPRESSED, + chunkIndex, ChunkQuality.COMPRESSED, ).then((chunk: ArrayBuffer) => { try { if (!(this.jobID in frameDataCache)) { @@ -378,8 +476,11 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { provider .requestDecodeBlock( chunk, - chunkNumber * chunkSize, - Math.min(stopFrame, (chunkNumber + 1) * chunkSize - 1), + chunkIndex, + segmentFrameNumbers.slice( + chunkIndex * chunkSize, + (chunkIndex + 1) * chunkSize, + ), (_frame: number, bitmap: ImageBitmap | Blob) => { if (decodeForward) { // resolve immediately only if is not playing @@ -395,7 +496,7 @@ Object.defineProperty(FrameData.prototype.data, 'implementation', { renderHeight: this.height, imageData: bitmap, }); - prefetchAnalizer.addRequested(this.number); + prefetchAnalyzer.addRequested(this.number); } }, () => { frameDataCache[this.jobID].activeChunkRequest = null; @@ -592,7 +693,7 @@ export async function getFrame( isPlaying: boolean, step: number, dimension: DimensionType, - getChunk: (chunkNumber: number, quality: ChunkQuality) => Promise, + getChunk: (chunkIndex: number, quality: ChunkQuality) => Promise, ): Promise { if (!(jobID in frameDataCache)) { const blockType = chunkType === 'video' ? BlockType.MP4VIDEO : BlockType.ARCHIVE; @@ -608,6 +709,13 @@ export async function getFrame( const decodedBlocksCacheSize = Math.min( Math.floor((2048 * 1024 * 1024) / ((mean + stdDev) * 4 * chunkSize)) || 1, 10, ); + + // TODO: migrate to local frame numbers + const dataStartFrame = getDataStartFrame(meta, startFrame); + const dataFrameNumberGetter = (frameNumber: number): number => ( + getDataFrameNumber(frameNumber, dataStartFrame, meta.frameStep) + ); + frameDataCache[jobID] = { meta, chunkSize, @@ -618,11 +726,13 @@ export async function getFrame( forwardStep: step, provider: new FrameDecoder( blockType, - chunkSize, decodedBlocksCacheSize, + (frameNumber: number): number => ( + meta.getFrameChunkIndex(dataFrameNumberGetter(frameNumber)) + ), dimension, ), - prefetchAnalizer: new PrefetchAnalyzer(chunkSize), + prefetchAnalyzer: new PrefetchAnalyzer(meta, dataFrameNumberGetter), decodedBlocksCacheSize, activeChunkRequest: null, activeContextRequest: null, @@ -697,8 +807,11 @@ export async function findFrame( let lastUndeletedFrame = null; const check = (frame): boolean => { if (meta.includedFrames) { - return (meta.includedFrames.includes(frame)) && - (!filters.notDeleted || !(frame in meta.deletedFrames)); + // meta.includedFrames contains input frame numbers now + const dataStartFrame = meta.startFrame; // this is only true when includedFrames is set + return (meta.includedFrames.includes( + getDataFrameNumber(frame, dataStartFrame, meta.frameStep)) + ) && (!filters.notDeleted || !(frame in meta.deletedFrames)); } if (filters.notDeleted) { return !(frame in meta.deletedFrames); @@ -726,6 +839,18 @@ export function getCachedChunks(jobID): number[] { return frameDataCache[jobID].provider.cachedChunks(true); } +export function getJobFrameNumbers(jobID): number[] { + if (!(jobID in frameDataCache)) { + return []; + } + + const { meta, startFrame } = frameDataCache[jobID]; + const dataStartFrame = getDataStartFrame(meta, startFrame); + return meta.getDataFrameNumbers().map((dataFrameNumber: number): number => ( + getFrameNumber(dataFrameNumber, dataStartFrame, meta.frameStep) + )); +} + export function clear(jobID: number): void { if (jobID in frameDataCache) { frameDataCache[jobID].provider.close(); diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 91dc52a7182..51309426198 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -1438,7 +1438,7 @@ async function getData(jid: number, chunk: number, quality: ChunkQuality, retry ...enableOrganization(), quality, type: 'chunk', - number: chunk, + index: chunk, }, responseType: 'arraybuffer', }); diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index fa77c934abd..96177170872 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -18,6 +18,7 @@ import { deleteFrame, restoreFrame, getCachedChunks, + getJobFrameNumbers, clear as clearFrames, findFrame, getContextImage, @@ -189,7 +190,7 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { isPlaying, step, this.dimension, - (chunkNumber, quality) => this.frames.chunk(chunkNumber, quality), + (chunkIndex, quality) => this.frames.chunk(chunkIndex, quality), ); }, }); @@ -244,6 +245,14 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { }, }); + Object.defineProperty(Job.prototype.frames.frameNumbers, 'implementation', { + value: function includedFramesImplementation( + this: JobClass, + ): ReturnType { + return Promise.resolve(getJobFrameNumbers(this.id)); + }, + }); + Object.defineProperty(Job.prototype.frames.preview, 'implementation', { value: function previewImplementation( this: JobClass, @@ -273,10 +282,10 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { Object.defineProperty(Job.prototype.frames.chunk, 'implementation', { value: function chunkImplementation( this: JobClass, - chunkNumber: Parameters[0], + chunkIndex: Parameters[0], quality: Parameters[1], ): ReturnType { - return serverProxy.frames.getData(this.id, chunkNumber, quality); + return serverProxy.frames.getData(this.id, chunkIndex, quality); }, }); @@ -829,7 +838,7 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { isPlaying, step, this.dimension, - (chunkNumber, quality) => job.frames.chunk(chunkNumber, quality), + (chunkIndex, quality) => job.frames.chunk(chunkIndex, quality), ); return result; }, diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index 1985a72b268..54133ff6b66 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -233,6 +233,10 @@ function buildDuplicatedAPI(prototype) { const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.cachedChunks); return result; }, + async frameNumbers() { + const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.frameNumbers); + return result; + }, async preview() { const result = await PluginRegistry.apiWrapper.call(this, prototype.frames.preview); return result; @@ -255,11 +259,11 @@ function buildDuplicatedAPI(prototype) { ); return result; }, - async chunk(chunkNumber, quality) { + async chunk(chunkIndex, quality) { const result = await PluginRegistry.apiWrapper.call( this, prototype.frames.chunk, - chunkNumber, + chunkIndex, quality, ); return result; @@ -380,6 +384,7 @@ export class Session { restore: (frame: number) => Promise; save: () => Promise; cachedChunks: () => Promise; + frameNumbers: () => Promise; preview: () => Promise; contextImage: (frame: number) => Promise>; search: ( @@ -443,6 +448,7 @@ export class Session { restore: Object.getPrototypeOf(this).frames.restore.bind(this), save: Object.getPrototypeOf(this).frames.save.bind(this), cachedChunks: Object.getPrototypeOf(this).frames.cachedChunks.bind(this), + frameNumbers: Object.getPrototypeOf(this).frames.frameNumbers.bind(this), preview: Object.getPrototypeOf(this).frames.preview.bind(this), search: Object.getPrototypeOf(this).frames.search.bind(this), contextImage: Object.getPrototypeOf(this).frames.contextImage.bind(this), diff --git a/cvat-data/src/ts/cvat-data.ts b/cvat-data/src/ts/cvat-data.ts index 2f832ac9d3f..baf00ac443c 100644 --- a/cvat-data/src/ts/cvat-data.ts +++ b/cvat-data/src/ts/cvat-data.ts @@ -72,8 +72,8 @@ export function decodeContextImages( decodeContextImages.mutex = new Mutex(); interface BlockToDecode { - start: number; - end: number; + chunkFrameNumbers: number[]; + chunkIndex: number; block: ArrayBuffer; onDecodeAll(): void; onDecode(frame: number, bitmap: ImageBitmap | Blob): void; @@ -82,7 +82,6 @@ interface BlockToDecode { export class FrameDecoder { private blockType: BlockType; - private chunkSize: number; /* ImageBitmap when decode zip or video chunks Blob when 3D dimension @@ -100,11 +99,12 @@ export class FrameDecoder { private renderHeight: number; private zipWorker: Worker | null; private videoWorker: Worker | null; + private getChunkIndex: (frame: number) => number; constructor( blockType: BlockType, - chunkSize: number, cachedBlockCount: number, + getChunkIndex: (frame: number) => number, dimension: DimensionType = DimensionType.DIMENSION_2D, ) { this.mutex = new Mutex(); @@ -117,7 +117,7 @@ export class FrameDecoder { this.renderWidth = 1920; this.renderHeight = 1080; - this.chunkSize = chunkSize; + this.getChunkIndex = getChunkIndex; this.blockType = blockType; this.decodedChunks = {}; @@ -125,8 +125,8 @@ export class FrameDecoder { this.chunkIsBeingDecoded = null; } - isChunkCached(chunkNumber: number): boolean { - return chunkNumber in this.decodedChunks; + isChunkCached(chunkIndex: number): boolean { + return chunkIndex in this.decodedChunks; } hasFreeSpace(): boolean { @@ -155,17 +155,37 @@ export class FrameDecoder { } } + private validateFrameNumbers(frameNumbers: number[]): void { + if (!Array.isArray(frameNumbers) || !frameNumbers.length) { + throw new Error('chunkFrameNumbers must not be empty'); + } + + // ensure is ordered + for (let i = 1; i < frameNumbers.length; ++i) { + const prev = frameNumbers[i - 1]; + const current = frameNumbers[i]; + if (current <= prev) { + throw new Error( + 'chunkFrameNumbers must be sorted in the ascending order, ' + + `got a (${prev}, ${current}) pair instead`, + ); + } + } + } + requestDecodeBlock( block: ArrayBuffer, - start: number, - end: number, + chunkIndex: number, + chunkFrameNumbers: number[], onDecode: (frame: number, bitmap: ImageBitmap | Blob) => void, onDecodeAll: () => void, onReject: (e: Error) => void, ): void { + this.validateFrameNumbers(chunkFrameNumbers); + if (this.requestedChunkToDecode !== null) { // a chunk was already requested to be decoded, but decoding didn't start yet - if (start === this.requestedChunkToDecode.start && end === this.requestedChunkToDecode.end) { + if (chunkIndex === this.requestedChunkToDecode.chunkIndex) { // it was the same chunk this.requestedChunkToDecode.onReject(new RequestOutdatedError()); @@ -175,12 +195,14 @@ export class FrameDecoder { // it was other chunk this.requestedChunkToDecode.onReject(new RequestOutdatedError()); } - } else if (this.chunkIsBeingDecoded === null || this.chunkIsBeingDecoded.start !== start) { + } else if (this.chunkIsBeingDecoded === null || + chunkIndex !== this.chunkIsBeingDecoded.chunkIndex + ) { // everything was decoded or decoding other chunk is in process this.requestedChunkToDecode = { + chunkFrameNumbers, + chunkIndex, block, - start, - end, onDecode, onDecodeAll, onReject, @@ -203,9 +225,9 @@ export class FrameDecoder { } frame(frameNumber: number): ImageBitmap | Blob | null { - const chunkNumber = Math.floor(frameNumber / this.chunkSize); - if (chunkNumber in this.decodedChunks) { - return this.decodedChunks[chunkNumber][frameNumber]; + const chunkIndex = this.getChunkIndex(frameNumber); + if (chunkIndex in this.decodedChunks) { + return this.decodedChunks[chunkIndex][frameNumber]; } return null; @@ -253,8 +275,8 @@ export class FrameDecoder { releaseMutex(); }; try { - const { start, end, block } = this.requestedChunkToDecode; - if (start !== blockToDecode.start) { + const { chunkFrameNumbers, chunkIndex, block } = this.requestedChunkToDecode; + if (chunkIndex !== blockToDecode.chunkIndex) { // request is not relevant, another block was already requested // it happens when A is being decoded, B comes and wait for mutex, C comes and wait for mutex // B is not necessary anymore, because C already was requested @@ -262,8 +284,11 @@ export class FrameDecoder { throw new RequestOutdatedError(); } - const chunkNumber = Math.floor(start / this.chunkSize); - this.orderedStack = [chunkNumber, ...this.orderedStack]; + const getFrameNumber = (chunkFrameIndex: number): number => ( + chunkFrameNumbers[chunkFrameIndex] + ); + + this.orderedStack = [chunkIndex, ...this.orderedStack]; this.cleanup(); const decodedFrames: Record = {}; this.chunkIsBeingDecoded = this.requestedChunkToDecode; @@ -273,7 +298,7 @@ export class FrameDecoder { this.videoWorker = new Worker( new URL('./3rdparty/Decoder.worker', import.meta.url), ); - let index = start; + let index = 0; this.videoWorker.onmessage = (e) => { if (e.data.consoleLog) { @@ -281,6 +306,7 @@ export class FrameDecoder { return; } const keptIndex = index; + const frameNumber = getFrameNumber(keptIndex); // do not use e.data.height and e.data.width because they might be not correct // instead, try to understand real height and width of decoded image via scale factor @@ -295,11 +321,11 @@ export class FrameDecoder { width, height, )).then((bitmap) => { - decodedFrames[keptIndex] = bitmap; - this.chunkIsBeingDecoded.onDecode(keptIndex, decodedFrames[keptIndex]); + decodedFrames[frameNumber] = bitmap; + this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]); - if (keptIndex === end) { - this.decodedChunks[chunkNumber] = decodedFrames; + if (keptIndex === chunkFrameNumbers.length - 1) { + this.decodedChunks[chunkIndex] = decodedFrames; this.chunkIsBeingDecoded.onDecodeAll(); this.chunkIsBeingDecoded = null; release(); @@ -343,7 +369,7 @@ export class FrameDecoder { this.zipWorker = this.zipWorker || new Worker( new URL('./unzip_imgs.worker', import.meta.url), ); - let index = start; + let decodedCount = 0; this.zipWorker.onmessage = async (event) => { if (event.data.error) { @@ -353,16 +379,18 @@ export class FrameDecoder { return; } - decodedFrames[event.data.index] = event.data.data as ImageBitmap | Blob; - this.chunkIsBeingDecoded.onDecode(event.data.index, decodedFrames[event.data.index]); + const frameNumber = getFrameNumber(event.data.index); + decodedFrames[frameNumber] = event.data.data as ImageBitmap | Blob; + this.chunkIsBeingDecoded.onDecode(frameNumber, decodedFrames[frameNumber]); - if (index === end) { - this.decodedChunks[chunkNumber] = decodedFrames; + if (decodedCount === chunkFrameNumbers.length - 1) { + this.decodedChunks[chunkIndex] = decodedFrames; this.chunkIsBeingDecoded.onDecodeAll(); this.chunkIsBeingDecoded = null; release(); } - index++; + + decodedCount++; }; this.zipWorker.onerror = (event: ErrorEvent) => { @@ -373,8 +401,8 @@ export class FrameDecoder { this.zipWorker.postMessage({ block, - start, - end, + start: 0, + end: chunkFrameNumbers.length - 1, dimension: this.dimension, dimension2D: DimensionType.DIMENSION_2D, }); @@ -400,9 +428,12 @@ export class FrameDecoder { } public cachedChunks(includeInProgress = false): number[] { - const chunkIsBeingDecoded = includeInProgress && this.chunkIsBeingDecoded ? - Math.floor(this.chunkIsBeingDecoded.start / this.chunkSize) : null; - return Object.keys(this.decodedChunks).map((chunkNumber: string) => +chunkNumber).concat( + const chunkIsBeingDecoded = ( + includeInProgress && this.chunkIsBeingDecoded ? + this.chunkIsBeingDecoded.chunkIndex : + null + ); + return Object.keys(this.decodedChunks).map((chunkIndex: string) => +chunkIndex).concat( ...(chunkIsBeingDecoded !== null ? [chunkIsBeingDecoded] : []), ).sort((a, b) => a - b); } diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts index 31b73314a13..b3fa8b503aa 100644 --- a/cvat-ui/src/actions/annotation-actions.ts +++ b/cvat-ui/src/actions/annotation-actions.ts @@ -587,12 +587,13 @@ export function confirmCanvasReadyAsync(): ThunkAction { const { instance: job } = state.annotation.job; const { changeFrameEvent } = state.annotation.player.frame; const chunks = await job.frames.cachedChunks() as number[]; - const { startFrame, stopFrame, dataChunkSize } = job; + const includedFrames = await job.frames.frameNumbers() as number[]; + const { frameCount, dataChunkSize } = job; const ranges = chunks.map((chunk) => ( [ - Math.max(startFrame, chunk * dataChunkSize), - Math.min(stopFrame, (chunk + 1) * dataChunkSize - 1), + includedFrames[chunk * dataChunkSize], + includedFrames[Math.min(frameCount - 1, (chunk + 1) * dataChunkSize - 1)], ] )).reduce>((acc, val) => { if (acc.length && acc[acc.length - 1][1] + 1 === val[0]) { @@ -905,7 +906,8 @@ export function getJobAsync({ // frame query parameter does not work for GT job const frameNumber = Number.isInteger(initialFrame) && gtJob?.id !== job.id ? - initialFrame as number : (await job.frames.search( + initialFrame as number : + (await job.frames.search( { notDeleted: !showDeletedFrames }, job.startFrame, job.stopFrame, )) || job.startFrame; diff --git a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx index 2088d14d7cc..f1a2e9cf289 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/player-navigation.tsx @@ -169,17 +169,14 @@ function PlayerNavigation(props: Props): JSX.Element { {!!ranges && ( {ranges.split(';').map((range) => { - const [start, end] = range.split(':').map((num) => +num); - const adjustedStart = Math.max(0, start - 1); - let totalSegments = stopFrame - startFrame; - if (totalSegments === 0) { - // corner case for jobs with one image - totalSegments = 1; - } + const [rangeStart, rangeStop] = range.split(':').map((num) => +num); + const totalSegments = stopFrame - startFrame + 1; const segmentWidth = 1000 / totalSegments; - const width = Math.max((end - adjustedStart), 1) * segmentWidth; - const offset = (Math.max((adjustedStart - startFrame), 0) / totalSegments) * 1000; - return (); + const width = (rangeStop - rangeStart + 1) * segmentWidth; + const offset = (Math.max((rangeStart - startFrame), 0) / totalSegments) * 1000; + return ( + + ); })} )} diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 534f885449e..eb8fdf26b52 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -30,8 +30,8 @@ from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.dataset_manager.util import add_prefetch_fields -from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, DimensionType, Job, +from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality, FrameOutputType +from cvat.apps.engine.models import (AttributeSpec, AttributeType, DimensionType, Job, JobType, Label, LabelType, Project, SegmentType, ShapeType, Task) from cvat.apps.engine.rq_job_handler import RQJobMetaField @@ -301,7 +301,7 @@ def start(self) -> int: @property def stop(self) -> int: - return len(self) + return max(0, len(self) - 1) def _get_queryset(self): raise NotImplementedError() @@ -437,7 +437,7 @@ def _export_tag(self, tag): def _export_track(self, track, idx): track['shapes'] = list(filter(lambda x: not self._is_frame_deleted(x['frame']), track['shapes'])) tracked_shapes = TrackManager.get_interpolated_shapes( - track, 0, self.stop, self._annotation_ir.dimension) + track, 0, self.stop + 1, self._annotation_ir.dimension) for tracked_shape in tracked_shapes: tracked_shape["attributes"] += track["attributes"] tracked_shape["track_id"] = track["track_id"] if self._use_server_track_ids else idx @@ -493,7 +493,7 @@ def get_frame(idx): anno_manager = AnnotationManager(self._annotation_ir) for shape in sorted( - anno_manager.to_shapes(self.stop, self._annotation_ir.dimension, + anno_manager.to_shapes(self.stop + 1, self._annotation_ir.dimension, # Skip outside, deleted and excluded frames included_frames=included_frames, include_outside=False, @@ -840,7 +840,7 @@ def start(self) -> int: @property def stop(self) -> int: segment = self._db_job.segment - return segment.stop_frame + 1 + return segment.stop_frame @property def db_instance(self): @@ -1410,7 +1410,7 @@ def add_task(self, task, files): @attrs(frozen=True, auto_attribs=True) class ImageSource: - db_data: Data + db_task: Task is_video: bool = attrib(kw_only=True) class ImageProvider: @@ -1439,8 +1439,10 @@ def video_frame_loader(_): # optimization for videos: use numpy arrays instead of bytes # some formats or transforms can require image data return self._frame_provider.get_frame(frame_index, - quality=FrameProvider.Quality.ORIGINAL, - out_type=FrameProvider.Type.NUMPY_ARRAY)[0] + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.NUMPY_ARRAY + ).data + return dm.Image(data=video_frame_loader, **image_kwargs) else: def image_loader(_): @@ -1448,8 +1450,10 @@ def image_loader(_): # for images use encoded data to avoid recoding return self._frame_provider.get_frame(frame_index, - quality=FrameProvider.Quality.ORIGINAL, - out_type=FrameProvider.Type.BUFFER)[0].getvalue() + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.BUFFER + ).data.getvalue() + return dm.ByteImage(data=image_loader, **image_kwargs) def _load_source(self, source_id: int, source: ImageSource) -> None: @@ -1457,7 +1461,7 @@ def _load_source(self, source_id: int, source: ImageSource) -> None: return self._unload_source() - self._frame_provider = FrameProvider(source.db_data) + self._frame_provider = TaskFrameProvider(source.db_task) self._current_source_id = source_id def _unload_source(self) -> None: @@ -1473,7 +1477,7 @@ def __init__(self, sources: Dict[int, ImageSource]) -> None: self._images_per_source = { source_id: { image.id: image - for image in source.db_data.images.prefetch_related('related_files') + for image in source.db_task.data.images.prefetch_related('related_files') } for source_id, source in sources.items() } @@ -1482,7 +1486,7 @@ def get_image_for_frame(self, source_id: int, frame_id: int, **image_kwargs): source = self._sources[source_id] point_cloud_path = osp.join( - source.db_data.get_upload_dirname(), image_kwargs['path'], + source.db_task.data.get_upload_dirname(), image_kwargs['path'], ) image = self._images_per_source[source_id][frame_id] @@ -1595,11 +1599,18 @@ def __init__( is_video = instance_meta['mode'] == 'interpolation' ext = '' if is_video: - ext = FrameProvider.VIDEO_FRAME_EXT + ext = TaskFrameProvider.VIDEO_FRAME_EXT if dimension == DimensionType.DIM_3D or include_images: + if isinstance(instance_data, TaskData): + db_task = instance_data.db_instance + elif isinstance(instance_data, JobData): + db_task = instance_data.db_instance.segment.task + else: + assert False + self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[dimension]( - {0: ImageSource(instance_data.db_data, is_video=is_video)} + {0: ImageSource(db_task, is_video=is_video)} ) for frame_data in instance_data.group_by_frame(include_empty=True): @@ -1681,13 +1692,13 @@ def __init__( if self._dimension == DimensionType.DIM_3D or include_images: self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[self._dimension]( { - task.id: ImageSource(task.data, is_video=task.mode == 'interpolation') + task.id: ImageSource(task, is_video=task.mode == 'interpolation') for task in project_data.tasks } ) ext_per_task: Dict[int, str] = { - task.id: FrameProvider.VIDEO_FRAME_EXT if is_video else '' + task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else '' for task in project_data.tasks for is_video in [task.mode == 'interpolation'] } diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index 0191dfe1c8c..4651fd39845 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -27,7 +27,7 @@ import_dm_annotations, match_dm_item) from cvat.apps.dataset_manager.util import make_zip_archive -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider from .registry import dm_env, exporter, importer @@ -1371,16 +1371,19 @@ def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callb dumper.close_document() def dump_media_files(instance_data: CommonData, img_dir: str, project_data: ProjectData = None): + frame_provider = make_frame_provider(instance_data.db_instance) + ext = '' if instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation': - ext = FrameProvider.VIDEO_FRAME_EXT - - frame_provider = FrameProvider(instance_data.db_data) - frames = frame_provider.get_frames( - instance_data.start, instance_data.stop, - frame_provider.Quality.ORIGINAL, - frame_provider.Type.BUFFER) - for frame_id, (frame_data, _) in zip(instance_data.rel_range, frames): + ext = frame_provider.VIDEO_FRAME_EXT + + frames = frame_provider.iterate_frames( + start_frame=instance_data.start, + stop_frame=instance_data.stop, + quality=FrameQuality.ORIGINAL, + out_type=FrameOutputType.BUFFER, + ) + for frame_id, frame in zip(instance_data.rel_range, frames): if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \ or frame_id in instance_data.deleted_frames: continue @@ -1389,7 +1392,7 @@ def dump_media_files(instance_data: CommonData, img_dir: str, project_data: Proj img_path = osp.join(img_dir, frame_name + ext) os.makedirs(osp.dirname(img_path), exist_ok=True) with open(img_path, 'wb') as f: - f.write(frame_data.getvalue()) + f.write(frame.data.getvalue()) def _export_task_or_job(dst_file, temp_dir, instance_data, anno_callback, save_images=False): with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f: diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 6a03e41c8aa..42b2337304b 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -1,6 +1,6 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -14,10 +14,8 @@ from datumaro.components.dataset import Dataset, DatasetItem from datumaro.components.annotation import Mask from django.contrib.auth.models import Group, User -from PIL import Image from rest_framework import status -from rest_framework.test import APIClient, APITestCase import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.annotation import AnnotationIR @@ -26,36 +24,13 @@ from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import Task -from cvat.apps.engine.tests.utils import get_paginated_collection +from cvat.apps.engine.tests.utils import ( + get_paginated_collection, ForceLogin, generate_image_file, ApiTestBase +) - -def generate_image_file(filename, size=(100, 100)): - f = BytesIO() - image = Image.new('RGB', size=size) - image.save(f, 'jpeg') - f.name = filename - f.seek(0) - return f - -class ForceLogin: - def __init__(self, user, client): - self.user = user - self.client = client - - def __enter__(self): - if self.user: - self.client.force_login(self.user, - backend='django.contrib.auth.backends.ModelBackend') - - return self - - def __exit__(self, exception_type, exception_value, traceback): - if self.user: - self.client.logout() - -class _DbTestBase(APITestCase): +class _DbTestBase(ApiTestBase): def setUp(self): - self.client = APIClient() + super().setUp() @classmethod def setUpTestData(cls): @@ -94,6 +69,11 @@ def _create_task(self, data, image_data): response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index de961318a5a..a0717c1ef11 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -171,6 +171,11 @@ def _create_task(self, data, image_data): response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) @@ -412,7 +417,7 @@ def test_api_v2_dump_and_upload_annotations_with_objects_type_is_shape(self): url = self._generate_url_dump_tasks_annotations(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -520,7 +525,7 @@ def test_api_v2_dump_annotations_with_objects_type_is_track(self): url = self._generate_url_dump_tasks_annotations(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -605,7 +610,7 @@ def test_api_v2_dump_tag_annotations(self): for user, edata in list(expected.items()): with self.subTest(format=f"{edata['name']}"): with TestDir() as test_dir: - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] url = self._generate_url_dump_tasks_annotations(task_id) @@ -847,7 +852,7 @@ def test_api_v2_export_dataset(self): # dump annotations url = self._generate_url_dump_task_dataset(task_id) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -2107,7 +2112,7 @@ def test_api_v2_export_import_dataset(self): self._create_annotations(task, dump_format_name, "random") for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') @@ -2170,7 +2175,7 @@ def test_api_v2_export_annotations(self): url = self._generate_url_dump_project_annotations(project['id'], dump_format_name) for user, edata in list(expected.items()): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations user_name = edata['name'] file_zip_name = osp.join(test_dir, f'{test_name}_{user_name}_{dump_format_name}.zip') diff --git a/cvat/apps/engine/apps.py b/cvat/apps/engine/apps.py index 326920e8b49..bcad84510f5 100644 --- a/cvat/apps/engine/apps.py +++ b/cvat/apps/engine/apps.py @@ -10,6 +10,14 @@ class EngineConfig(AppConfig): name = 'cvat.apps.engine' def ready(self): + from django.conf import settings + + from . import default_settings + + for key in dir(default_settings): + if key.isupper() and not hasattr(settings, key): + setattr(settings, key, getattr(default_settings, key)) + # Required to define signals in application import cvat.apps.engine.signals # Required in order to silent "unused-import" in pyflake diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index 2603c2fd5a1..bc4c8616bd7 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -1,349 +1,700 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +from __future__ import annotations + import io import os -import zipfile -from datetime import datetime, timezone -from io import BytesIO -import shutil +import os.path +import pickle # nosec import tempfile +import zipfile import zlib - -from typing import Optional, Tuple - +from contextlib import ExitStack, closing +from datetime import datetime, timezone +from itertools import groupby, pairwise +from typing import ( + Any, + Callable, + Collection, + Generator, + Iterator, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, +) + +import av import cv2 import PIL.Image -import pickle # nosec -from django.conf import settings +import PIL.ImageOps from django.core.cache import caches from rest_framework.exceptions import NotFound, ValidationError -from cvat.apps.engine.cloud_provider import (Credentials, - db_storage_to_storage_instance, - get_cloud_storage_instance) +from cvat.apps.engine import models +from cvat.apps.engine.cloud_provider import ( + Credentials, + db_storage_to_storage_instance, + get_cloud_storage_instance, +) from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.media_extractors import (ImageDatasetManifestReader, - Mpeg4ChunkWriter, - Mpeg4CompressedChunkWriter, - VideoDatasetManifestReader, - ZipChunkWriter, - ZipCompressedChunkWriter) -from cvat.apps.engine.mime_types import mimetypes -from cvat.apps.engine.models import (DataChoice, DimensionType, Job, Image, - StorageChoice, CloudStorage) +from cvat.apps.engine.media_extractors import ( + FrameQuality, + IChunkWriter, + ImageReaderWithManifest, + Mpeg4ChunkWriter, + Mpeg4CompressedChunkWriter, + VideoReader, + VideoReaderWithManifest, + ZipChunkWriter, + ZipCompressedChunkWriter, +) from cvat.apps.engine.utils import md5_hash, preload_images from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) + +DataWithMime = Tuple[io.BytesIO, str] +_CacheItem = Tuple[io.BytesIO, str, int] + + class MediaCache: - def __init__(self, dimension=DimensionType.DIM_2D): - self._dimension = dimension - self._cache = caches['media'] - - def _get_or_set_cache_item(self, key, create_function): - def create_item(): - slogger.glob.info(f'Starting to prepare chunk: key {key}') - item = create_function() - slogger.glob.info(f'Ending to prepare chunk: key {key}') - - if item[0]: - item = (item[0], item[1], zlib.crc32(item[0].getbuffer())) + def __init__(self) -> None: + self._cache = caches["media"] + + def _get_checksum(self, value: bytes) -> int: + return zlib.crc32(value) + + def _get_or_set_cache_item( + self, key: str, create_callback: Callable[[], DataWithMime] + ) -> _CacheItem: + def create_item() -> _CacheItem: + slogger.glob.info(f"Starting to prepare chunk: key {key}") + item_data = create_callback() + slogger.glob.info(f"Ending to prepare chunk: key {key}") + + item_data_bytes = item_data[0].getvalue() + item = (item_data[0], item_data[1], self._get_checksum(item_data_bytes)) + if item_data_bytes: self._cache.set(key, item) return item - slogger.glob.info(f'Starting to get chunk from cache: key {key}') - try: - item = self._cache.get(key) - except pickle.UnpicklingError: - slogger.glob.error(f'Unable to get item from cache: key {key}', exc_info=True) - item = None - slogger.glob.info(f'Ending to get chunk from cache: key {key}, is_cached {bool(item)}') - + item = self._get_cache_item(key) if not item: item = create_item() else: # compare checksum item_data = item[0].getbuffer() if isinstance(item[0], io.BytesIO) else item[0] item_checksum = item[2] if len(item) == 3 else None - if item_checksum != zlib.crc32(item_data): - slogger.glob.info(f'Recreating cache item {key} due to checksum mismatch') + if item_checksum != self._get_checksum(item_data): + slogger.glob.info(f"Recreating cache item {key} due to checksum mismatch") item = create_item() - return item[0], item[1] + return item - def get_task_chunk_data_with_mime(self, chunk_number, quality, db_data): - item = self._get_or_set_cache_item( - key=f'{db_data.id}_{chunk_number}_{quality}', - create_function=lambda: self._prepare_task_chunk(db_data, quality, chunk_number), - ) + def _get_cache_item(self, key: str) -> Optional[_CacheItem]: + slogger.glob.info(f"Starting to get chunk from cache: key {key}") + try: + item = self._cache.get(key) + except pickle.UnpicklingError: + slogger.glob.error(f"Unable to get item from cache: key {key}", exc_info=True) + item = None + slogger.glob.info(f"Ending to get chunk from cache: key {key}, is_cached {bool(item)}") return item - def get_selective_job_chunk_data_with_mime(self, chunk_number, quality, job): - item = self._get_or_set_cache_item( - key=f'job_{job.id}_{chunk_number}_{quality}', - create_function=lambda: self.prepare_selective_job_chunk(job, quality, chunk_number), - ) + def _has_key(self, key: str) -> bool: + return self._cache.has_key(key) + + def _make_cache_key_prefix( + self, obj: Union[models.Task, models.Segment, models.Job, models.CloudStorage] + ) -> str: + if isinstance(obj, models.Task): + return f"task_{obj.id}" + elif isinstance(obj, models.Segment): + return f"segment_{obj.id}" + elif isinstance(obj, models.Job): + return f"job_{obj.id}" + elif isinstance(obj, models.CloudStorage): + return f"cloudstorage_{obj.id}" + else: + assert False, f"Unexpected object type {type(obj)}" - return item + def _make_chunk_key( + self, + db_obj: Union[models.Task, models.Segment, models.Job], + chunk_number: int, + *, + quality: FrameQuality, + ) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_chunk_{chunk_number}_{quality}" + + def _make_preview_key(self, db_obj: Union[models.Segment, models.CloudStorage]) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_preview" - def get_local_preview_with_mime(self, frame_number, db_data): - item = self._get_or_set_cache_item( - key=f'data_{db_data.id}_{frame_number}_preview', - create_function=lambda: self._prepare_local_preview(frame_number, db_data), + def _make_segment_task_chunk_key( + self, + db_obj: models.Segment, + chunk_number: int, + *, + quality: FrameQuality, + ) -> str: + return f"{self._make_cache_key_prefix(db_obj)}_task_chunk_{chunk_number}_{quality}" + + def _make_context_image_preview_key(self, db_data: models.Data, frame_number: int) -> str: + return f"context_image_{db_data.id}_{frame_number}_preview" + + @overload + def _to_data_with_mime(self, cache_item: _CacheItem) -> DataWithMime: ... + + @overload + def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: ... + + def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: + if not cache_item: + return None + + return cache_item[:2] + + def get_or_set_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_segment, chunk_number, quality=quality), + create_callback=lambda: self.prepare_segment_chunk( + db_segment, chunk_number, quality=quality + ), + ) ) - return item + def get_task_chunk( + self, db_task: models.Task, chunk_number: int, *, quality: FrameQuality + ) -> Optional[DataWithMime]: + return self._to_data_with_mime( + self._get_cache_item(key=self._make_chunk_key(db_task, chunk_number, quality=quality)) + ) - def get_cloud_preview_with_mime( + def get_or_set_task_chunk( self, - db_storage: CloudStorage, - ) -> Optional[Tuple[io.BytesIO, str]]: - key = f'cloudstorage_{db_storage.id}_preview' - return self._cache.get(key) + db_task: models.Task, + chunk_number: int, + *, + quality: FrameQuality, + set_callback: Callable[[], DataWithMime], + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_task, chunk_number, quality=quality), + create_callback=set_callback, + ) + ) - def get_or_set_cloud_preview_with_mime( - self, - db_storage: CloudStorage, - ) -> Tuple[io.BytesIO, str]: - key = f'cloudstorage_{db_storage.id}_preview' + def get_segment_task_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> Optional[DataWithMime]: + return self._to_data_with_mime( + self._get_cache_item( + key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality) + ) + ) - item = self._get_or_set_cache_item( - key, create_function=lambda: self._prepare_cloud_preview(db_storage) + def get_or_set_segment_task_chunk( + self, + db_segment: models.Segment, + chunk_number: int, + *, + quality: FrameQuality, + set_callback: Callable[[], DataWithMime], + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), + create_callback=set_callback, + ) ) - return item + def get_or_set_selective_job_chunk( + self, db_job: models.Job, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_chunk_key(db_job, chunk_number, quality=quality), + create_callback=lambda: self.prepare_masked_range_segment_chunk( + db_job.segment, chunk_number, quality=quality + ), + ) + ) - def get_frame_context_images(self, db_data, frame_number): - item = self._get_or_set_cache_item( - key=f'context_image_{db_data.id}_{frame_number}', - create_function=lambda: self._prepare_context_image(db_data, frame_number) + def get_or_set_segment_preview(self, db_segment: models.Segment) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + self._make_preview_key(db_segment), + create_callback=lambda: self._prepare_segment_preview(db_segment), + ) ) - return item + def get_cloud_preview(self, db_storage: models.CloudStorage) -> Optional[DataWithMime]: + return self._to_data_with_mime(self._get_cache_item(self._make_preview_key(db_storage))) - @staticmethod - def _get_frame_provider_class(): - from cvat.apps.engine.frame_provider import \ - FrameProvider # TODO: remove circular dependency - return FrameProvider - - from contextlib import contextmanager - - @staticmethod - @contextmanager - def _get_images(db_data, chunk_number, dimension): - images = [] - tmp_dir = None - upload_dir = { - StorageChoice.LOCAL: db_data.get_upload_dirname(), - StorageChoice.SHARE: settings.SHARE_ROOT, - StorageChoice.CLOUD_STORAGE: db_data.get_upload_dirname(), - }[db_data.storage] + def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + self._make_preview_key(db_storage), + create_callback=lambda: self._prepare_cloud_preview(db_storage), + ) + ) - try: - if hasattr(db_data, 'video'): - source_path = os.path.join(upload_dir, db_data.video.path) - - reader = VideoDatasetManifestReader(manifest_path=db_data.get_manifest_path(), - source_path=source_path, chunk_number=chunk_number, - chunk_size=db_data.chunk_size, start=db_data.start_frame, - stop=db_data.stop_frame, step=db_data.get_frame_step()) - for frame in reader: - images.append((frame, source_path, None)) - else: - reader = ImageDatasetManifestReader(manifest_path=db_data.get_manifest_path(), - chunk_number=chunk_number, chunk_size=db_data.chunk_size, - start=db_data.start_frame, stop=db_data.stop_frame, - step=db_data.get_frame_step()) - if db_data.storage == StorageChoice.CLOUD_STORAGE: - db_cloud_storage = db_data.cloud_storage - assert db_cloud_storage, 'Cloud storage instance was deleted' - credentials = Credentials() - credentials.convert_from_db({ - 'type': db_cloud_storage.credentials_type, - 'value': db_cloud_storage.credentials, - }) - details = { - 'resource': db_cloud_storage.resource, - 'credentials': credentials, - 'specific_attributes': db_cloud_storage.get_specific_attributes() + def get_or_set_frame_context_images_chunk( + self, db_data: models.Data, frame_number: int + ) -> DataWithMime: + return self._to_data_with_mime( + self._get_or_set_cache_item( + key=self._make_context_image_preview_key(db_data, frame_number), + create_callback=lambda: self.prepare_context_images_chunk(db_data, frame_number), + ) + ) + + def _read_raw_images( + self, + db_task: models.Task, + frame_ids: Sequence[int], + *, + manifest_path: str, + ): + db_data = db_task.data + + if os.path.isfile(manifest_path) and db_data.storage == models.StorageChoice.CLOUD_STORAGE: + reader = ImageReaderWithManifest(manifest_path) + with ExitStack() as es: + db_cloud_storage = db_data.cloud_storage + assert db_cloud_storage, "Cloud storage instance was deleted" + credentials = Credentials() + credentials.convert_from_db( + { + "type": db_cloud_storage.credentials_type, + "value": db_cloud_storage.credentials, } - cloud_storage_instance = get_cloud_storage_instance(cloud_provider=db_cloud_storage.provider_type, **details) + ) + details = { + "resource": db_cloud_storage.resource, + "credentials": credentials, + "specific_attributes": db_cloud_storage.get_specific_attributes(), + } + cloud_storage_instance = get_cloud_storage_instance( + cloud_provider=db_cloud_storage.provider_type, **details + ) + + tmp_dir = es.enter_context(tempfile.TemporaryDirectory(prefix="cvat")) + files_to_download = [] + checksums = [] + media = [] + for item in reader.iterate_frames(frame_ids): + file_name = f"{item['name']}{item['extension']}" + fs_filename = os.path.join(tmp_dir, file_name) + + files_to_download.append(file_name) + checksums.append(item.get("checksum", None)) + media.append((fs_filename, fs_filename, None)) + + cloud_storage_instance.bulk_download_to_dir( + files=files_to_download, upload_dir=tmp_dir + ) + media = preload_images(media) + + for checksum, (_, fs_filename, _) in zip(checksums, media): + if checksum and not md5_hash(fs_filename) == checksum: + slogger.cloud_storage[db_cloud_storage.id].warning( + "Hash sums of files {} do not match".format(file_name) + ) + + yield from media + else: + requested_frame_iter = iter(frame_ids) + next_requested_frame_id = next(requested_frame_iter, None) + if next_requested_frame_id is None: + return + + # TODO: find a way to use prefetched results, if provided + db_images = ( + db_data.images.order_by("frame") + .filter(frame__gte=frame_ids[0], frame__lte=frame_ids[-1]) + .values_list("frame", "path") + .all() + ) - tmp_dir = tempfile.mkdtemp(prefix='cvat') - files_to_download = [] - checksums = [] - for item in reader: - file_name = f"{item['name']}{item['extension']}" - fs_filename = os.path.join(tmp_dir, file_name) + raw_data_dir = db_data.get_raw_data_dirname() + media = [] + for frame_id, frame_path in db_images: + if frame_id == next_requested_frame_id: + source_path = os.path.join(raw_data_dir, frame_path) + media.append((source_path, source_path, None)) - files_to_download.append(file_name) - checksums.append(item.get('checksum', None)) - images.append((fs_filename, fs_filename, None)) + next_requested_frame_id = next(requested_frame_iter, None) - cloud_storage_instance.bulk_download_to_dir(files=files_to_download, upload_dir=tmp_dir) - images = preload_images(images) + if next_requested_frame_id is None: + break - for checksum, (_, fs_filename, _) in zip(checksums, images): - if checksum and not md5_hash(fs_filename) == checksum: - slogger.cloud_storage[db_cloud_storage.id].warning('Hash sums of files {} do not match'.format(file_name)) - else: - for item in reader: - source_path = os.path.join(upload_dir, f"{item['name']}{item['extension']}") - images.append((source_path, source_path, None)) - if dimension == DimensionType.DIM_2D: - images = preload_images(images) - - yield images - finally: - if db_data.storage == StorageChoice.CLOUD_STORAGE and tmp_dir is not None: - shutil.rmtree(tmp_dir) - - def _prepare_task_chunk(self, db_data, quality, chunk_number): - FrameProvider = self._get_frame_provider_class() - - writer_classes = { - FrameProvider.Quality.COMPRESSED : Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == DataChoice.VIDEO else ZipCompressedChunkWriter, - FrameProvider.Quality.ORIGINAL : Mpeg4ChunkWriter if db_data.original_chunk_type == DataChoice.VIDEO else ZipChunkWriter, - } - - image_quality = 100 if writer_classes[quality] in [Mpeg4ChunkWriter, ZipChunkWriter] else db_data.image_quality - mime_type = 'video/mp4' if writer_classes[quality] in [Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter] else 'application/zip' - - kwargs = {} - if self._dimension == DimensionType.DIM_3D: - kwargs["dimension"] = DimensionType.DIM_3D - writer = writer_classes[quality](image_quality, **kwargs) - - buff = BytesIO() - with self._get_images(db_data, chunk_number, self._dimension) as images: - writer.save_as_chunk(images, buff) - buff.seek(0) + assert next_requested_frame_id is None - return buff, mime_type + if db_task.dimension == models.DimensionType.DIM_2D: + media = preload_images(media) - def prepare_selective_job_chunk(self, db_job: Job, quality, chunk_number: int): - db_data = db_job.segment.task.data + yield from media - FrameProvider = self._get_frame_provider_class() - frame_provider = FrameProvider(db_data, self._dimension) + def _read_raw_frames( + self, db_task: models.Task, frame_ids: Sequence[int] + ) -> Generator[Tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]: + for prev_frame, cur_frame in pairwise(frame_ids): + assert ( + prev_frame <= cur_frame + ), f"Requested frame ids must be sorted, got a ({prev_frame}, {cur_frame}) pair" - frame_set = db_job.segment.frame_set - frame_step = db_data.get_frame_step() - chunk_frames = [] + db_data = db_task.data - writer = ZipCompressedChunkWriter(db_data.image_quality, dimension=self._dimension) - dummy_frame = BytesIO() - PIL.Image.new('RGB', (1, 1)).save(dummy_frame, writer.IMAGE_EXT) + manifest_path = db_data.get_manifest_path() - if hasattr(db_data, 'video'): - frame_size = (db_data.video.width, db_data.video.height) - else: - frame_size = None + if hasattr(db_data, "video"): + source_path = os.path.join(db_data.get_raw_data_dirname(), db_data.video.path) - for frame_idx in range(db_data.chunk_size): - frame_idx = ( - db_data.start_frame + chunk_number * db_data.chunk_size + frame_idx * frame_step + reader = VideoReaderWithManifest( + manifest_path=manifest_path, + source_path=source_path, + allow_threading=False, ) - if db_data.stop_frame < frame_idx: - break - - frame_bytes = None - - if frame_idx in frame_set: - frame_bytes = frame_provider.get_frame(frame_idx, quality=quality)[0] + if not os.path.isfile(manifest_path): + try: + reader.manifest.link(source_path, force=True) + reader.manifest.create() + except Exception as e: + slogger.task[db_task.id].warning( + f"Failed to create video manifest: {e}", exc_info=True + ) + reader = None + + if reader: + for frame in reader.iterate_frames(frame_filter=frame_ids): + yield (frame, source_path, None) + else: + reader = VideoReader([source_path], allow_threading=False) - if frame_size is not None: - # Decoded video frames can have different size, restore the original one + for frame_tuple in reader.iterate_frames(frame_filter=frame_ids): + yield frame_tuple + else: + yield from self._read_raw_images(db_task, frame_ids, manifest_path=manifest_path) + + def prepare_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + if db_segment.type == models.SegmentType.RANGE: + return self.prepare_range_segment_chunk(db_segment, chunk_number, quality=quality) + elif db_segment.type == models.SegmentType.SPECIFIC_FRAMES: + return self.prepare_masked_range_segment_chunk( + db_segment, chunk_number, quality=quality + ) + else: + assert False, f"Unknown segment type {db_segment.type}" + + def prepare_range_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + db_task = db_segment.task + db_data = db_task.data + + chunk_size = db_data.chunk_size + chunk_frame_ids = list(db_segment.frame_set)[ + chunk_size * chunk_number : chunk_size * (chunk_number + 1) + ] + + return self.prepare_custom_range_segment_chunk(db_task, chunk_frame_ids, quality=quality) + + def prepare_custom_range_segment_chunk( + self, db_task: models.Task, frame_ids: Sequence[int], *, quality: FrameQuality + ) -> DataWithMime: + with closing(self._read_raw_frames(db_task, frame_ids=frame_ids)) as frame_iter: + return prepare_chunk(frame_iter, quality=quality, db_task=db_task) + + def prepare_masked_range_segment_chunk( + self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + ) -> DataWithMime: + db_task = db_segment.task + db_data = db_task.data + + chunk_size = db_data.chunk_size + chunk_frame_ids = sorted(db_segment.frame_set)[ + chunk_size * chunk_number : chunk_size * (chunk_number + 1) + ] + + return self.prepare_custom_masked_range_segment_chunk( + db_task, chunk_frame_ids, chunk_number, quality=quality + ) - frame = PIL.Image.open(frame_bytes) - if frame.size != frame_size: - frame = frame.resize(frame_size) + def prepare_custom_masked_range_segment_chunk( + self, + db_task: models.Task, + frame_ids: Collection[int], + chunk_number: int, + *, + quality: FrameQuality, + insert_placeholders: bool = False, + ) -> DataWithMime: + db_data = db_task.data - frame_bytes = BytesIO() - frame.save(frame_bytes, writer.IMAGE_EXT) - frame_bytes.seek(0) + frame_step = db_data.get_frame_step() - else: - # Populate skipped frames with placeholder data, - # this is required for video chunk decoding implementation in UI - frame_bytes = BytesIO(dummy_frame.getvalue()) + image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality + writer = ZipCompressedChunkWriter(image_quality, dimension=db_task.dimension) + + dummy_frame = io.BytesIO() + PIL.Image.new("RGB", (1, 1)).save(dummy_frame, writer.IMAGE_EXT) + + # Optimize frame access if all the required frames are already cached + # Otherwise we might need to download files. + # This is not needed for video tasks, as it will reduce performance + from cvat.apps.engine.frame_provider import FrameOutputType, TaskFrameProvider + + task_frame_provider = TaskFrameProvider(db_task) + + use_cached_data = False + if db_task.mode != "interpolation": + required_frame_set = set(frame_ids) + available_chunks = [ + self._has_key(self._make_chunk_key(db_segment, chunk_number, quality=quality)) + for db_segment in db_task.segment_set.filter(type=models.SegmentType.RANGE).all() + for chunk_number, _ in groupby( + sorted(required_frame_set.intersection(db_segment.frame_set)), + key=lambda frame: frame // db_data.chunk_size, + ) + ] + use_cached_data = bool(available_chunks) and all(available_chunks) + + if hasattr(db_data, "video"): + frame_size = (db_data.video.width, db_data.video.height) + else: + frame_size = None - if frame_bytes is not None: - chunk_frames.append((frame_bytes, None, None)) + def get_frames(): + with ExitStack() as es: + es.callback(task_frame_provider.unload) + + if insert_placeholders: + frame_range = ( + ( + db_data.start_frame + + chunk_number * db_data.chunk_size + + chunk_frame_idx * frame_step + ) + for chunk_frame_idx in range(db_data.chunk_size) + ) + else: + frame_range = frame_ids + + if not use_cached_data: + frames_gen = self._read_raw_frames(db_task, frame_ids) + frames_iter = iter(es.enter_context(closing(frames_gen))) + + for abs_frame_idx in frame_range: + if db_data.stop_frame < abs_frame_idx: + break + + if abs_frame_idx in frame_ids: + if use_cached_data: + frame_data = task_frame_provider.get_frame( + task_frame_provider.get_rel_frame_number(abs_frame_idx), + quality=quality, + out_type=FrameOutputType.BUFFER, + ) + frame = frame_data.data + else: + frame, _, _ = next(frames_iter) + + if hasattr(db_data, "video"): + # Decoded video frames can have different size, restore the original one + + if isinstance(frame, av.VideoFrame): + frame = frame.to_image() + else: + frame = PIL.Image.open(frame) + + if frame.size != frame_size: + frame = frame.resize(frame_size) + else: + # Populate skipped frames with placeholder data, + # this is required for video chunk decoding implementation in UI + frame = io.BytesIO(dummy_frame.getvalue()) + + yield (frame, None, None) + + buff = io.BytesIO() + with closing(get_frames()) as frame_iter: + writer.save_as_chunk( + frame_iter, + buff, + zip_compress_level=1, + # there are likely to be many skips with repeated placeholder frames + # in SPECIFIC_FRAMES segments, it makes sense to compress the archive + ) - buff = BytesIO() - writer.save_as_chunk(chunk_frames, buff, compress_frames=False, - zip_compress_level=1 # these are likely to be many skips in SPECIFIC_FRAMES segments - ) buff.seek(0) + return buff, get_chunk_mime_type_for_writer(writer) - return buff, 'application/zip' + def _prepare_segment_preview(self, db_segment: models.Segment) -> DataWithMime: + if db_segment.task.dimension == models.DimensionType.DIM_3D: + # TODO + preview = PIL.Image.open( + os.path.join(os.path.dirname(__file__), "assets/3d_preview.jpeg") + ) + else: + from cvat.apps.engine.frame_provider import ( # avoid circular import + FrameOutputType, + make_frame_provider, + ) - def _prepare_local_preview(self, frame_number, db_data): - FrameProvider = self._get_frame_provider_class() - frame_provider = FrameProvider(db_data, self._dimension) - buff, mime_type = frame_provider.get_preview(frame_number) + task_frame_provider = make_frame_provider(db_segment.task) + segment_frame_provider = make_frame_provider(db_segment) + preview = segment_frame_provider.get_frame( + task_frame_provider.get_rel_frame_number(min(db_segment.frame_set)), + quality=FrameQuality.COMPRESSED, + out_type=FrameOutputType.PIL, + ).data - return buff, mime_type + return prepare_preview_image(preview) - def _prepare_cloud_preview(self, db_storage): + def _prepare_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: storage = db_storage_to_storage_instance(db_storage) if not db_storage.manifests.count(): - raise ValidationError('Cannot get the cloud storage preview. There is no manifest file') + raise ValidationError("Cannot get the cloud storage preview. There is no manifest file") + preview_path = None - for manifest_model in db_storage.manifests.all(): - manifest_prefix = os.path.dirname(manifest_model.filename) - full_manifest_path = os.path.join(db_storage.get_storage_dirname(), manifest_model.filename) - if not os.path.exists(full_manifest_path) or \ - datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) < storage.get_file_last_modified(manifest_model.filename): - storage.download_file(manifest_model.filename, full_manifest_path) + for db_manifest in db_storage.manifests.all(): + manifest_prefix = os.path.dirname(db_manifest.filename) + + full_manifest_path = os.path.join( + db_storage.get_storage_dirname(), db_manifest.filename + ) + if not os.path.exists(full_manifest_path) or datetime.fromtimestamp( + os.path.getmtime(full_manifest_path), tz=timezone.utc + ) < storage.get_file_last_modified(db_manifest.filename): + storage.download_file(db_manifest.filename, full_manifest_path) + manifest = ImageManifestManager( - os.path.join(db_storage.get_storage_dirname(), manifest_model.filename), - db_storage.get_storage_dirname() + os.path.join(db_storage.get_storage_dirname(), db_manifest.filename), + db_storage.get_storage_dirname(), ) # need to update index manifest.set_index() if not len(manifest): continue + preview_info = manifest[0] - preview_filename = ''.join([preview_info['name'], preview_info['extension']]) + preview_filename = "".join([preview_info["name"], preview_info["extension"]]) preview_path = os.path.join(manifest_prefix, preview_filename) break + if not preview_path: - msg = 'Cloud storage {} does not contain any images'.format(db_storage.pk) + msg = "Cloud storage {} does not contain any images".format(db_storage.pk) slogger.cloud_storage[db_storage.pk].info(msg) raise NotFound(msg) buff = storage.download_fileobj(preview_path) - mime_type = mimetypes.guess_type(preview_path)[0] + image = PIL.Image.open(buff) + return prepare_preview_image(image) - return buff, mime_type + def prepare_context_images_chunk(self, db_data: models.Data, frame_number: int) -> DataWithMime: + zip_buffer = io.BytesIO() - def _prepare_context_image(self, db_data, frame_number): - zip_buffer = BytesIO() - try: - image = Image.objects.get(data_id=db_data.id, frame=frame_number) - except Image.DoesNotExist: - return None, None - with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED, False) as zip_file: - if not image.related_files.count(): - return None, None - common_path = os.path.commonpath(list(map(lambda x: str(x.path), image.related_files.all()))) - for i in image.related_files.all(): + related_images = db_data.related_files.filter(primary_image__frame=frame_number).all() + if not related_images: + return zip_buffer, "" + + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + common_path = os.path.commonpath(list(map(lambda x: str(x.path), related_images))) + for i in related_images: path = os.path.realpath(str(i.path)) name = os.path.relpath(str(i.path), common_path) image = cv2.imread(path) - success, result = cv2.imencode('.JPEG', image) + success, result = cv2.imencode(".JPEG", image) if not success: raise Exception('Failed to encode image to ".jpeg" format') - zip_file.writestr(f'{name}.jpg', result.tobytes()) - mime_type = 'application/zip' + zip_file.writestr(f"{name}.jpg", result.tobytes()) + zip_buffer.seek(0) + mime_type = "application/zip" return zip_buffer, mime_type + + +def prepare_preview_image(image: PIL.Image.Image) -> DataWithMime: + PREVIEW_SIZE = (256, 256) + PREVIEW_MIME = "image/jpeg" + + image = PIL.ImageOps.exif_transpose(image) + image.thumbnail(PREVIEW_SIZE) + + output_buf = io.BytesIO() + image.convert("RGB").save(output_buf, format="JPEG") + return output_buf, PREVIEW_MIME + + +def prepare_chunk( + task_chunk_frames: Iterator[Tuple[Any, str, int]], + *, + quality: FrameQuality, + db_task: models.Task, + dump_unchanged: bool = False, +) -> DataWithMime: + # TODO: refactor all chunk building into another class + + db_data = db_task.data + + writer_classes: dict[FrameQuality, Type[IChunkWriter]] = { + FrameQuality.COMPRESSED: ( + Mpeg4CompressedChunkWriter + if db_data.compressed_chunk_type == models.DataChoice.VIDEO + else ZipCompressedChunkWriter + ), + FrameQuality.ORIGINAL: ( + Mpeg4ChunkWriter + if db_data.original_chunk_type == models.DataChoice.VIDEO + else ZipChunkWriter + ), + } + + writer_class = writer_classes[quality] + + image_quality = 100 if quality == FrameQuality.ORIGINAL else db_data.image_quality + + writer_kwargs = {} + if db_task.dimension == models.DimensionType.DIM_3D: + writer_kwargs["dimension"] = models.DimensionType.DIM_3D + merged_chunk_writer = writer_class(image_quality, **writer_kwargs) + + writer_kwargs = {} + if dump_unchanged and isinstance(merged_chunk_writer, ZipCompressedChunkWriter): + writer_kwargs = dict(compress_frames=False, zip_compress_level=1) + + buffer = io.BytesIO() + merged_chunk_writer.save_as_chunk(task_chunk_frames, buffer, **writer_kwargs) + + buffer.seek(0) + return buffer, get_chunk_mime_type_for_writer(writer_class) + + +def get_chunk_mime_type_for_writer(writer: Union[IChunkWriter, Type[IChunkWriter]]) -> str: + if isinstance(writer, IChunkWriter): + writer_class = type(writer) + else: + writer_class = writer + + if issubclass(writer_class, ZipChunkWriter): + return "application/zip" + elif issubclass(writer_class, Mpeg4ChunkWriter): + return "video/mp4" + else: + assert False, f"Unknown chunk writer class {writer_class}" diff --git a/cvat/apps/engine/default_settings.py b/cvat/apps/engine/default_settings.py new file mode 100644 index 00000000000..826fe1c9bef --- /dev/null +++ b/cvat/apps/engine/default_settings.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import os + +from attrs.converters import to_bool + +MEDIA_CACHE_ALLOW_STATIC_CACHE = to_bool(os.getenv("CVAT_ALLOW_STATIC_CACHE", False)) +""" +Allow or disallow static media cache. +If disabled, CVAT will only use the dynamic media cache. New tasks requesting static media cache +will be automatically switched to the dynamic cache. +When enabled, this option can increase data access speed and reduce server load, +but significantly increase disk space occupied by tasks. +""" diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 4e2f42ef793..ea14b40a75a 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -3,226 +3,693 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + +import io +import itertools import math -from enum import Enum +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto from io import BytesIO -import os - +from typing import ( + Any, + Callable, + Generic, + Iterator, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import av import cv2 import numpy as np -from PIL import Image, ImageOps +from datumaro.util import take_by +from django.conf import settings +from PIL import Image +from rest_framework.exceptions import ValidationError -from cvat.apps.engine.cache import MediaCache -from cvat.apps.engine.media_extractors import VideoReader, ZipReader +from cvat.apps.engine import models +from cvat.apps.engine.cache import DataWithMime, MediaCache, prepare_chunk +from cvat.apps.engine.media_extractors import ( + FrameQuality, + IMediaReader, + RandomAccessIterator, + VideoReader, + ZipReader, +) from cvat.apps.engine.mime_types import mimetypes -from cvat.apps.engine.models import DataChoice, StorageMethodChoice, DimensionType -from rest_framework.exceptions import ValidationError -class RandomAccessIterator: - def __init__(self, iterable): - self.iterable = iterable - self.iterator = None - self.pos = -1 - - def __iter__(self): - return self - - def __next__(self): - return self[self.pos + 1] - - def __getitem__(self, idx): - assert 0 <= idx - if self.iterator is None or idx <= self.pos: - self.reset() - v = None - while self.pos < idx: - # NOTE: don't keep the last item in self, it can be expensive - v = next(self.iterator) - self.pos += 1 - return v - - def reset(self): - self.close() - self.iterator = iter(self.iterable) - - def close(self): - if self.iterator is not None: - if close := getattr(self.iterator, 'close', None): - close() - self.iterator = None - self.pos = -1 - -class FrameProvider: - VIDEO_FRAME_EXT = '.PNG' - VIDEO_FRAME_MIME = 'image/png' - - class Quality(Enum): - COMPRESSED = 0 - ORIGINAL = 100 - - class Type(Enum): - BUFFER = 0 - PIL = 1 - NUMPY_ARRAY = 2 - - class ChunkLoader: - def __init__(self, reader_class, path_getter): - self.chunk_id = None +_T = TypeVar("_T") + + +class _ChunkLoader(metaclass=ABCMeta): + def __init__( + self, + reader_class: Type[IMediaReader], + *, + reader_params: Optional[dict] = None, + ) -> None: + self.chunk_id: Optional[int] = None + self.chunk_reader: Optional[RandomAccessIterator] = None + self.reader_class = reader_class + self.reader_params = reader_params + + def load(self, chunk_id: int) -> RandomAccessIterator[Tuple[Any, str, int]]: + if self.chunk_id != chunk_id: + self.unload() + + self.chunk_id = chunk_id + self.chunk_reader = RandomAccessIterator( + self.reader_class( + [self.read_chunk(chunk_id)[0]], + **(self.reader_params or {}), + ) + ) + return self.chunk_reader + + def unload(self): + self.chunk_id = None + if self.chunk_reader: + self.chunk_reader.close() self.chunk_reader = None - self.reader_class = reader_class - self.get_chunk_path = path_getter - - def load(self, chunk_id): - if self.chunk_id != chunk_id: - self.unload() - - self.chunk_id = chunk_id - self.chunk_reader = RandomAccessIterator( - self.reader_class([self.get_chunk_path(chunk_id)])) - return self.chunk_reader - - def unload(self): - self.chunk_id = None - if self.chunk_reader: - self.chunk_reader.close() - self.chunk_reader = None - - class BuffChunkLoader(ChunkLoader): - def __init__(self, reader_class, path_getter, quality, db_data): - super().__init__(reader_class, path_getter) - self.quality = quality - self.db_data = db_data - - def load(self, chunk_id): - if self.chunk_id != chunk_id: - self.chunk_id = chunk_id - self.chunk_reader = RandomAccessIterator( - self.reader_class([self.get_chunk_path(chunk_id, self.quality, self.db_data)[0]])) - return self.chunk_reader - - def __init__(self, db_data, dimension=DimensionType.DIM_2D): - self._db_data = db_data - self._dimension = dimension - self._loaders = {} - - reader_class = { - DataChoice.IMAGESET: ZipReader, - DataChoice.VIDEO: VideoReader, - } - if db_data.storage_method == StorageMethodChoice.CACHE: - cache = MediaCache(dimension=dimension) - - self._loaders[self.Quality.COMPRESSED] = self.BuffChunkLoader( - reader_class[db_data.compressed_chunk_type], - cache.get_task_chunk_data_with_mime, - self.Quality.COMPRESSED, - self._db_data) - self._loaders[self.Quality.ORIGINAL] = self.BuffChunkLoader( - reader_class[db_data.original_chunk_type], - cache.get_task_chunk_data_with_mime, - self.Quality.ORIGINAL, - self._db_data) - else: - self._loaders[self.Quality.COMPRESSED] = self.ChunkLoader( - reader_class[db_data.compressed_chunk_type], - db_data.get_compressed_chunk_path) - self._loaders[self.Quality.ORIGINAL] = self.ChunkLoader( - reader_class[db_data.original_chunk_type], - db_data.get_original_chunk_path) + @abstractmethod + def read_chunk(self, chunk_id: int) -> DataWithMime: ... - def __len__(self): - return self._db_data.size - def unload(self): - for loader in self._loaders.values(): - loader.unload() +class _FileChunkLoader(_ChunkLoader): + def __init__( + self, + reader_class: Type[IMediaReader], + get_chunk_path_callback: Callable[[int], str], + *, + reader_params: Optional[dict] = None, + ) -> None: + super().__init__(reader_class, reader_params=reader_params) + self.get_chunk_path = get_chunk_path_callback + + def read_chunk(self, chunk_id: int) -> DataWithMime: + chunk_path = self.get_chunk_path(chunk_id) + with open(chunk_path, "rb") as f: + return ( + io.BytesIO(f.read()), + mimetypes.guess_type(chunk_path)[0], + ) + + +class _BufferChunkLoader(_ChunkLoader): + def __init__( + self, + reader_class: Type[IMediaReader], + get_chunk_callback: Callable[[int], DataWithMime], + *, + reader_params: Optional[dict] = None, + ) -> None: + super().__init__(reader_class, reader_params=reader_params) + self.get_chunk = get_chunk_callback + + def read_chunk(self, chunk_id: int) -> DataWithMime: + return self.get_chunk(chunk_id) + - def _validate_frame_number(self, frame_number): - frame_number_ = int(frame_number) - if frame_number_ < 0 or frame_number_ >= self._db_data.size: - raise ValidationError('Incorrect requested frame number: {}'.format(frame_number_)) +class FrameOutputType(Enum): + BUFFER = auto() + PIL = auto() + NUMPY_ARRAY = auto() - chunk_number = frame_number_ // self._db_data.chunk_size - frame_offset = frame_number_ % self._db_data.chunk_size - return frame_number_, chunk_number, frame_offset +Frame2d = Union[BytesIO, np.ndarray, Image.Image] +Frame3d = BytesIO +AnyFrame = Union[Frame2d, Frame3d] - def get_chunk_number(self, frame_number): - return int(frame_number) // self._db_data.chunk_size - def _validate_chunk_number(self, chunk_number): - chunk_number_ = int(chunk_number) - if chunk_number_ < 0 or chunk_number_ >= math.ceil(self._db_data.size / self._db_data.chunk_size): - raise ValidationError('requested chunk does not exist') +@dataclass +class DataWithMeta(Generic[_T]): + data: _T + mime: str - return chunk_number_ + +class IFrameProvider(metaclass=ABCMeta): + VIDEO_FRAME_EXT = ".PNG" + VIDEO_FRAME_MIME = "image/png" + + def unload(self): + pass @classmethod - def _av_frame_to_png_bytes(cls, av_frame): + def _av_frame_to_png_bytes(cls, av_frame: av.VideoFrame) -> BytesIO: ext = cls.VIDEO_FRAME_EXT - image = av_frame.to_ndarray(format='bgr24') + image = av_frame.to_ndarray(format="bgr24") success, result = cv2.imencode(ext, image) if not success: - raise RuntimeError("Failed to encode image to '%s' format" % (ext)) + raise RuntimeError(f"Failed to encode image to '{ext}' format") return BytesIO(result.tobytes()) - def _convert_frame(self, frame, reader_class, out_type): - if out_type == self.Type.BUFFER: - return self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame - elif out_type == self.Type.PIL: - return frame.to_image() if reader_class is VideoReader else Image.open(frame) - elif out_type == self.Type.NUMPY_ARRAY: - if reader_class is VideoReader: - image = frame.to_ndarray(format='bgr24') + def _convert_frame( + self, frame: Any, reader_class: Type[IMediaReader], out_type: FrameOutputType + ) -> AnyFrame: + if out_type == FrameOutputType.BUFFER: + return ( + self._av_frame_to_png_bytes(frame) + if issubclass(reader_class, VideoReader) + else frame + ) + elif out_type == FrameOutputType.PIL: + return frame.to_image() if issubclass(reader_class, VideoReader) else Image.open(frame) + elif out_type == FrameOutputType.NUMPY_ARRAY: + if issubclass(reader_class, VideoReader): + image = frame.to_ndarray(format="bgr24") else: image = np.array(Image.open(frame)) if len(image.shape) == 3 and image.shape[2] in {3, 4}: - image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR + image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR return image else: - raise RuntimeError('unsupported output type') + raise RuntimeError("unsupported output type") + + @abstractmethod + def validate_frame_number(self, frame_number: int) -> int: ... + + @abstractmethod + def validate_chunk_number(self, chunk_number: int) -> int: ... + + @abstractmethod + def get_chunk_number(self, frame_number: int) -> int: ... + + @abstractmethod + def get_preview(self) -> DataWithMeta[BytesIO]: ... + + @abstractmethod + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: ... + + @abstractmethod + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: ... + + @abstractmethod + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: ... + + @abstractmethod + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: ... + + def _get_abs_frame_number(self, db_data: models.Data, rel_frame_number: int) -> int: + return db_data.start_frame + rel_frame_number * db_data.get_frame_step() + + def _get_rel_frame_number(self, db_data: models.Data, abs_frame_number: int) -> int: + return (abs_frame_number - db_data.start_frame) // db_data.get_frame_step() + + +class TaskFrameProvider(IFrameProvider): + def __init__(self, db_task: models.Task) -> None: + self._db_task = db_task + + def validate_frame_number(self, frame_number: int) -> int: + if frame_number not in range(0, self._db_task.data.size): + raise ValidationError( + f"Invalid frame '{frame_number}'. " + f"The frame number should be in the [0, {self._db_task.data.size}] range" + ) + + return frame_number + + def validate_chunk_number(self, chunk_number: int) -> int: + last_chunk = math.ceil(self._db_task.data.size / self._db_task.data.chunk_size) - 1 + if not 0 <= chunk_number <= last_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + f"The chunk number should be in the [0, {last_chunk}] range" + ) + + return chunk_number + + def get_chunk_number(self, frame_number: int) -> int: + return int(frame_number) // self._db_task.data.chunk_size + + def get_abs_frame_number(self, rel_frame_number: int) -> int: + "Returns absolute frame number in the task (in the range [start, stop, step])" + return super()._get_abs_frame_number(self._db_task.data, rel_frame_number) + + def get_rel_frame_number(self, abs_frame_number: int) -> int: + """ + Returns relative frame number in the task (in the range [0, task_size - 1]). + This is the "normal" frame number, expected in other methods. + """ + return super()._get_rel_frame_number(self._db_task.data, abs_frame_number) + + def get_preview(self) -> DataWithMeta[BytesIO]: + return self._get_segment_frame_provider(0).get_preview() + + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: + return_type = DataWithMeta[BytesIO] + chunk_number = self.validate_chunk_number(chunk_number) + + cache = MediaCache() + cached_chunk = cache.get_task_chunk(self._db_task, chunk_number, quality=quality) + if cached_chunk: + return return_type(cached_chunk[0], cached_chunk[1]) + + db_data = self._db_task.data + step = db_data.get_frame_step() + task_chunk_start_frame = chunk_number * db_data.chunk_size + task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1 + task_chunk_frame_set = set( + range( + db_data.start_frame + task_chunk_start_frame * step, + min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step, + step, + ) + ) + + matching_segments: list[models.Segment] = sorted( + [ + s + for s in self._db_task.segment_set.all() + if s.type == models.SegmentType.RANGE + if not task_chunk_frame_set.isdisjoint(s.frame_set) + ], + key=lambda s: s.start_frame, + ) + assert matching_segments + + # Don't put this into set_callback to avoid data duplication in the cache + + if len(matching_segments) == 1: + segment_frame_provider = SegmentFrameProvider(matching_segments[0]) + matching_chunk_index = segment_frame_provider.find_matching_chunk( + sorted(task_chunk_frame_set) + ) + if matching_chunk_index is not None: + # The requested frames match one of the job chunks, we can use it directly + return segment_frame_provider.get_chunk(matching_chunk_index, quality=quality) + + def _set_callback() -> DataWithMime: + # Create and return a joined / cleaned chunk + task_chunk_frames = {} + for db_segment in matching_segments: + segment_frame_provider = SegmentFrameProvider(db_segment) + segment_frame_set = db_segment.frame_set + + for task_chunk_frame_id in sorted(task_chunk_frame_set): + if ( + task_chunk_frame_id not in segment_frame_set + or task_chunk_frame_id in task_chunk_frames + ): + continue + + frame, frame_name, _ = segment_frame_provider._get_raw_frame( + self.get_rel_frame_number(task_chunk_frame_id), quality=quality + ) + task_chunk_frames[task_chunk_frame_id] = (frame, frame_name, None) + + return prepare_chunk( + task_chunk_frames.values(), + quality=quality, + db_task=self._db_task, + dump_unchanged=True, + ) + + buffer, mime_type = cache.get_or_set_task_chunk( + self._db_task, chunk_number, quality=quality, set_callback=_set_callback + ) + + return return_type(data=buffer, mime=mime_type) + + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: + return self._get_segment_frame_provider(frame_number).get_frame( + frame_number, quality=quality, out_type=out_type + ) + + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: + return self._get_segment_frame_provider(frame_number).get_frame_context_images_chunk( + frame_number + ) + + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: + frame_range = itertools.count(start_frame, self._db_task.data.get_frame_step()) + if stop_frame: + frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range) + + db_segment = None + db_segment_frame_set = None + db_segment_frame_provider = None + for idx in frame_range: + if db_segment and idx not in db_segment_frame_set: + db_segment = None + db_segment_frame_set = None + db_segment_frame_provider = None + + if not db_segment: + db_segment = self._get_segment(idx) + db_segment_frame_set = set(db_segment.frame_set) + db_segment_frame_provider = SegmentFrameProvider(db_segment) + + yield db_segment_frame_provider.get_frame(idx, quality=quality, out_type=out_type) + + def _get_segment(self, validated_frame_number: int) -> models.Segment: + if not self._db_task.data or not self._db_task.data.size: + raise ValidationError("Task has no data") + + abs_frame_number = self.get_abs_frame_number(validated_frame_number) + + return next( + s + for s in self._db_task.segment_set.all() + if s.type == models.SegmentType.RANGE + if abs_frame_number in s.frame_set + ) + + def _get_segment_frame_provider(self, frame_number: int) -> SegmentFrameProvider: + return SegmentFrameProvider(self._get_segment(self.validate_frame_number(frame_number))) + + +class SegmentFrameProvider(IFrameProvider): + def __init__(self, db_segment: models.Segment) -> None: + super().__init__() + self._db_segment = db_segment + + db_data = db_segment.task.data + + reader_class: dict[models.DataChoice, Tuple[Type[IMediaReader], Optional[dict]]] = { + models.DataChoice.IMAGESET: (ZipReader, None), + models.DataChoice.VIDEO: ( + VideoReader, + { + "allow_threading": False + # disable threading to avoid unpredictable server + # resource consumption during reading in endpoints + # can be enabled for other clients + }, + ), + } + + self._loaders: dict[FrameQuality, _ChunkLoader] = {} + if ( + db_data.storage_method == models.StorageMethodChoice.CACHE + or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE + # TODO: separate handling, extract cache creation logic from media cache + ): + cache = MediaCache() + + self._loaders[FrameQuality.COMPRESSED] = _BufferChunkLoader( + reader_class=reader_class[db_data.compressed_chunk_type][0], + reader_params=reader_class[db_data.compressed_chunk_type][1], + get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk( + db_segment, chunk_idx, quality=FrameQuality.COMPRESSED + ), + ) + + self._loaders[FrameQuality.ORIGINAL] = _BufferChunkLoader( + reader_class=reader_class[db_data.original_chunk_type][0], + reader_params=reader_class[db_data.original_chunk_type][1], + get_chunk_callback=lambda chunk_idx: cache.get_or_set_segment_chunk( + db_segment, chunk_idx, quality=FrameQuality.ORIGINAL + ), + ) + else: + self._loaders[FrameQuality.COMPRESSED] = _FileChunkLoader( + reader_class=reader_class[db_data.compressed_chunk_type][0], + reader_params=reader_class[db_data.compressed_chunk_type][1], + get_chunk_path_callback=lambda chunk_idx: db_data.get_compressed_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + self._loaders[FrameQuality.ORIGINAL] = _FileChunkLoader( + reader_class=reader_class[db_data.original_chunk_type][0], + reader_params=reader_class[db_data.original_chunk_type][1], + get_chunk_path_callback=lambda chunk_idx: db_data.get_original_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + def unload(self): + for loader in self._loaders.values(): + loader.unload() + + def __len__(self): + return self._db_segment.frame_count + + def validate_frame_number(self, frame_number: int) -> Tuple[int, int, int]: + frame_sequence = list(self._db_segment.frame_set) + abs_frame_number = self._get_abs_frame_number(self._db_segment.task.data, frame_number) + if abs_frame_number not in frame_sequence: + raise ValidationError(f"Incorrect requested frame number: {frame_number}") + + # TODO: maybe optimize search + chunk_number, frame_position = divmod( + frame_sequence.index(abs_frame_number), self._db_segment.task.data.chunk_size + ) + return frame_number, chunk_number, frame_position + + def get_chunk_number(self, frame_number: int) -> int: + return int(frame_number) // self._db_segment.task.data.chunk_size + + def find_matching_chunk(self, frames: Sequence[int]) -> Optional[int]: + return next( + ( + i + for i, chunk_frames in enumerate( + take_by( + sorted(self._db_segment.frame_set), self._db_segment.task.data.chunk_size + ) + ) + if frames == set(chunk_frames) + ), + None, + ) + + def validate_chunk_number(self, chunk_number: int) -> int: + segment_size = self._db_segment.frame_count + last_chunk = math.ceil(segment_size / self._db_segment.task.data.chunk_size) - 1 + if not 0 <= chunk_number <= last_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + f"The chunk number should be in the [0, {last_chunk}] range" + ) + + return chunk_number + + def get_preview(self) -> DataWithMeta[BytesIO]: + cache = MediaCache() + preview, mime = cache.get_or_set_segment_preview(self._db_segment) + return DataWithMeta[BytesIO](preview, mime=mime) + + def get_chunk( + self, chunk_number: int, *, quality: FrameQuality = FrameQuality.ORIGINAL + ) -> DataWithMeta[BytesIO]: + chunk_number = self.validate_chunk_number(chunk_number) + chunk_data, mime = self._loaders[quality].read_chunk(chunk_number) + return DataWithMeta[BytesIO](chunk_data, mime=mime) + + def _get_raw_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + ) -> Tuple[Any, str, Type[IMediaReader]]: + _, chunk_number, frame_offset = self.validate_frame_number(frame_number) + loader = self._loaders[quality] + chunk_reader = loader.load(chunk_number) + frame, frame_name, _ = chunk_reader[frame_offset] + return frame, frame_name, loader.reader_class + + def get_frame( + self, + frame_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> DataWithMeta[AnyFrame]: + return_type = DataWithMeta[AnyFrame] - def get_preview(self, frame_number): - PREVIEW_SIZE = (256, 256) - PREVIEW_MIME = 'image/jpeg' + frame, frame_name, reader_class = self._get_raw_frame(frame_number, quality=quality) - if self._dimension == DimensionType.DIM_3D: - # TODO - preview = Image.open(os.path.join(os.path.dirname(__file__), 'assets/3d_preview.jpeg')) + frame = self._convert_frame(frame, reader_class, out_type) + if issubclass(reader_class, VideoReader): + return return_type(frame, mime=self.VIDEO_FRAME_MIME) + + return return_type(frame, mime=mimetypes.guess_type(frame_name)[0]) + + def get_frame_context_images_chunk( + self, + frame_number: int, + ) -> Optional[DataWithMeta[BytesIO]]: + self.validate_frame_number(frame_number) + + db_data = self._db_segment.task.data + + cache = MediaCache() + if db_data.storage_method == models.StorageMethodChoice.CACHE: + data, mime = cache.get_or_set_frame_context_images_chunk(db_data, frame_number) else: - preview, _ = self.get_frame(frame_number, self.Quality.COMPRESSED, self.Type.PIL) + data, mime = cache.prepare_context_images_chunk(db_data, frame_number) + + if not data.getvalue(): + return None + + return DataWithMeta[BytesIO](data, mime=mime) + + def iterate_frames( + self, + *, + start_frame: Optional[int] = None, + stop_frame: Optional[int] = None, + quality: FrameQuality = FrameQuality.ORIGINAL, + out_type: FrameOutputType = FrameOutputType.BUFFER, + ) -> Iterator[DataWithMeta[AnyFrame]]: + frame_range = itertools.count(start_frame) + if stop_frame: + frame_range = itertools.takewhile(lambda x: x <= stop_frame, frame_range) + + segment_frame_set = set(self._db_segment.frame_set) + for idx in frame_range: + if self._get_abs_frame_number(self._db_segment.task.data, idx) in segment_frame_set: + yield self.get_frame(idx, quality=quality, out_type=out_type) + + +class JobFrameProvider(SegmentFrameProvider): + def __init__(self, db_job: models.Job) -> None: + super().__init__(db_job.segment) + + def get_chunk( + self, + chunk_number: int, + *, + quality: FrameQuality = FrameQuality.ORIGINAL, + is_task_chunk: bool = False, + ) -> DataWithMeta[BytesIO]: + if not is_task_chunk: + return super().get_chunk(chunk_number, quality=quality) + + # Backward compatibility for the "number" parameter + # Reproduce the task chunks, limited by this job + return_type = DataWithMeta[BytesIO] + + task_frame_provider = TaskFrameProvider(self._db_segment.task) + segment_start_chunk = task_frame_provider.get_chunk_number(self._db_segment.start_frame) + segment_stop_chunk = task_frame_provider.get_chunk_number(self._db_segment.stop_frame) + if not segment_start_chunk <= chunk_number <= segment_stop_chunk: + raise ValidationError( + f"Invalid chunk number '{chunk_number}'. " + "The chunk number should be in the " + f"[{segment_start_chunk}, {segment_stop_chunk}] range" + ) + + cache = MediaCache() + cached_chunk = cache.get_segment_task_chunk(self._db_segment, chunk_number, quality=quality) + if cached_chunk: + return return_type(cached_chunk[0], cached_chunk[1]) + + db_data = self._db_segment.task.data + step = db_data.get_frame_step() + task_chunk_start_frame = chunk_number * db_data.chunk_size + task_chunk_stop_frame = (chunk_number + 1) * db_data.chunk_size - 1 + task_chunk_frame_set = set( + range( + db_data.start_frame + task_chunk_start_frame * step, + min(db_data.start_frame + task_chunk_stop_frame * step, db_data.stop_frame) + step, + step, + ) + ) + + # Don't put this into set_callback to avoid data duplication in the cache + matching_chunk = self.find_matching_chunk(sorted(task_chunk_frame_set)) + if matching_chunk is not None: + return self.get_chunk(matching_chunk, quality=quality) + + def _set_callback() -> DataWithMime: + # Create and return a joined / cleaned chunk + segment_chunk_frame_ids = sorted( + task_chunk_frame_set.intersection(self._db_segment.frame_set) + ) + + if self._db_segment.type == models.SegmentType.RANGE: + return cache.prepare_custom_range_segment_chunk( + db_task=self._db_segment.task, + frame_ids=segment_chunk_frame_ids, + quality=quality, + ) + elif self._db_segment.type == models.SegmentType.SPECIFIC_FRAMES: + return cache.prepare_custom_masked_range_segment_chunk( + db_task=self._db_segment.task, + frame_ids=segment_chunk_frame_ids, + chunk_number=chunk_number, + quality=quality, + insert_placeholders=True, + ) + else: + assert False - preview = ImageOps.exif_transpose(preview) - preview.thumbnail(PREVIEW_SIZE) + buffer, mime_type = cache.get_or_set_segment_task_chunk( + self._db_segment, chunk_number, quality=quality, set_callback=_set_callback + ) - output_buf = BytesIO() - preview.convert('RGB').save(output_buf, format="JPEG") + return return_type(data=buffer, mime=mime_type) - return output_buf, PREVIEW_MIME - def get_chunk(self, chunk_number, quality=Quality.ORIGINAL): - chunk_number = self._validate_chunk_number(chunk_number) - if self._db_data.storage_method == StorageMethodChoice.CACHE: - return self._loaders[quality].get_chunk_path(chunk_number, quality, self._db_data) - return self._loaders[quality].get_chunk_path(chunk_number) +@overload +def make_frame_provider(data_source: models.Job) -> JobFrameProvider: ... - def get_frame(self, frame_number, quality=Quality.ORIGINAL, - out_type=Type.BUFFER): - _, chunk_number, frame_offset = self._validate_frame_number(frame_number) - loader = self._loaders[quality] - chunk_reader = loader.load(chunk_number) - frame, frame_name, _ = chunk_reader[frame_offset] - frame = self._convert_frame(frame, loader.reader_class, out_type) - if loader.reader_class is VideoReader: - return (frame, self.VIDEO_FRAME_MIME) - return (frame, mimetypes.guess_type(frame_name)[0]) +@overload +def make_frame_provider(data_source: models.Segment) -> SegmentFrameProvider: ... + + +@overload +def make_frame_provider(data_source: models.Task) -> TaskFrameProvider: ... + - def get_frames(self, start_frame, stop_frame, quality=Quality.ORIGINAL, out_type=Type.BUFFER): - for idx in range(start_frame, stop_frame): - yield self.get_frame(idx, quality=quality, out_type=out_type) +def make_frame_provider( + data_source: Union[models.Job, models.Segment, models.Task, Any] +) -> IFrameProvider: + if isinstance(data_source, models.Task): + frame_provider = TaskFrameProvider(data_source) + elif isinstance(data_source, models.Segment): + frame_provider = SegmentFrameProvider(data_source) + elif isinstance(data_source, models.Job): + frame_provider = JobFrameProvider(data_source) + else: + raise TypeError(f"Unexpected data source type {type(data_source)}") - @property - def data_id(self): - return self._db_data.id + return frame_provider diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py index 5f123d33eef..6f1740e74fd 100644 --- a/cvat/apps/engine/log.py +++ b/cvat/apps/engine/log.py @@ -59,24 +59,31 @@ def get_logger(logger_name, log_file): vlogger = logging.getLogger('vector') + +def get_migration_log_dir() -> str: + return settings.MIGRATIONS_LOGS_ROOT + +def get_migration_log_file_path(migration_name: str) -> str: + return osp.join(get_migration_log_dir(), f'{migration_name}.log') + @contextmanager def get_migration_logger(migration_name): - migration_log_file = '{}.log'.format(migration_name) + migration_log_file_path = get_migration_log_file_path(migration_name) stdout = sys.stdout stderr = sys.stderr + # redirect all stdout to the file - log_file_object = open(osp.join(settings.MIGRATIONS_LOGS_ROOT, migration_log_file), 'w') - sys.stdout = log_file_object - sys.stderr = log_file_object - - log = logging.getLogger(migration_name) - log.addHandler(logging.StreamHandler(stdout)) - log.addHandler(logging.StreamHandler(log_file_object)) - log.setLevel(logging.INFO) - - try: - yield log - finally: - log_file_object.close() - sys.stdout = stdout - sys.stderr = stderr + with open(migration_log_file_path, 'w') as log_file_object: + sys.stdout = log_file_object + sys.stderr = log_file_object + + log = logging.getLogger(migration_name) + log.addHandler(logging.StreamHandler(stdout)) + log.addHandler(logging.StreamHandler(log_file_object)) + log.setLevel(logging.INFO) + + try: + yield log + finally: + sys.stdout = stdout + sys.stderr = stderr diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 9a352c3b930..9ddbad10e3a 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + import os import sysconfig import tempfile @@ -11,12 +13,20 @@ import io import itertools import struct -from enum import IntEnum from abc import ABC, abstractmethod -from contextlib import closing -from typing import Iterable +from bisect import bisect +from contextlib import ExitStack, closing, contextmanager +from dataclasses import dataclass +from enum import IntEnum +from typing import ( + Any, Callable, ContextManager, Generator, Iterable, Iterator, Optional, Protocol, + Sequence, Tuple, TypeVar, Union +) import av +import av.codec +import av.container +import av.video.stream import numpy as np from natsort import os_sorted from pyunpack import Archive @@ -45,6 +55,10 @@ class ORIENTATION(IntEnum): MIRROR_HORIZONTAL_90_ROTATED=7 NORMAL_270_ROTATED=8 +class FrameQuality(IntEnum): + COMPRESSED = 0 + ORIGINAL = 100 + def get_mime(name): for type_name, type_def in MEDIA_TYPES.items(): if type_def['has_mime_type'](name): @@ -78,21 +92,126 @@ def sort(images, sorting_method=SortingMethod.LEXICOGRAPHICAL, func=None): else: raise NotImplementedError() -def image_size_within_orientation(img: Image): +def image_size_within_orientation(img: Image.Image): orientation = img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL) if orientation > 4: return img.height, img.width return img.width, img.height -def has_exif_rotation(img: Image): +def has_exif_rotation(img: Image.Image): return img.getexif().get(ORIENTATION_EXIF_TAG, ORIENTATION.NORMAL_HORIZONTAL) != ORIENTATION.NORMAL_HORIZONTAL +_T = TypeVar("_T") + + +class RandomAccessIterator(Iterator[_T]): + def __init__(self, iterable: Iterable[_T]): + self.iterable: Iterable[_T] = iterable + self.iterator: Optional[Iterator[_T]] = None + self.pos: int = -1 + + def __iter__(self): + return self + + def __next__(self): + return self[self.pos + 1] + + def __getitem__(self, idx: int) -> Optional[_T]: + assert 0 <= idx + if self.iterator is None or idx <= self.pos: + self.reset() + v = None + while self.pos < idx: + # NOTE: don't keep the last item in self, it can be expensive + v = next(self.iterator) + self.pos += 1 + return v + + def reset(self): + self.close() + self.iterator = iter(self.iterable) + + def close(self): + if self.iterator is not None: + if close := getattr(self.iterator, "close", None): + close() + self.iterator = None + self.pos = -1 + + +class Sized(Protocol): + def get_size(self) -> int: ... + +_MediaT = TypeVar("_MediaT", bound=Sized) + +class CachingMediaIterator(RandomAccessIterator[_MediaT]): + @dataclass + class _CacheItem: + value: _MediaT + size: int + + def __init__( + self, + iterable: Iterable, + *, + max_cache_memory: int, + max_cache_entries: int, + object_size_callback: Optional[Callable[[_MediaT], int]] = None, + ): + super().__init__(iterable) + self.max_cache_entries = max_cache_entries + self.max_cache_memory = max_cache_memory + self._get_object_size_callback = object_size_callback + self.used_cache_memory = 0 + self._cache: dict[int, self._CacheItem] = {} + + def _get_object_size(self, obj: _MediaT) -> int: + if self._get_object_size_callback: + return self._get_object_size_callback(obj) + + return obj.get_size() + + def __getitem__(self, idx: int): + cache_item = self._cache.get(idx) + if cache_item: + return cache_item.value + + value = super().__getitem__(idx) + value_size = self._get_object_size(value) + + while ( + len(self._cache) + 1 > self.max_cache_entries or + self.used_cache_memory + value_size > self.max_cache_memory + ): + min_key = min(self._cache.keys()) + self._cache.pop(min_key) + + if self.used_cache_memory + value_size <= self.max_cache_memory: + self._cache[idx] = self._CacheItem(value, value_size) + + return value + + class IMediaReader(ABC): - def __init__(self, source_path, step, start, stop, dimension): + def __init__( + self, + source_path, + *, + start: int = 0, + stop: Optional[int] = None, + step: int = 1, + dimension: DimensionType = DimensionType.DIM_2D + ): self._source_path = source_path + self._step = step + self._start = start + "The first included index" + self._stop = stop + "The last included index" + self._dimension = dimension @abstractmethod @@ -140,30 +259,25 @@ def _get_preview(obj): def get_image_size(self, i): pass - def __len__(self): - return len(self.frame_range) - - @property - def frame_range(self): - return range(self._start, self._stop, self._step) - class ImageListReader(IMediaReader): def __init__(self, - source_path, - step=1, - start=0, - stop=None, - dimension=DimensionType.DIM_2D, - sorting_method=SortingMethod.LEXICOGRAPHICAL): + source_path, + step: int = 1, + start: int = 0, + stop: Optional[int] = None, + dimension: DimensionType = DimensionType.DIM_2D, + sorting_method: SortingMethod = SortingMethod.LEXICOGRAPHICAL, + ): if not source_path: raise Exception('No image found') if not stop: - stop = len(source_path) + stop = len(source_path) - 1 else: - stop = min(len(source_path), stop + 1) + stop = min(len(source_path) - 1, stop) + step = max(step, 1) - assert stop > start + assert stop >= start super().__init__( source_path=sort(source_path, sorting_method), @@ -176,7 +290,7 @@ def __init__(self, self._sorting_method = sorting_method def __iter__(self): - for i in range(self._start, self._stop, self._step): + for i in self.frame_range: yield (self.get_image(i), self.get_path(i), i) def __contains__(self, media_file): @@ -189,7 +303,7 @@ def filter(self, callback): source_path, step=self._step, start=self._start, - stop=self._stop - 1, + stop=self._stop, dimension=self._dimension, sorting_method=self._sorting_method ) @@ -201,7 +315,7 @@ def get_image(self, i): return self._source_path[i] def get_progress(self, pos): - return (pos - self._start + 1) / (self._stop - self._start) + return (pos + 1) / (len(self.frame_range) or 1) def get_preview(self, frame): if self._dimension == DimensionType.DIM_3D: @@ -233,6 +347,13 @@ def reconcile(self, source_files, step=1, start=0, stop=None, dimension=Dimensio def absolute_source_paths(self): return [self.get_path(idx) for idx, _ in enumerate(self._source_path)] + def __len__(self): + return len(self.frame_range) + + @property + def frame_range(self): + return range(self._start, self._stop + 1, self._step) + class DirectoryReader(ImageListReader): def __init__(self, source_path, @@ -403,57 +524,149 @@ def extract(self): if not self.extract_dir: os.remove(self._zip_source.filename) +class _AvVideoReading: + @contextmanager + def read_av_container(self, source: Union[str, io.BytesIO]) -> av.container.InputContainer: + if isinstance(source, io.BytesIO): + source.seek(0) # required for re-reading + + container = av.open(source) + try: + yield container + finally: + # fixes a memory leak in input container closing + # https://github.com/PyAV-Org/PyAV/issues/1117 + for stream in container.streams: + context = stream.codec_context + if context and context.is_open: + context.close() + + if container.open_files: + container.close() + + def decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + demux_iter = container.demux(video_stream) + try: + for packet in demux_iter: + yield from packet.decode() + finally: + # av v9.2.0 seems to have a memory corruption or a deadlock + # in exception handling for demux() in the multithreaded mode. + # Instead of breaking the iteration, we iterate over packets till the end. + # Fixed in av v12.2.0. + if av.__version__ == "9.2.0" and video_stream.thread_type == 'AUTO': + exhausted = object() + while next(demux_iter, exhausted) is not exhausted: + pass + class VideoReader(IMediaReader): - def __init__(self, source_path, step=1, start=0, stop=None, dimension=DimensionType.DIM_2D): + def __init__( + self, + source_path: Union[str, io.BytesIO], + step: int = 1, + start: int = 0, + stop: Optional[int] = None, + dimension: DimensionType = DimensionType.DIM_2D, + *, + allow_threading: bool = True, + ): super().__init__( source_path=source_path, step=step, start=start, - stop=stop + 1 if stop is not None else stop, + stop=stop, dimension=dimension, ) - def _has_frame(self, i): - if i >= self._start: - if (i - self._start) % self._step == 0: - if self._stop is None or i < self._stop: - return True + self.allow_threading = allow_threading + self._frame_count: Optional[int] = None + self._frame_size: Optional[tuple[int, int]] = None # (w, h) - return False + def iterate_frames( + self, + *, + frame_filter: Union[bool, Iterable[int]] = True, + video_stream: Optional[av.video.stream.VideoStream] = None, + ) -> Iterator[Tuple[av.VideoFrame, str, int]]: + """ + If provided, frame_filter must be an ordered sequence in the ascending order. + 'True' means using the frames configured in the reader object. + 'False' or 'None' means returning all the video frames. + """ - def __iter__(self): - with self._get_av_container() as container: - stream = container.streams.video[0] - stream.thread_type = 'AUTO' - frame_num = 0 - for packet in container.demux(stream): - for image in packet.decode(): - frame_num += 1 - if self._has_frame(frame_num - 1): - if packet.stream.metadata.get('rotate'): - pts = image.pts - image = av.VideoFrame().from_ndarray( + if frame_filter is True: + frame_filter = itertools.count(self._start, self._step) + if self._stop: + frame_filter = itertools.takewhile(lambda x: x <= self._stop, frame_filter) + elif not frame_filter: + frame_filter = itertools.count() + + frame_filter_iter = iter(frame_filter) + next_frame_filter_frame = next(frame_filter_iter, None) + if next_frame_filter_frame is None: + return + + es = ExitStack() + + needs_init = video_stream is None + if needs_init: + container = es.enter_context(self._read_av_container()) + else: + container = video_stream.container + + with es: + if needs_init: + video_stream = container.streams.video[0] + + if self.allow_threading: + video_stream.thread_type = 'AUTO' + + frame_counter = itertools.count() + with closing(self._decode_stream(container, video_stream)) as stream_decoder: + for frame, frame_number in zip(stream_decoder, frame_counter): + if frame_number == next_frame_filter_frame: + if video_stream.metadata.get('rotate'): + pts = frame.pts + frame = av.VideoFrame().from_ndarray( rotate_image( - image.to_ndarray(format='bgr24'), - 360 - int(stream.metadata.get('rotate')) + frame.to_ndarray(format='bgr24'), + 360 - int(video_stream.metadata.get('rotate')) ), format ='bgr24' ) - image.pts = pts - yield (image, self._source_path[0], image.pts) + frame.pts = pts + + if self._frame_size is None: + self._frame_size = (frame.width, frame.height) + + yield (frame, self._source_path[0], frame.pts) + + next_frame_filter_frame = next(frame_filter_iter, None) + + if next_frame_filter_frame is None: + return + + def __iter__(self) -> Iterator[Tuple[av.VideoFrame, str, int]]: + return self.iterate_frames() def get_progress(self, pos): duration = self._get_duration() return pos / duration if duration else None - def _get_av_container(self): - if isinstance(self._source_path[0], io.BytesIO): - self._source_path[0].seek(0) # required for re-reading - return av.open(self._source_path[0]) + def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + return _AvVideoReading().read_av_container(self._source_path[0]) + + def _decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + return _AvVideoReading().decode_stream(container, video_stream) def _get_duration(self): - with self._get_av_container() as container: + with self._read_av_container() as container: stream = container.streams.video[0] + duration = None if stream.duration: duration = stream.duration @@ -468,122 +681,128 @@ def _get_duration(self): return duration def get_preview(self, frame): - with self._get_av_container() as container: + with self._read_av_container() as container: stream = container.streams.video[0] + tb_denominator = stream.time_base.denominator needed_time = int((frame / stream.guessed_rate) * tb_denominator) container.seek(offset=needed_time, stream=stream) - for packet in container.demux(stream): - for frame in packet.decode(): - return self._get_preview(frame.to_image() if not stream.metadata.get('rotate') \ - else av.VideoFrame().from_ndarray( - rotate_image( - frame.to_ndarray(format='bgr24'), - 360 - int(container.streams.video[0].metadata.get('rotate')) - ), - format ='bgr24' - ).to_image() - ) + + with closing(self.iterate_frames(video_stream=stream)) as frame_iter: + return self._get_preview(next(frame_iter)) def get_image_size(self, i): - image = (next(iter(self)))[0] - return image.width, image.height + if self._frame_size is not None: + return self._frame_size -class FragmentMediaReader: - def __init__(self, chunk_number, chunk_size, start, stop, step=1): - self._start = start - self._stop = stop + 1 # up to the last inclusive - self._step = step - self._chunk_number = chunk_number - self._chunk_size = chunk_size - self._start_chunk_frame_number = \ - self._start + self._chunk_number * self._chunk_size * self._step - self._end_chunk_frame_number = min(self._start_chunk_frame_number \ - + (self._chunk_size - 1) * self._step + 1, self._stop) - self._frame_range = self._get_frame_range() + with closing(iter(self)) as frame_iter: + frame = next(frame_iter)[0] + self._frame_size = (frame.width, frame.height) - @property - def frame_range(self): - return self._frame_range + return self._frame_size - def _get_frame_range(self): - frame_range = [] - for idx in range(self._start, self._stop, self._step): - if idx < self._start_chunk_frame_number: - continue - elif idx < self._end_chunk_frame_number and \ - not (idx - self._start_chunk_frame_number) % self._step: - frame_range.append(idx) - elif (idx - self._start_chunk_frame_number) % self._step: - continue - else: - break - return frame_range + def get_frame_count(self) -> int: + """ + Returns total frame count in the video + + Note that not all videos provide length / duration metainfo, so the + result may require full video decoding. + + The total count is NOT affected by the frame filtering options of the object, + i.e. start frame, end frame and frame step. + """ + # It's possible to retrieve frame count from the stream.frames, + # but the number may be incorrect. + # https://superuser.com/questions/1512575/why-total-frame-count-is-different-in-ffmpeg-than-ffprobe + if self._frame_count is not None: + return self._frame_count + + frame_count = 0 + for _ in self.iterate_frames(frame_filter=False): + frame_count += 1 + + self._frame_count = frame_count -class ImageDatasetManifestReader(FragmentMediaReader): - def __init__(self, manifest_path, **kwargs): - super().__init__(**kwargs) + return frame_count + + +class ImageReaderWithManifest: + def __init__(self, manifest_path: str): self._manifest = ImageManifestManager(manifest_path) self._manifest.init_index() - def __iter__(self): - for idx in self._frame_range: + def iterate_frames(self, frame_ids: Iterable[int]): + for idx in frame_ids: yield self._manifest[idx] -class VideoDatasetManifestReader(FragmentMediaReader): - def __init__(self, manifest_path, **kwargs): - self.source_path = kwargs.pop('source_path') - super().__init__(**kwargs) - self._manifest = VideoManifestManager(manifest_path) - self._manifest.init_index() +class VideoReaderWithManifest: + # TODO: merge this class with VideoReader + + def __init__(self, manifest_path: str, source_path: str, *, allow_threading: bool = False): + self.source_path = source_path + self.manifest = VideoManifestManager(manifest_path) + if self.manifest.exists: + self.manifest.init_index() + + self.allow_threading = allow_threading + + def _read_av_container(self) -> ContextManager[av.container.InputContainer]: + return _AvVideoReading().read_av_container(self.source_path) + + def _decode_stream( + self, container: av.container.Container, video_stream: av.video.stream.VideoStream + ) -> Generator[av.VideoFrame, None, None]: + return _AvVideoReading().decode_stream(container, video_stream) - def _get_nearest_left_key_frame(self): - if self._start_chunk_frame_number >= \ - self._manifest[len(self._manifest) - 1].get('number'): - left_border = len(self._manifest) - 1 + def _get_nearest_left_key_frame(self, frame_id: int) -> tuple[int, int]: + nearest_left_keyframe_pos = bisect( + self.manifest, frame_id, key=lambda entry: entry.get('number') + ) + if nearest_left_keyframe_pos: + frame_number = self.manifest[nearest_left_keyframe_pos - 1].get('number') + timestamp = self.manifest[nearest_left_keyframe_pos - 1].get('pts') else: - left_border = 0 - delta = len(self._manifest) - while delta: - step = delta // 2 - cur_position = left_border + step - if self._manifest[cur_position].get('number') < self._start_chunk_frame_number: - cur_position += 1 - left_border = cur_position - delta -= step + 1 - else: - delta = step - if self._manifest[cur_position].get('number') > self._start_chunk_frame_number: - left_border -= 1 - frame_number = self._manifest[left_border].get('number') - timestamp = self._manifest[left_border].get('pts') + frame_number = 0 + timestamp = 0 return frame_number, timestamp - def __iter__(self): - start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame() - with closing(av.open(self.source_path, mode='r')) as container: - video_stream = next(stream for stream in container.streams if stream.type == 'video') - video_stream.thread_type = 'AUTO' + def iterate_frames(self, *, frame_filter: Iterable[int]) -> Iterable[av.VideoFrame]: + "frame_ids must be an ordered sequence in the ascending order" + + frame_filter_iter = iter(frame_filter) + next_frame_filter_frame = next(frame_filter_iter, None) + if next_frame_filter_frame is None: + return + + start_decode_frame_number, start_decode_timestamp = self._get_nearest_left_key_frame( + next_frame_filter_frame + ) + + with self._read_av_container() as container: + video_stream = container.streams.video[0] + if self.allow_threading: + video_stream.thread_type = 'AUTO' container.seek(offset=start_decode_timestamp, stream=video_stream) - frame_number = start_decode_frame_number - 1 - for packet in container.demux(video_stream): - for frame in packet.decode(): - frame_number += 1 - if frame_number in self._frame_range: + frame_counter = itertools.count(start_decode_frame_number) + with closing(self._decode_stream(container, video_stream)) as stream_decoder: + for frame, frame_number in zip(stream_decoder, frame_counter): + if frame_number == next_frame_filter_frame: if video_stream.metadata.get('rotate'): frame = av.VideoFrame().from_ndarray( rotate_image( frame.to_ndarray(format='bgr24'), - 360 - int(container.streams.video[0].metadata.get('rotate')) + 360 - int(video_stream.metadata.get('rotate')) ), format ='bgr24' ) + yield frame - elif frame_number < self._frame_range[-1]: - continue - else: + + next_frame_filter_frame = next(frame_filter_iter, None) + + if next_frame_filter_frame is None: return class IChunkWriter(ABC): @@ -648,33 +867,37 @@ class ZipChunkWriter(IChunkWriter): POINT_CLOUD_EXT = 'pcd' def _write_pcd_file(self, image: str|io.BytesIO) -> tuple[io.BytesIO, str, int, int]: - image_buf = open(image, "rb") if isinstance(image, str) else image - try: + with ExitStack() as es: + if isinstance(image, str): + image_buf = es.enter_context(open(image, "rb")) + else: + image_buf = image + properties = ValidateDimension.get_pcd_properties(image_buf) w, h = int(properties["WIDTH"]), int(properties["HEIGHT"]) image_buf.seek(0, 0) return io.BytesIO(image_buf.read()), self.POINT_CLOUD_EXT, w, h - finally: - if isinstance(image, str): - image_buf.close() - def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str): + def save_as_chunk(self, images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str): with zipfile.ZipFile(chunk_path, 'x') as zip_chunk: for idx, (image, path, _) in enumerate(images): ext = os.path.splitext(path)[1].replace('.', '') - output = io.BytesIO() + if self._dimension == DimensionType.DIM_2D: # current version of Pillow applies exif rotation immediately when TIFF image opened # and it removes rotation tag after that # so, has_exif_rotation(image) will return False for TIFF images even if they were actually rotated # and original files will be added to the archive (without applied rotation) # that is why we need the second part of the condition - if has_exif_rotation(image) or image.format == 'TIFF': + if isinstance(image, Image.Image) and ( + has_exif_rotation(image) or image.format == 'TIFF' + ): + output = io.BytesIO() rot_image = ImageOps.exif_transpose(image) try: if image.format == 'TIFF': # https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html - # use loseless lzw compression for tiff images + # use lossless lzw compression for tiff images rot_image.save(output, format='TIFF', compression='tiff_lzw') else: rot_image.save( @@ -686,16 +909,22 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s ) finally: rot_image.close() + elif isinstance(image, io.IOBase): + output = image else: output = path else: - output, ext = self._write_pcd_file(path)[0:2] - arcname = '{:06d}.{}'.format(idx, ext) + if isinstance(image, io.BytesIO): + output, ext = self._write_pcd_file(image)[0:2] + else: + output, ext = self._write_pcd_file(path)[0:2] + arcname = '{:06d}.{}'.format(idx, ext) if isinstance(output, io.BytesIO): zip_chunk.writestr(arcname, output.getvalue()) else: zip_chunk.write(filename=output, arcname=arcname) + # return empty list because ZipChunkWriter write files as is # and does not decode it to know img size. return [] @@ -703,7 +932,7 @@ def save_as_chunk(self, images: Iterable[tuple[Image.Image|io.IOBase|str, str, s class ZipCompressedChunkWriter(ZipChunkWriter): def save_as_chunk( self, - images: Iterable[tuple[Image.Image|io.IOBase|str, str, str]], + images: Iterator[tuple[Image.Image|io.IOBase|str, str, str]], chunk_path: str, *, compress_frames: bool = True, zip_compress_level: int = 0 ): image_sizes = [] @@ -719,7 +948,11 @@ def save_as_chunk( w, h = img.size extension = self.IMAGE_EXT else: - image_buf, extension, w, h = self._write_pcd_file(path) + if isinstance(image, io.BytesIO): + image_buf, extension, w, h = self._write_pcd_file(image) + else: + image_buf, extension, w, h = self._write_pcd_file(path) + image_sizes.append((w, h)) arcname = '{:06d}.{}'.format(idx, extension) zip_chunk.writestr(arcname, image_buf.getvalue()) @@ -751,7 +984,7 @@ def __init__(self, quality=67): "preset": "ultrafast", } - def _add_video_stream(self, container, w, h, rate, options): + def _add_video_stream(self, container: av.container.OutputContainer, w, h, rate, options): # x264 requires width and height must be divisible by 2 for yuv420p if h % 2: h += 1 @@ -772,12 +1005,28 @@ def _add_video_stream(self, container, w, h, rate, options): return video_stream - def save_as_chunk(self, images, chunk_path): - if not images: + FrameDescriptor = Tuple[av.VideoFrame, Any, Any] + + def _peek_first_frame( + self, frame_iter: Iterator[FrameDescriptor] + ) -> Tuple[Optional[FrameDescriptor], Iterator[FrameDescriptor]]: + "Gets the first frame and returns the same full iterator" + + if not hasattr(frame_iter, '__next__'): + frame_iter = iter(frame_iter) + + first_frame = next(frame_iter, None) + return first_frame, itertools.chain((first_frame, ), frame_iter) + + def save_as_chunk( + self, images: Iterator[FrameDescriptor], chunk_path: str + ) -> Sequence[Tuple[int, int]]: + first_frame, images = self._peek_first_frame(images) + if not first_frame: raise Exception('no images to save') - input_w = images[0][0].width - input_h = images[0][0].height + input_w = first_frame[0].width + input_h = first_frame[0].height with av.open(chunk_path, 'w', format=self.FORMAT) as output_container: output_v_stream = self._add_video_stream( @@ -788,11 +1037,15 @@ def save_as_chunk(self, images, chunk_path): options=self._codec_opts, ) - self._encode_images(images, output_container, output_v_stream) + with closing(output_v_stream): + self._encode_images(images, output_container, output_v_stream) + return [(input_w, input_h)] @staticmethod - def _encode_images(images, container, stream): + def _encode_images( + images, container: av.container.OutputContainer, stream: av.video.stream.VideoStream + ): for frame, _, _ in images: # let libav set the correct pts and time_base frame.pts = None @@ -818,11 +1071,12 @@ def __init__(self, quality): } def save_as_chunk(self, images, chunk_path): - if not images: + first_frame, images = self._peek_first_frame(images) + if not first_frame: raise Exception('no images to save') - input_w = images[0][0].width - input_h = images[0][0].height + input_w = first_frame[0].width + input_h = first_frame[0].height downscale_factor = 1 while input_h / downscale_factor >= 1080: @@ -840,7 +1094,9 @@ def save_as_chunk(self, images, chunk_path): options=self._codec_opts, ) - self._encode_images(images, output_container, output_v_stream) + with closing(output_v_stream): + self._encode_images(images, output_container, output_v_stream) + return [(input_w, input_h)] def _is_archive(path): diff --git a/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py new file mode 100644 index 00000000000..8ef887d4c54 --- /dev/null +++ b/cvat/apps/engine/migrations/0083_move_to_segment_chunks.py @@ -0,0 +1,118 @@ +# Generated by Django 4.2.13 on 2024-08-12 09:49 + +import os +from itertools import islice +from typing import Iterable, TypeVar + +from django.db import migrations + +from cvat.apps.engine.log import get_migration_log_dir, get_migration_logger + +T = TypeVar("T") + + +def take_by(iterable: Iterable[T], count: int) -> Iterable[T]: + """ + Returns elements from the input iterable by batches of N items. + ('abcdefg', 3) -> ['a', 'b', 'c'], ['d', 'e', 'f'], ['g'] + """ + + it = iter(iterable) + while True: + batch = list(islice(it, count)) + if len(batch) == 0: + break + + yield batch + + +def get_migration_name() -> str: + return os.path.splitext(os.path.basename(__file__))[0] + + +def get_updated_ids_filename(log_dir: str, migration_name: str) -> str: + return os.path.join(log_dir, migration_name + "-data_ids.log") + + +MIGRATION_LOG_HEADER = ( + 'The following Data ids have been switched from using "filesystem" chunk storage ' 'to "cache":' +) + + +def switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor): + migration_name = get_migration_name() + migration_log_dir = get_migration_log_dir() + with get_migration_logger(migration_name) as common_logger: + Data = apps.get_model("engine", "Data") + + data_with_static_cache_query = Data.objects.filter(storage_method="file_system") + + data_with_static_cache_ids = list( + v[0] + for v in ( + data_with_static_cache_query.order_by("id") + .values_list("id") + .iterator(chunk_size=100000) + ) + ) + + data_with_static_cache_query.update(storage_method="cache") + + updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name) + with open(updated_ids_filename, "w") as data_ids_file: + print(MIGRATION_LOG_HEADER, file=data_ids_file) + + for data_id in data_with_static_cache_ids: + print(data_id, file=data_ids_file) + + common_logger.info( + "Information about migrated tasks is available in the migration log file: " + "{}. You will need to remove data manually for these tasks.".format( + updated_ids_filename + ) + ) + + +def revert_switch_tasks_with_static_chunks_to_dynamic_chunks(apps, schema_editor): + migration_name = get_migration_name() + migration_log_dir = get_migration_log_dir() + + updated_ids_filename = get_updated_ids_filename(migration_log_dir, migration_name) + if not os.path.isfile(updated_ids_filename): + raise FileNotFoundError( + "Can't revert the migration: can't file forward migration logfile at " + f"'{updated_ids_filename}'." + ) + + with open(updated_ids_filename, "r") as data_ids_file: + header = data_ids_file.readline().strip() + if header != MIGRATION_LOG_HEADER: + raise ValueError( + "Can't revert the migration: the migration log file has unexpected header" + ) + + forward_updated_ids = tuple(map(int, data_ids_file)) + + if not forward_updated_ids: + return + + Data = apps.get_model("engine", "Data") + + for id_batch in take_by(forward_updated_ids, 1000): + Data.objects.filter(storage_method="cache", id__in=id_batch).update( + storage_method="file_system" + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("engine", "0082_alter_labeledimage_job_and_more"), + ] + + operations = [ + migrations.RunPython( + switch_tasks_with_static_chunks_to_dynamic_chunks, + reverse_code=revert_switch_tasks_with_static_chunks_to_dynamic_chunks, + ) + ] diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index eda765e6beb..c57eb0371d5 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -252,6 +252,13 @@ def get_data_dirname(self): def get_upload_dirname(self): return os.path.join(self.get_data_dirname(), "raw") + def get_raw_data_dirname(self) -> str: + return { + StorageChoice.LOCAL: self.get_upload_dirname(), + StorageChoice.SHARE: settings.SHARE_ROOT, + StorageChoice.CLOUD_STORAGE: self.get_upload_dirname(), + }[self.storage] + def get_compressed_cache_dirname(self): return os.path.join(self.get_data_dirname(), "compressed") @@ -259,7 +266,7 @@ def get_original_cache_dirname(self): return os.path.join(self.get_data_dirname(), "original") @staticmethod - def _get_chunk_name(chunk_number, chunk_type): + def _get_chunk_name(segment_id: int, chunk_number: int, chunk_type: DataChoice | str) -> str: if chunk_type == DataChoice.VIDEO: ext = 'mp4' elif chunk_type == DataChoice.IMAGESET: @@ -267,21 +274,21 @@ def _get_chunk_name(chunk_number, chunk_type): else: ext = 'list' - return '{}.{}'.format(chunk_number, ext) + return 'segment_{}-{}.{}'.format(segment_id, chunk_number, ext) - def _get_compressed_chunk_name(self, chunk_number): - return self._get_chunk_name(chunk_number, self.compressed_chunk_type) + def _get_compressed_chunk_name(self, segment_id: int, chunk_number: int) -> str: + return self._get_chunk_name(segment_id, chunk_number, self.compressed_chunk_type) - def _get_original_chunk_name(self, chunk_number): - return self._get_chunk_name(chunk_number, self.original_chunk_type) + def _get_original_chunk_name(self, segment_id: int, chunk_number: int) -> str: + return self._get_chunk_name(segment_id, chunk_number, self.original_chunk_type) - def get_original_chunk_path(self, chunk_number): + def get_original_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str: return os.path.join(self.get_original_cache_dirname(), - self._get_original_chunk_name(chunk_number)) + self._get_original_chunk_name(segment_id, chunk_number)) - def get_compressed_chunk_path(self, chunk_number): + def get_compressed_segment_chunk_path(self, chunk_number: int, segment_id: int) -> str: return os.path.join(self.get_compressed_cache_dirname(), - self._get_compressed_chunk_name(chunk_number)) + self._get_compressed_chunk_name(segment_id, chunk_number)) def get_manifest_path(self): return os.path.join(self.get_upload_dirname(), 'manifest.jsonl') @@ -600,7 +607,7 @@ def __str__(self): class Segment(models.Model): # Common fields - task = models.ForeignKey(Task, on_delete=models.CASCADE) + task = models.ForeignKey(Task, on_delete=models.CASCADE) # TODO: add related name start_frame = models.IntegerField() stop_frame = models.IntegerField() type = models.CharField(choices=SegmentType.choices(), default=SegmentType.RANGE, max_length=32) diff --git a/cvat/apps/engine/pyproject.toml b/cvat/apps/engine/pyproject.toml new file mode 100644 index 00000000000..567b7836258 --- /dev/null +++ b/cvat/apps/engine/pyproject.toml @@ -0,0 +1,12 @@ +[tool.isort] +profile = "black" +forced_separate = ["tests"] +line_length = 100 +skip_gitignore = true # align tool behavior with Black +known_first_party = ["cvat"] + +# Can't just use a pyproject in the root dir, so duplicate +# https://github.com/psf/black/issues/2863 +[tool.black] +line-length = 100 +target-version = ['py38'] diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 9d66b1716c1..ed937a993ff 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -594,6 +594,7 @@ class JobReadSerializer(serializers.ModelSerializer): dimension = serializers.CharField(max_length=2, source='segment.task.dimension', read_only=True) data_chunk_size = serializers.ReadOnlyField(source='segment.task.data.chunk_size') organization = serializers.ReadOnlyField(source='segment.task.organization.id', allow_null=True) + data_original_chunk_type = serializers.ReadOnlyField(source='segment.task.data.original_chunk_type') data_compressed_chunk_type = serializers.ReadOnlyField(source='segment.task.data.compressed_chunk_type') mode = serializers.ReadOnlyField(source='segment.task.mode') bug_tracker = serializers.CharField(max_length=2000, source='get_bug_tracker', @@ -607,7 +608,8 @@ class Meta: model = models.Job fields = ('url', 'id', 'task_id', 'project_id', 'assignee', 'guide_id', 'dimension', 'bug_tracker', 'status', 'stage', 'state', 'mode', 'frame_count', - 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type', + 'start_frame', 'stop_frame', + 'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type', 'created_date', 'updated_date', 'issues', 'labels', 'type', 'organization', 'target_storage', 'source_storage', 'assignee_updated_date') read_only_fields = fields diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 0db84cebc32..f24cd686a58 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -1,32 +1,37 @@ # Copyright (C) 2018-2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT +import concurrent.futures import itertools import fnmatch import os -from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union, Iterable -from rest_framework.serializers import ValidationError -import rq import re +import rq import shutil +from contextlib import closing +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union from urllib import parse as urlparse from urllib import request as urlrequest -import django_rq -import concurrent.futures -import queue +import av +import attrs +import django_rq from django.conf import settings from django.db import transaction from django.http import HttpRequest -from datetime import datetime, timezone -from pathlib import Path +from rest_framework.serializers import ValidationError from cvat.apps.engine import models from cvat.apps.engine.log import ServerLogManager -from cvat.apps.engine.media_extractors import (MEDIA_TYPES, ImageListReader, Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, - ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort) +from cvat.apps.engine.media_extractors import ( + MEDIA_TYPES, CachingMediaIterator, IMediaReader, ImageListReader, + Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, RandomAccessIterator, + ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort +) from cvat.apps.engine.models import RequestAction, RequestTarget from cvat.apps.engine.utils import ( av_scan_paths,get_rq_job_meta, define_dependent_job, get_rq_lock_by_user, preload_images @@ -71,6 +76,8 @@ def create( class SegmentParams(NamedTuple): start_frame: int stop_frame: int + type: models.SegmentType = models.SegmentType.RANGE + frames: Optional[Sequence[int]] = [] class SegmentsParams(NamedTuple): segments: Iterator[SegmentParams] @@ -116,7 +123,7 @@ def _copy_data_from_share_point( os.makedirs(target_dir) shutil.copyfile(source_path, target_path) -def _get_task_segment_data( +def _generate_segment_params( db_task: models.Task, *, data_size: Optional[int] = None, @@ -127,10 +134,14 @@ def _segments(): # It is assumed here that files are already saved ordered in the task # Here we just need to create segments by the job sizes start_frame = 0 - for jf in job_file_mapping: - segment_size = len(jf) + for job_files in job_file_mapping: + segment_size = len(job_files) stop_frame = start_frame + segment_size - 1 - yield SegmentParams(start_frame, stop_frame) + yield SegmentParams( + start_frame=start_frame, + stop_frame=stop_frame, + type=models.SegmentType.RANGE, + ) start_frame = stop_frame + 1 @@ -153,31 +164,39 @@ def _segments(): ) segments = ( - SegmentParams(start_frame, min(start_frame + segment_size - 1, data_size - 1)) + SegmentParams( + start_frame=start_frame, + stop_frame=min(start_frame + segment_size - 1, data_size - 1), + type=models.SegmentType.RANGE + ) for start_frame in range(0, data_size - overlap, segment_size - overlap) ) return SegmentsParams(segments, segment_size, overlap) -def _save_task_to_db(db_task: models.Task, *, job_file_mapping: Optional[JobFileMapping] = None): - job = rq.get_current_job() - job.meta['status'] = 'Task is being saved in database' - job.save_meta() +def _create_segments_and_jobs( + db_task: models.Task, + *, + job_file_mapping: Optional[JobFileMapping] = None, +): + rq_job = rq.get_current_job() + rq_job.meta['status'] = 'Task is being saved in database' + rq_job.save_meta() - segments, segment_size, overlap = _get_task_segment_data( - db_task=db_task, job_file_mapping=job_file_mapping + segments, segment_size, overlap = _generate_segment_params( + db_task=db_task, job_file_mapping=job_file_mapping, ) db_task.segment_size = segment_size db_task.overlap = overlap - for segment_idx, (start_frame, stop_frame) in enumerate(segments): - slogger.glob.info("New segment for task #{}: idx = {}, start_frame = {}, \ - stop_frame = {}".format(db_task.id, segment_idx, start_frame, stop_frame)) + for segment_idx, segment_params in enumerate(segments): + slogger.glob.info( + "New segment for task #{task_id}: idx = {segment_idx}, start_frame = {start_frame}, " + "stop_frame = {stop_frame}".format( + task_id=db_task.id, segment_idx=segment_idx, **segment_params._asdict() + )) - db_segment = models.Segment() - db_segment.task = db_task - db_segment.start_frame = start_frame - db_segment.stop_frame = stop_frame + db_segment = models.Segment(task=db_task, **segment_params._asdict()) db_segment.save() db_job = models.Job(segment=db_segment) @@ -322,48 +341,28 @@ def _validate_manifest( *, is_in_cloud: bool, db_cloud_storage: Optional[Any], - data_storage_method: str, - data_sorting_method: str, - isBackupRestore: bool, ) -> Optional[str]: - if manifests: - if len(manifests) != 1: - raise ValidationError('Only one manifest file can be attached to data') - manifest_file = manifests[0] - full_manifest_path = os.path.join(root_dir, manifests[0]) - - if is_in_cloud: - cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage) - # check that cloud storage manifest file exists and is up to date - if not os.path.exists(full_manifest_path) or \ - datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \ - < cloud_storage_instance.get_file_last_modified(manifest_file): - cloud_storage_instance.download_file(manifest_file, full_manifest_path) - - if is_manifest(full_manifest_path): - if not ( - data_sorting_method == models.SortingMethod.PREDEFINED or - (settings.USE_CACHE and data_storage_method == models.StorageMethodChoice.CACHE) or - isBackupRestore or is_in_cloud - ): - cache_disabled_message = "" - if data_storage_method == models.StorageMethodChoice.CACHE and not settings.USE_CACHE: - cache_disabled_message = ( - "This server doesn't allow to use cache for data. " - "Please turn 'use cache' off and try to recreate the task" - ) - slogger.glob.warning(cache_disabled_message) - - raise ValidationError( - "A manifest file can only be used with the 'use cache' option " - "or when 'sorting_method' is 'predefined'" + \ - (". " + cache_disabled_message if cache_disabled_message else "") - ) - return manifest_file + if not manifests: + return None + if len(manifests) != 1: + raise ValidationError('Only one manifest file can be attached to data') + manifest_file = manifests[0] + full_manifest_path = os.path.join(root_dir, manifests[0]) + + if is_in_cloud: + cloud_storage_instance = db_storage_to_storage_instance(db_cloud_storage) + # check that cloud storage manifest file exists and is up to date + if not os.path.exists(full_manifest_path) or ( + datetime.fromtimestamp(os.path.getmtime(full_manifest_path), tz=timezone.utc) \ + < cloud_storage_instance.get_file_last_modified(manifest_file) + ): + cloud_storage_instance.download_file(manifest_file, full_manifest_path) + + if not is_manifest(full_manifest_path): raise ValidationError('Invalid manifest was uploaded') - return None + return manifest_file def _validate_scheme(url): ALLOWED_SCHEMES = ['http', 'https'] @@ -522,18 +521,18 @@ def _create_thread( slogger.glob.info("create task #{}".format(db_task.id)) - job_file_mapping = _validate_job_file_mapping(db_task, data) - - db_data = db_task.data - upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT - is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE - job = rq.get_current_job() def _update_status(msg: str) -> None: job.meta['status'] = msg job.save_meta() + job_file_mapping = _validate_job_file_mapping(db_task, data) + + db_data = db_task.data + upload_dir = db_data.get_upload_dirname() if db_data.storage != models.StorageChoice.SHARE else settings.SHARE_ROOT + is_data_in_cloud = db_data.storage == models.StorageChoice.CLOUD_STORAGE + if data['remote_files'] and not isDatasetImport: data['remote_files'] = _download_data(data['remote_files'], upload_dir) @@ -551,14 +550,17 @@ def _update_status(msg: str) -> None: else: assert False, f"Unknown file storage {db_data.storage}" + if ( + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM and + not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE + ): + db_data.storage_method = models.StorageMethodChoice.CACHE + manifest_file = _validate_manifest( manifest_files, manifest_root, is_in_cloud=is_data_in_cloud, db_cloud_storage=db_data.cloud_storage if is_data_in_cloud else None, - data_storage_method=db_data.storage_method, - data_sorting_method=data['sorting_method'], - isBackupRestore=isBackupRestore, ) manifest = None @@ -668,14 +670,16 @@ def _update_status(msg: str) -> None: is_media_sorted = False if is_data_in_cloud: - # first we need to filter files and keep only supported ones - if any([v for k, v in media.items() if k != 'image']) and db_data.storage_method == models.StorageMethodChoice.CACHE: - # FUTURE-FIXME: This is a temporary workaround for creating tasks - # with unsupported cloud storage data (video, archive, pdf) when use_cache is enabled - db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM - _update_status("The 'use cache' option is ignored") - - if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE: + if ( + # Download remote data if local storage is requested + # TODO: maybe move into cache building to fail faster on invalid task configurations + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or + + # Packed media must be downloaded for task creation + any(v for k, v in media.items() if k != 'image') + ): + _update_status("Downloading input media") + filtered_data = [] for files in (i for i in media.values() if i): filtered_data.extend(files) @@ -690,9 +694,11 @@ def _update_status(msg: str) -> None: step = db_data.get_frame_step() if start_frame or step != 1 or stop_frame != len(filtered_data) - 1: media_to_download = filtered_data[start_frame : stop_frame + 1: step] + _download_data_from_cloud_storage(db_data.cloud_storage, media_to_download, upload_dir) del media_to_download del filtered_data + is_data_in_cloud = False db_data.storage = models.StorageChoice.LOCAL else: @@ -757,7 +763,7 @@ def _update_status(msg: str) -> None: ) # Extract input data - extractor = None + extractor: Optional[IMediaReader] = None manifest_index = _get_manifest_frame_indexer() for media_type, media_files in media.items(): if not media_files: @@ -917,38 +923,9 @@ def _update_status(msg: str) -> None: db_data.compressed_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' and not data['use_zip_chunks'] else models.DataChoice.IMAGESET db_data.original_chunk_type = models.DataChoice.VIDEO if task_mode == 'interpolation' else models.DataChoice.IMAGESET - def update_progress(progress): - progress_animation = '|/-\\' - if not hasattr(update_progress, 'call_counter'): - update_progress.call_counter = 0 - - status_message = 'CVAT is preparing data chunks' - if not progress: - status_message = '{} {}'.format(status_message, progress_animation[update_progress.call_counter]) - job.meta['status'] = status_message - job.meta['task_progress'] = progress or 0. - job.save_meta() - update_progress.call_counter = (update_progress.call_counter + 1) % len(progress_animation) - - compressed_chunk_writer_class = Mpeg4CompressedChunkWriter if db_data.compressed_chunk_type == models.DataChoice.VIDEO else ZipCompressedChunkWriter - if db_data.original_chunk_type == models.DataChoice.VIDEO: - original_chunk_writer_class = Mpeg4ChunkWriter - # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks, which should be visually lossless or nearly so. - # A lower value will significantly increase the chunk size with a slight increase of quality. - original_quality = 67 - else: - original_chunk_writer_class = ZipChunkWriter - original_quality = 100 - - kwargs = {} - if validate_dimension.dimension == models.DimensionType.DIM_3D: - kwargs["dimension"] = validate_dimension.dimension - compressed_chunk_writer = compressed_chunk_writer_class(db_data.image_quality, **kwargs) - original_chunk_writer = original_chunk_writer_class(original_quality, **kwargs) - # calculate chunk size if it isn't specified if db_data.chunk_size is None: - if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter): + if db_data.compressed_chunk_type == models.DataChoice.IMAGESET: first_image_idx = db_data.start_frame if not is_data_in_cloud: w, h = extractor.get_image_size(first_image_idx) @@ -960,206 +937,317 @@ def update_progress(progress): else: db_data.chunk_size = 36 - video_path = "" - video_size = (0, 0) + # TODO: try to pull up + # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl') + if (manifest_file and not os.path.exists(db_data.get_manifest_path())): + shutil.copyfile(os.path.join(manifest_root, manifest_file), + db_data.get_manifest_path()) + if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()): + os.remove(os.path.join(manifest_root, manifest_file)) + manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir) - db_images = [] + # Create task frames from the metadata collected + video_path: str = "" + video_frame_size: tuple[int, int] = (0, 0) - if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE: - for media_type, media_files in media.items(): - if not media_files: - continue + images: list[models.Image] = [] - # replace manifest file (e.g was uploaded 'subdir/manifest.jsonl' or 'some_manifest.jsonl') - if manifest_file and not os.path.exists(db_data.get_manifest_path()): - shutil.copyfile(os.path.join(manifest_root, manifest_file), - db_data.get_manifest_path()) - if manifest_root and manifest_root.startswith(db_data.get_upload_dirname()): - os.remove(os.path.join(manifest_root, manifest_file)) - manifest_file = os.path.relpath(db_data.get_manifest_path(), upload_dir) + for media_type, media_files in media.items(): + if not media_files: + continue - if task_mode == MEDIA_TYPES['video']['mode']: + if task_mode == MEDIA_TYPES['video']['mode']: + if manifest_file: try: - manifest_is_prepared = False - if manifest_file: - try: - manifest = VideoManifestValidator(source_path=os.path.join(upload_dir, media_files[0]), - manifest_path=db_data.get_manifest_path()) - manifest.init_index() - manifest.validate_seek_key_frames() - assert len(manifest) > 0, 'No key frames.' - - all_frames = manifest.video_length - video_size = manifest.video_resolution - manifest_is_prepared = True - except Exception as ex: - manifest.remove() - if isinstance(ex, AssertionError): - base_msg = str(ex) - else: - base_msg = 'Invalid manifest file was upload.' - slogger.glob.warning(str(ex)) - _update_status('{} Start prepare a valid manifest file.'.format(base_msg)) - - if not manifest_is_prepared: - _update_status('Start prepare a manifest file') - manifest = VideoManifestManager(db_data.get_manifest_path()) - manifest.link( - media_file=media_files[0], - upload_dir=upload_dir, - chunk_size=db_data.chunk_size - ) - manifest.create() - _update_status('A manifest had been created') + _update_status('Validating the input manifest file') - all_frames = len(manifest.reader) - video_size = manifest.reader.resolution - manifest_is_prepared = True + manifest = VideoManifestValidator( + source_path=os.path.join(upload_dir, media_files[0]), + manifest_path=db_data.get_manifest_path() + ) + manifest.init_index() + manifest.validate_seek_key_frames() + + if not len(manifest): + raise ValidationError("No key frames found in the manifest") - db_data.size = len(range(db_data.start_frame, min(data['stop_frame'] + 1 \ - if data['stop_frame'] else all_frames, all_frames), db_data.get_frame_step())) - video_path = os.path.join(upload_dir, media_files[0]) except Exception as ex: - db_data.storage_method = models.StorageMethodChoice.FILE_SYSTEM manifest.remove() - del manifest - base_msg = str(ex) if isinstance(ex, AssertionError) \ - else "Uploaded video does not support a quick way of task creating." - _update_status("{} The task will be created using the old method".format(base_msg)) - else: # images, archive, pdf - db_data.size = len(extractor) - manifest = ImageManifestManager(db_data.get_manifest_path()) - - if not manifest.exists: + manifest = None + + if isinstance(ex, (ValidationError, AssertionError)): + base_msg = f"Invalid manifest file was uploaded: {ex}" + else: + base_msg = "Failed to parse the uploaded manifest file" + slogger.glob.warning(ex, exc_info=True) + + _update_status(base_msg) + else: + manifest = None + + if not manifest: + try: + _update_status('Preparing a manifest file') + + # TODO: maybe generate manifest in a temp directory + manifest = VideoManifestManager(db_data.get_manifest_path()) manifest.link( - sources=extractor.absolute_source_paths, - meta={ k: {'related_images': related_images[k] } for k in related_images }, - data_dir=upload_dir, - DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D), + media_file=media_files[0], + upload_dir=upload_dir, + chunk_size=db_data.chunk_size, # TODO: why it's needed here? + force=True ) manifest.create() - else: - manifest.init_index() - counter = itertools.count() - for _, chunk_frames in itertools.groupby(extractor.frame_range, lambda x: next(counter) // db_data.chunk_size): - chunk_paths = [(extractor.get_path(i), i) for i in chunk_frames] - img_sizes = [] - - for chunk_path, frame_id in chunk_paths: - properties = manifest[manifest_index(frame_id)] - - # check mapping - if not chunk_path.endswith(f"{properties['name']}{properties['extension']}"): - raise Exception('Incorrect file mapping to manifest content') - - if db_task.dimension == models.DimensionType.DIM_2D and ( - properties.get('width') is not None and - properties.get('height') is not None - ): - resolution = (properties['width'], properties['height']) - elif is_data_in_cloud: - raise Exception( - "Can't find image '{}' width or height info in the manifest" - .format(f"{properties['name']}{properties['extension']}") - ) - else: - resolution = extractor.get_image_size(frame_id) - img_sizes.append(resolution) - - db_images.extend([ - models.Image(data=db_data, - path=os.path.relpath(path, upload_dir), - frame=frame, width=w, height=h) - for (path, frame), (w, h) in zip(chunk_paths, img_sizes) - ]) - if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM or not settings.USE_CACHE: - counter = itertools.count() - generator = itertools.groupby(extractor, lambda _: next(counter) // db_data.chunk_size) - generator = ((idx, list(chunk_data)) for idx, chunk_data in generator) - - def save_chunks( - executor: concurrent.futures.ThreadPoolExecutor, - chunk_idx: int, - chunk_data: Iterable[tuple[str, str, str]]) -> list[tuple[str, int, tuple[int, int]]]: - nonlocal db_data, db_task, extractor, original_chunk_writer, compressed_chunk_writer - if (db_task.dimension == models.DimensionType.DIM_2D and - isinstance(extractor, ( - MEDIA_TYPES['image']['extractor'], - MEDIA_TYPES['zip']['extractor'], - MEDIA_TYPES['pdf']['extractor'], - MEDIA_TYPES['archive']['extractor'], - ))): - chunk_data = preload_images(chunk_data) - - fs_original = executor.submit( - original_chunk_writer.save_as_chunk, - images=chunk_data, - chunk_path=db_data.get_original_chunk_path(chunk_idx) - ) - fs_compressed = executor.submit( - compressed_chunk_writer.save_as_chunk, - images=chunk_data, - chunk_path=db_data.get_compressed_chunk_path(chunk_idx), - ) - fs_original.result() - image_sizes = fs_compressed.result() - # (path, frame, size) - return list((i[0][1], i[0][2], i[1]) for i in zip(chunk_data, image_sizes)) + _update_status('A manifest has been created') - def process_results(img_meta: list[tuple[str, int, tuple[int, int]]]): - nonlocal db_images, db_data, video_path, video_size + except Exception as ex: + manifest.remove() + manifest = None - if db_task.mode == 'annotation': - db_images.extend( - models.Image( - data=db_data, - path=os.path.relpath(frame_path, upload_dir), - frame=frame_number, - width=frame_size[0], - height=frame_size[1]) - for frame_path, frame_number, frame_size in img_meta) + if isinstance(ex, AssertionError): + base_msg = f": {ex}" + else: + base_msg = "" + slogger.glob.warning(ex, exc_info=True) + + _update_status( + f"Failed to create manifest for the uploaded video{base_msg}. " + "A manifest will not be used in this task" + ) + + if manifest: + video_frame_count = manifest.video_length + video_frame_size = manifest.video_resolution else: - video_size = img_meta[0][2] - video_path = img_meta[0][0] + video_frame_count = extractor.get_frame_count() + video_frame_size = extractor.get_image_size(0) + + db_data.size = len(range( + db_data.start_frame, + min( + data['stop_frame'] + 1 if data['stop_frame'] else video_frame_count, + video_frame_count, + ), + db_data.get_frame_step() + )) + video_path = os.path.join(upload_dir, media_files[0]) + else: # images, archive, pdf + db_data.size = len(extractor) - progress = extractor.get_progress(img_meta[-1][1]) - update_progress(progress) + manifest = ImageManifestManager(db_data.get_manifest_path()) + if not manifest.exists: + manifest.link( + sources=extractor.absolute_source_paths, + meta={ + k: {'related_images': related_images[k] } + for k in related_images + }, + data_dir=upload_dir, + DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D), + ) + manifest.create() + else: + manifest.init_index() + + for frame_id in extractor.frame_range: + image_path = extractor.get_path(frame_id) + image_size = None + + if manifest: + image_info = manifest[manifest_index(frame_id)] + + # check mapping + if not image_path.endswith(f"{image_info['name']}{image_info['extension']}"): + raise ValidationError('Incorrect file mapping to manifest content') + + if db_task.dimension == models.DimensionType.DIM_2D and ( + image_info.get('width') is not None and + image_info.get('height') is not None + ): + image_size = (image_info['width'], image_info['height']) + elif is_data_in_cloud: + raise ValidationError( + "Can't find image '{}' width or height info in the manifest" + .format(f"{image_info['name']}{image_info['extension']}") + ) - futures = queue.Queue(maxsize=settings.CVAT_CONCURRENT_CHUNK_PROCESSING) - with concurrent.futures.ThreadPoolExecutor(max_workers=2*settings.CVAT_CONCURRENT_CHUNK_PROCESSING) as executor: - for chunk_idx, chunk_data in generator: - db_data.size += len(chunk_data) - if futures.full(): - process_results(futures.get().result()) - futures.put(executor.submit(save_chunks, executor, chunk_idx, chunk_data)) + if not image_size: + image_size = extractor.get_image_size(frame_id) - while not futures.empty(): - process_results(futures.get().result()) + images.append( + models.Image( + data=db_data, + path=os.path.relpath(image_path, upload_dir), + frame=frame_id, + width=image_size[0], + height=image_size[1], + ) + ) if db_task.mode == 'annotation': - models.Image.objects.bulk_create(db_images) - created_images = models.Image.objects.filter(data_id=db_data.id) + models.Image.objects.bulk_create(images) + images = models.Image.objects.filter(data_id=db_data.id) db_related_files = [ models.RelatedFile(data=image.data, primary_image=image, path=os.path.join(upload_dir, related_file_path)) - for image in created_images + for image in images for related_file_path in related_images.get(image.path, []) ] models.RelatedFile.objects.bulk_create(db_related_files) - db_images = [] else: models.Video.objects.create( data=db_data, path=os.path.relpath(video_path, upload_dir), - width=video_size[0], height=video_size[1]) + width=video_frame_size[0], height=video_frame_size[1] + ) + # validate stop_frame if db_data.stop_frame == 0: db_data.stop_frame = db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step() else: - # validate stop_frame db_data.stop_frame = min(db_data.stop_frame, \ db_data.start_frame + (db_data.size - 1) * db_data.get_frame_step()) slogger.glob.info("Found frames {} for Data #{}".format(db_data.size, db_data.id)) - _save_task_to_db(db_task, job_file_mapping=job_file_mapping) + _create_segments_and_jobs(db_task, job_file_mapping=job_file_mapping) + + if ( + settings.MEDIA_CACHE_ALLOW_STATIC_CACHE and + db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM + ): + _create_static_chunks(db_task, media_extractor=extractor) + +def _create_static_chunks(db_task: models.Task, *, media_extractor: IMediaReader): + @attrs.define + class _ChunkProgressUpdater: + _call_counter: int = attrs.field(default=0, init=False) + _rq_job: rq.job.Job = attrs.field(factory=rq.get_current_job) + + def update_progress(self, progress: float): + progress_animation = '|/-\\' + + status_message = 'CVAT is preparing data chunks' + if not progress: + status_message = '{} {}'.format( + status_message, progress_animation[self._call_counter] + ) + + self._rq_job.meta['status'] = status_message + self._rq_job.meta['task_progress'] = progress or 0. + self._rq_job.save_meta() + + self._call_counter = (self._call_counter + 1) % len(progress_animation) + + def save_chunks( + executor: concurrent.futures.ThreadPoolExecutor, + db_segment: models.Segment, + chunk_idx: int, + chunk_frame_ids: Sequence[int] + ): + chunk_data = [media_iterator[frame_idx] for frame_idx in chunk_frame_ids] + + if ( + db_task.dimension == models.DimensionType.DIM_2D and + isinstance(media_extractor, ( + MEDIA_TYPES['image']['extractor'], + MEDIA_TYPES['zip']['extractor'], + MEDIA_TYPES['pdf']['extractor'], + MEDIA_TYPES['archive']['extractor'], + )) + ): + chunk_data = preload_images(chunk_data) + + # TODO: extract into a class + + fs_original = executor.submit( + original_chunk_writer.save_as_chunk, + images=chunk_data, + chunk_path=db_data.get_original_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + compressed_chunk_writer.save_as_chunk( + images=chunk_data, + chunk_path=db_data.get_compressed_segment_chunk_path( + chunk_idx, segment_id=db_segment.id + ), + ) + + fs_original.result() + + db_data = db_task.data + + if db_data.compressed_chunk_type == models.DataChoice.VIDEO: + compressed_chunk_writer_class = Mpeg4CompressedChunkWriter + else: + compressed_chunk_writer_class = ZipCompressedChunkWriter + + if db_data.original_chunk_type == models.DataChoice.VIDEO: + original_chunk_writer_class = Mpeg4ChunkWriter + + # Let's use QP=17 (that is 67 for 0-100 range) for the original chunks, + # which should be visually lossless or nearly so. + # A lower value will significantly increase the chunk size with a slight increase of quality. + original_quality = 67 # TODO: fix discrepancy in values in different parts of code + else: + original_chunk_writer_class = ZipChunkWriter + original_quality = 100 + + chunk_writer_kwargs = {} + if db_task.dimension == models.DimensionType.DIM_3D: + chunk_writer_kwargs["dimension"] = db_task.dimension + compressed_chunk_writer = compressed_chunk_writer_class( + db_data.image_quality, **chunk_writer_kwargs + ) + original_chunk_writer = original_chunk_writer_class(original_quality, **chunk_writer_kwargs) + + db_segments = db_task.segment_set.all() + + if isinstance(media_extractor, MEDIA_TYPES['video']['extractor']): + def _get_frame_size(frame_tuple: Tuple[av.VideoFrame, Any, Any]) -> int: + # There is no need to be absolutely precise here, + # just need to provide the reasonable upper boundary. + # Return bytes needed for 1 frame + frame = frame_tuple[0] + return frame.width * frame.height * (frame.format.padded_bits_per_pixel // 8) + + # Currently, we only optimize video creation for sequential + # chunks with potential overlap, so parallel processing is likely to + # help only for image datasets + media_iterator = CachingMediaIterator( + media_extractor, + max_cache_memory=2 ** 30, max_cache_entries=db_task.overlap, + object_size_callback=_get_frame_size + ) + else: + media_iterator = RandomAccessIterator(media_extractor) + + with closing(media_iterator): + progress_updater = _ChunkProgressUpdater() + + # TODO: remove 2 * or the configuration option + # TODO: maybe make real multithreading support, currently the code is limited by 1 + # video segment chunk, even if more threads are available + max_concurrency = 2 * settings.CVAT_CONCURRENT_CHUNK_PROCESSING if not isinstance( + media_extractor, MEDIA_TYPES['video']['extractor'] + ) else 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor: + frame_step = db_data.get_frame_step() + for segment_idx, db_segment in enumerate(db_segments): + frame_counter = itertools.count() + for chunk_idx, chunk_frame_ids in ( + (chunk_idx, list(chunk_frame_ids)) + for chunk_idx, chunk_frame_ids in itertools.groupby( + ( + # Convert absolute to relative ids (extractor output positions) + # Extractor will skip frames outside requested + (abs_frame_id - db_data.start_frame) // frame_step + for abs_frame_id in db_segment.frame_set + ), + lambda _: next(frame_counter) // db_data.chunk_size + ) + ): + save_chunks(executor, db_segment, chunk_idx, chunk_frame_ids) + + progress_updater.update_progress(segment_idx / len(db_segments)) diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index ae0200b6a2a..e7ae8ae9ba7 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -1422,7 +1422,13 @@ def _create_task(task_data, media_data): if isinstance(media, io.BytesIO): media.seek(0) response = cls.client.post("/api/tasks/{}/data".format(tid), data=media_data) - assert response.status_code == status.HTTP_202_ACCEPTED + assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = cls.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = cls.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] cls.tasks.append({ @@ -1766,6 +1772,12 @@ def _create_task(task_data, media_data): media.seek(0) response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data) assert response.status_code == status.HTTP_202_ACCEPTED + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = self.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] self.tasks.append({ @@ -2882,6 +2894,12 @@ def _create_task(task_data, media_data): media.seek(0) response = self.client.post("/api/tasks/{}/data".format(tid), data=media_data) assert response.status_code == status.HTTP_202_ACCEPTED + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") + response = self.client.get("/api/tasks/{}".format(tid)) data_id = response.data["data"] self.tasks.append({ @@ -3433,7 +3451,7 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data, expected_compressed_type, expected_original_type, expected_image_sizes, - expected_storage_method=StorageMethodChoice.FILE_SYSTEM, + expected_storage_method=None, expected_uploaded_data_location=StorageChoice.LOCAL, dimension=DimensionType.DIM_2D, expected_task_creation_status_state='Finished', @@ -3448,6 +3466,12 @@ def _test_api_v2_tasks_id_data_spec(self, user, spec, data, if get_status_callback is None: get_status_callback = self._get_task_creation_status + if expected_storage_method is None: + if settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: + expected_storage_method = StorageMethodChoice.FILE_SYSTEM + else: + expected_storage_method = StorageMethodChoice.CACHE + # create task response = self._create_task(user, spec) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -4007,7 +4031,7 @@ def _test_api_v2_tasks_id_data_create_can_use_chunked_local_video(self, user): image_sizes = self._share_image_sizes['test_rotated_90_video.mp4'] self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, - self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.FILE_SYSTEM) + self.ChunkType.VIDEO, image_sizes, StorageMethodChoice.CACHE) def _test_api_v2_tasks_id_data_create_can_use_chunked_cached_local_video(self, user): task_spec = { @@ -4104,7 +4128,6 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u task_data = { "image_quality": 70, - "use_cache": True } manifest_name = "images_manifest_sorted.jsonl" @@ -4115,79 +4138,34 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_and_manifest(self, u for i, fn in enumerate(images + [manifest_name]) }) - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - for copy_data in [True, False]: - with self.subTest(current_function_name(), copy=copy_data): + for use_cache in [True, False]: + task_data['use_cache'] = use_cache + + for copy_data in [True, False]: + with self.subTest(current_function_name(), copy=copy_data, use_cache=use_cache): + task_spec = task_spec_common.copy() + task_spec['name'] = task_spec['name'] + f' copy={copy_data}' + task_data_copy = task_data.copy() + task_data_copy['copy_data'] = copy_data + self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, + self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, + image_sizes, + expected_uploaded_data_location=( + StorageChoice.LOCAL if copy_data else StorageChoice.SHARE + ) + ) + + with self.subTest(current_function_name() + ' file order mismatch', use_cache=use_cache): task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' copy={copy_data}' - task_data['copy_data'] = copy_data - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, + task_spec['name'] = task_spec['name'] + f' mismatching file order' + task_data_copy = task_data.copy() + task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" + self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, - StorageChoice.LOCAL if copy_data else StorageChoice.SHARE) - - with self.subTest(current_function_name() + ' file order mismatch'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' mismatching file order' - task_data_copy = task_data.copy() - task_data_copy[f'server_files[{len(images)}]'] = "images_manifest.jsonl" - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason='Incorrect file mapping to manifest content') - - with self.subTest(current_function_name() + ' without use cache'): - task_spec = task_spec_common.copy() - task_spec['name'] = task_spec['name'] + f' manifest without cache' - task_data_copy = task_data.copy() - task_data_copy['use_cache'] = False - self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data_copy, - self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.CACHE, StorageChoice.SHARE, - expected_task_creation_status_state='Failed', - expected_task_creation_status_reason="A manifest file can only be used with the 'use cache' option") + image_sizes, + expected_uploaded_data_location=StorageChoice.SHARE, + expected_task_creation_status_state='Failed', + expected_task_creation_status_reason='Incorrect file mapping to manifest content') def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sorting(self, user): task_spec = { @@ -4219,7 +4197,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_images_with_predefined_sort task_data = task_data_common.copy() task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4278,7 +4256,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_images_with_predefined_sorti sorting_method=SortingMethod.PREDEFINED) task_data_common["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4339,7 +4317,7 @@ def _test_api_v2_tasks_id_data_create_can_use_server_archive_with_predefined_sor task_data = task_data_common.copy() task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4412,7 +4390,7 @@ def _test_api_v2_tasks_id_data_create_can_use_local_archive_with_predefined_sort sorting_method=SortingMethod.PREDEFINED) task_data["use_cache"] = caching_enabled - if caching_enabled: + if caching_enabled or not settings.MEDIA_CACHE_ALLOW_STATIC_CACHE: storage_method = StorageMethodChoice.CACHE else: storage_method = StorageMethodChoice.FILE_SYSTEM @@ -4590,7 +4568,7 @@ def _send_data_and_fail(*args, **kwargs): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data) with self.subTest(current_function_name() + ' mismatching file sets - extra files'): @@ -4604,7 +4582,7 @@ def _send_data_and_fail(*args, **kwargs): with self.assertRaisesMessage(Exception, "(extra)"): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data_and_fail) with self.subTest(current_function_name() + ' mismatching file sets - missing files'): @@ -4618,7 +4596,7 @@ def _send_data_and_fail(*args, **kwargs): with self.assertRaisesMessage(Exception, "(missing)"): self._test_api_v2_tasks_id_data_spec(user, task_spec, task_data, self.ChunkType.IMAGESET, self.ChunkType.IMAGESET, - image_sizes, StorageMethodChoice.FILE_SYSTEM, StorageChoice.LOCAL, + image_sizes, expected_uploaded_data_location=StorageChoice.LOCAL, send_data_callback=_send_data_and_fail) def _test_api_v2_tasks_id_data_create_can_use_server_rar(self, user): diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py index a67a79109f3..9f000be5d21 100644 --- a/cvat/apps/engine/tests/test_rest_api_3D.py +++ b/cvat/apps/engine/tests/test_rest_api_3D.py @@ -86,9 +86,13 @@ def _create_task(self, data, image_data): assert response.status_code == status.HTTP_201_CREATED, response.status_code tid = response.data["id"] - response = self.client.post("/api/tasks/%s/data" % tid, - data=image_data) + response = self.client.post("/api/tasks/%s/data" % tid, data=image_data) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid) @@ -527,7 +531,7 @@ def test_api_v2_dump_and_upload_annotation(self): for user, edata in list(self.expected_dump_upload.items()): with self.subTest(format=f"{format_name}_{edata['name']}_dump"): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations url = self._generate_url_dump_tasks_annotations(task_id) file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip") @@ -718,7 +722,7 @@ def test_api_v2_export_dataset(self): for user, edata in list(self.expected_dump_upload.items()): with self.subTest(format=f"{format_name}_{edata['name']}_export"): - self._clear_rq_jobs() # clean up from previous tests and iterations + self._clear_temp_data() # clean up from previous tests and iterations url = self._generate_url_dump_dataset(task_id) file_name = osp.join(test_dir, f"{format_name}_{edata['name']}.zip") @@ -740,6 +744,8 @@ def test_api_v2_export_dataset(self): content = io.BytesIO(b"".join(response.streaming_content)) with open(file_name, "wb") as f: f.write(content.getvalue()) - self.assertEqual(osp.exists(file_name), edata['file_exists']) - self._check_dump_content(content, task_ann_prev.data, format_name,related_files=False) + self.assertEqual(osp.exists(file_name), edata['file_exists']) + self._check_dump_content( + content, task_ann_prev.data, format_name, related_files=False + ) diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py index b884b3e9b4c..3d2a533d1e9 100644 --- a/cvat/apps/engine/tests/utils.py +++ b/cvat/apps/engine/tests/utils.py @@ -13,7 +13,7 @@ from django.core.cache import caches from django.http.response import HttpResponse from PIL import Image -from rest_framework.test import APIClient, APITestCase +from rest_framework.test import APITestCase import av import django_rq import numpy as np @@ -92,14 +92,7 @@ def clear_rq_jobs(): class ApiTestBase(APITestCase): - def _clear_rq_jobs(self): - clear_rq_jobs() - - def setUp(self): - super().setUp() - self.client = APIClient() - - def tearDown(self): + def _clear_temp_data(self): # Clear server frame/chunk cache. # The parent class clears DB changes, and it can lead to under-cleaned task data, # which can affect other tests. @@ -112,7 +105,14 @@ def tearDown(self): # Clear any remaining RQ jobs produced by the tests executed self._clear_rq_jobs() - return super().tearDown() + def _clear_rq_jobs(self): + clear_rq_jobs() + + def setUp(self): + self._clear_temp_data() + + super().setUp() + self.client = self.client_class() def generate_image_file(filename, size=(100, 100)): diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 528c8314b67..3cb7e34c5c4 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: MIT +from abc import ABCMeta, abstractmethod import os import os.path as osp import re @@ -12,7 +13,7 @@ from contextlib import suppress from PIL import Image from types import SimpleNamespace -from typing import Optional, Any, Dict, List, cast, Callable, Mapping, Iterable +from typing import Optional, Any, Dict, List, Union, cast, Callable, Mapping, Iterable import traceback import textwrap from collections import namedtuple @@ -58,12 +59,14 @@ from cvat.apps.events.handlers import handle_dataset_import from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import ( + IFrameProvider, TaskFrameProvider, JobFrameProvider, FrameQuality +) from cvat.apps.engine.filters import NonModelSimpleFilter, NonModelOrderingFilter, NonModelJsonLogicFilter from cvat.apps.engine.media_extractors import get_mime from cvat.apps.engine.permissions import AnnotationGuidePermission, get_iam_context from cvat.apps.engine.models import ( - ClientFile, Job, JobType, Label, SegmentType, Task, Project, Issue, Data, + ClientFile, Job, JobType, Label, Task, Project, Issue, Data, Comment, StorageMethodChoice, StorageChoice, CloudProviderChoice, Location, CloudStorage as CloudStorageModel, Asset, AnnotationGuide, RequestStatus, RequestAction, RequestTarget, RequestSubresource @@ -631,19 +634,17 @@ def append_backup_chunk(self, request, file_id): def preview(self, request, pk): self._object = self.get_object() # call check_object_permissions as well - first_task = self._object.tasks.select_related('data__video').order_by('-id').first() + first_task: Optional[models.Task] = self._object.tasks.order_by('-id').first() if not first_task: return HttpResponseNotFound('Project image preview not found') - data_getter = DataChunkGetter( + data_getter = _TaskDataGetter( + db_task=first_task, data_type='preview', data_quality='compressed', - data_num=first_task.data.start_frame, - task_dim=first_task.dimension ) - return data_getter(request, first_task.data.start_frame, - first_task.data.stop_frame, first_task.data) + return data_getter() @staticmethod def _get_rq_response(queue, job_id): @@ -663,80 +664,50 @@ def _get_rq_response(queue, job_id): return response -class DataChunkGetter: - def __init__(self, data_type, data_num, data_quality, task_dim): +class _DataGetter(metaclass=ABCMeta): + def __init__( + self, data_type: str, data_num: Optional[Union[str, int]], data_quality: str + ) -> None: possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') possible_quality_values = ('compressed', 'original') if not data_type or data_type not in possible_data_type_values: raise ValidationError('Data type not specified or has wrong value') elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview': - if data_num is None: + if data_num is None and data_type != 'preview': raise ValidationError('Number is not specified') elif data_quality not in possible_quality_values: raise ValidationError('Wrong quality value') self.type = data_type self.number = int(data_num) if data_num is not None else None - self.quality = FrameProvider.Quality.COMPRESSED \ - if data_quality == 'compressed' else FrameProvider.Quality.ORIGINAL - - self.dimension = task_dim - - def _check_frame_range(self, frame: int): - frame_range = range(self._start, self._stop + 1, self._db_data.get_frame_step()) - if frame not in frame_range: - raise ValidationError( - f'The frame number should be in the [{self._start}, {self._stop}] range' - ) - - def __call__(self, request, start: int, stop: int, db_data: Optional[Data]): - if not db_data: - raise NotFound(detail='Cannot find requested data') + self.quality = FrameQuality.COMPRESSED \ + if data_quality == 'compressed' else FrameQuality.ORIGINAL - self._start = start - self._stop = stop - self._db_data = db_data + @abstractmethod + def _get_frame_provider(self) -> IFrameProvider: ... - frame_provider = FrameProvider(db_data, self.dimension) + def __call__(self): + frame_provider = self._get_frame_provider() try: if self.type == 'chunk': - start_chunk = frame_provider.get_chunk_number(start) - stop_chunk = frame_provider.get_chunk_number(stop) - # pylint: disable=superfluous-parens - if not (start_chunk <= self.number <= stop_chunk): - raise ValidationError('The chunk number should be in the ' + - f'[{start_chunk}, {stop_chunk}] range') - - # TODO: av.FFmpegError processing - if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE: - buff, mime_type = frame_provider.get_chunk(self.number, self.quality) - return HttpResponse(buff.getvalue(), content_type=mime_type) - - # Follow symbol links if the chunk is a link on a real image otherwise - # mimetype detection inside sendfile will work incorrectly. - path = os.path.realpath(frame_provider.get_chunk(self.number, self.quality)) - return sendfile(request, path) + data = frame_provider.get_chunk(self.number, quality=self.quality) + return HttpResponse(data.data.getvalue(), content_type=data.mime) elif self.type == 'frame' or self.type == 'preview': - self._check_frame_range(self.number) - if self.type == 'preview': - cache = MediaCache(self.dimension) - buf, mime = cache.get_local_preview_with_mime(self.number, db_data) + data = frame_provider.get_preview() else: - buf, mime = frame_provider.get_frame(self.number, self.quality) + data = frame_provider.get_frame(self.number, quality=self.quality) - return HttpResponse(buf.getvalue(), content_type=mime) + return HttpResponse(data.data.getvalue(), content_type=data.mime) elif self.type == 'context_image': - self._check_frame_range(self.number) - - cache = MediaCache(self.dimension) - buff, mime = cache.get_frame_context_images(db_data, self.number) - if not buff: + data = frame_provider.get_frame_context_images_chunk(self.number) + if not data: return HttpResponseNotFound() - return HttpResponse(buff, content_type=mime) + + return HttpResponse(data.data, content_type=data.mime) else: return Response(data='unknown data type {}.'.format(self.type), status=status.HTTP_400_BAD_REQUEST) @@ -745,44 +716,78 @@ def __call__(self, request, start: int, stop: int, db_data: Optional[Data]): '\n'.join([str(d) for d in ex.detail]) return Response(data=msg, status=ex.status_code) +class _TaskDataGetter(_DataGetter): + def __init__( + self, + db_task: models.Task, + *, + data_type: str, + data_quality: str, + data_num: Optional[Union[str, int]] = None, + ) -> None: + super().__init__(data_type=data_type, data_num=data_num, data_quality=data_quality) + self._db_task = db_task + + def _get_frame_provider(self) -> TaskFrameProvider: + return TaskFrameProvider(self._db_task) + + +class _JobDataGetter(_DataGetter): + def __init__( + self, + db_job: models.Job, + *, + data_type: str, + data_quality: str, + data_num: Optional[Union[str, int]] = None, + data_index: Optional[Union[str, int]] = None, + ) -> None: + possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') + possible_quality_values = ('compressed', 'original') + + if not data_type or data_type not in possible_data_type_values: + raise ValidationError('Data type not specified or has wrong value') + elif data_type == 'chunk' or data_type == 'frame' or data_type == 'preview': + if data_type == 'chunk': + if data_num is None and data_index is None: + raise ValidationError('Number or Index is not specified') + if data_num is not None and data_index is not None: + raise ValidationError('Number and Index cannot be used together') + elif data_num is None and data_type != 'preview': + raise ValidationError('Number is not specified') + elif data_quality not in possible_quality_values: + raise ValidationError('Wrong quality value') + + self.type = data_type -class JobDataGetter(DataChunkGetter): - def __init__(self, job: Job, data_type, data_num, data_quality): - super().__init__(data_type, data_num, data_quality, task_dim=job.segment.task.dimension) - self.job = job + self.index = int(data_index) if data_index is not None else None + self.number = int(data_num) if data_num is not None else None - def _check_frame_range(self, frame: int): - frame_range = self.job.segment.frame_set - if frame not in frame_range: - raise ValidationError("The frame number doesn't belong to the job") + self.quality = FrameQuality.COMPRESSED \ + if data_quality == 'compressed' else FrameQuality.ORIGINAL - def __call__(self, request, start, stop, db_data): - if self.type == 'chunk' and self.job.segment.type == SegmentType.SPECIFIC_FRAMES: - frame_provider = FrameProvider(db_data, self.dimension) + self._db_job = db_job - start_chunk = frame_provider.get_chunk_number(start) - stop_chunk = frame_provider.get_chunk_number(stop) - # pylint: disable=superfluous-parens - if not (start_chunk <= self.number <= stop_chunk): - raise ValidationError('The chunk number should be in the ' + - f'[{start_chunk}, {stop_chunk}] range') + def _get_frame_provider(self) -> JobFrameProvider: + return JobFrameProvider(self._db_job) - cache = MediaCache() + def __call__(self): + if self.type == 'chunk': + # Reproduce the task chunk indexing + frame_provider = self._get_frame_provider() - if settings.USE_CACHE and db_data.storage_method == StorageMethodChoice.CACHE: - buf, mime = cache.get_selective_job_chunk_data_with_mime( - chunk_number=self.number, quality=self.quality, job=self.job + if self.index is not None: + data = frame_provider.get_chunk( + self.index, quality=self.quality, is_task_chunk=False ) else: - buf, mime = cache.prepare_selective_job_chunk( - chunk_number=self.number, quality=self.quality, db_job=self.job + data = frame_provider.get_chunk( + self.number, quality=self.quality, is_task_chunk=True ) - return HttpResponse(buf.getvalue(), content_type=mime) - + return HttpResponse(data.data.getvalue(), content_type=data.mime) else: - return super().__call__(request, start, stop, db_data) - + return super().__call__() @extend_schema(tags=['tasks']) @extend_schema_view( @@ -1306,11 +1311,10 @@ def data(self, request, pk): data_num = request.query_params.get('number', None) data_quality = request.query_params.get('quality', 'compressed') - data_getter = DataChunkGetter(data_type, data_num, data_quality, - self._object.dimension) - - return data_getter(request, self._object.data.start_frame, - self._object.data.stop_frame, self._object.data) + data_getter = _TaskDataGetter( + self._object, data_type=data_type, data_num=data_num, data_quality=data_quality + ) + return data_getter() @tus_chunk_action(detail=True, suffix_base="data") def append_data_chunk(self, request, pk, file_id): @@ -1651,15 +1655,12 @@ def preview(self, request, pk): if not self._object.data: return HttpResponseNotFound('Task image preview not found') - data_getter = DataChunkGetter( + data_getter = _TaskDataGetter( + db_task=self._object, data_type='preview', data_quality='compressed', - data_num=self._object.data.start_frame, - task_dim=self._object.dimension ) - - return data_getter(request, self._object.data.start_frame, - self._object.data.stop_frame, self._object.data) + return data_getter() @extend_schema(tags=['jobs']) @@ -2026,8 +2027,14 @@ def get_export_callback(self, save_images: bool) -> Callable: OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.STR, enum=['compressed', 'original'], description="Specifies the quality level of the requested data"), - OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, - description="A unique number value identifying chunk or frame"), + OpenApiParameter('number', + location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, + description="A unique number value identifying chunk or frame. " + "The numbers are the same as for the task. " + "Deprecated for chunks in favor of 'index'"), + OpenApiParameter('index', + location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, + description="A unique number value identifying chunk, starts from 0 for each job"), ], responses={ '200': OpenApiResponse(OpenApiTypes.BINARY, description='Data of a specific type'), @@ -2039,12 +2046,15 @@ def data(self, request, pk): db_job = self.get_object() # call check_object_permissions as well data_type = request.query_params.get('type', None) data_num = request.query_params.get('number', None) + data_index = request.query_params.get('index', None) data_quality = request.query_params.get('quality', 'compressed') - data_getter = JobDataGetter(db_job, data_type, data_num, data_quality) - - return data_getter(request, db_job.segment.start_frame, - db_job.segment.stop_frame, db_job.segment.task.data) + data_getter = _JobDataGetter( + db_job, + data_type=data_type, data_quality=data_quality, + data_index=data_index, data_num=data_num + ) + return data_getter() @extend_schema(methods=['GET'], summary='Get metainformation for media files in a job', @@ -2137,15 +2147,12 @@ def metadata(self, request, pk): def preview(self, request, pk): self._object = self.get_object() # call check_object_permissions as well - data_getter = DataChunkGetter( + data_getter = _JobDataGetter( + db_job=self._object, data_type='preview', data_quality='compressed', - data_num=self._object.segment.start_frame, - task_dim=self._object.segment.task.dimension ) - - return data_getter(request, self._object.segment.start_frame, - self._object.segment.stop_frame, self._object.segment.task.data) + return data_getter() @extend_schema(tags=['issues']) @@ -2716,13 +2723,13 @@ def preview(self, request, pk): # The idea is try to define real manifest preview only for the storages that have related manifests # because otherwise it can lead to extra calls to a bucket, that are usually not free. if not db_storage.has_at_least_one_manifest: - result = cache.get_cloud_preview_with_mime(db_storage) + result = cache.get_cloud_preview(db_storage) if not result: return HttpResponseNotFound('Cloud storage preview not found') - return HttpResponse(result[0], result[1]) + return HttpResponse(result[0].getvalue(), result[1]) - preview, mime = cache.get_or_set_cloud_preview_with_mime(db_storage) - return HttpResponse(preview, mime) + preview, mime = cache.get_or_set_cloud_preview(db_storage) + return HttpResponse(preview.getvalue(), mime) except CloudStorageModel.DoesNotExist: message = f"Storage {pk} does not exist" slogger.glob.error(message) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index c86b4eaa61a..e49b93e24f1 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -1,11 +1,10 @@ # Copyright (C) 2021-2022 Intel Corporation -# Copyright (C) 2023 CVAT.ai Corporation +# Copyright (C) 2023-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from collections import OrderedDict from itertools import groupby -from io import BytesIO from typing import Dict, Optional from unittest import mock, skip import json @@ -14,11 +13,11 @@ import requests from django.contrib.auth.models import Group, User from django.http import HttpResponseNotFound, HttpResponseServerError -from PIL import Image from rest_framework import status -from rest_framework.test import APIClient, APITestCase -from cvat.apps.engine.tests.utils import filter_dict, get_paginated_collection +from cvat.apps.engine.tests.utils import ( + ApiTestBase, filter_dict, ForceLogin, generate_image_file, get_paginated_collection +) LAMBDA_ROOT_PATH = '/api/lambda' LAMBDA_FUNCTIONS_PATH = f'{LAMBDA_ROOT_PATH}/functions' @@ -49,34 +48,11 @@ with open(path) as f: functions = json.load(f) - -def generate_image_file(filename, size=(100, 100)): - f = BytesIO() - image = Image.new('RGB', size=size) - image.save(f, 'jpeg') - f.name = filename - f.seek(0) - return f - - -class ForceLogin: - def __init__(self, user, client): - self.user = user - self.client = client - - def __enter__(self): - if self.user: - self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend') - - return self - - def __exit__(self, exception_type, exception_value, traceback): - if self.user: - self.client.logout() - -class _LambdaTestCaseBase(APITestCase): +class _LambdaTestCaseBase(ApiTestBase): def setUp(self): - self.client = APIClient(raise_request_exception=False) + super().setUp() + + self.client = self.client_class(raise_request_exception=False) http_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', side_effect = self._get_data_from_lambda_manager_http) self.addCleanup(http_patcher.stop) @@ -181,6 +157,11 @@ def _create_task(self, task_spec, data, *, owner=None, org_id=None): data=data, QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + rq_id = response.json()["rq_id"] + + response = self.client.get(f"/api/requests/{rq_id}") + assert response.status_code == status.HTTP_200_OK, response.status_code + assert response.json()["status"] == "finished", response.json().get("status") response = self.client.get("/api/tasks/%s" % tid, QUERY_STRING=f'org_id={org_id}' if org_id is not None else None) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 286b8b4cc98..143537985fd 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -1,5 +1,5 @@ # Copyright (C) 2022 Intel Corporation -# Copyright (C) 2022-2023 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -32,9 +32,9 @@ from rest_framework.request import Request import cvat.apps.dataset_manager as dm -from cvat.apps.engine.frame_provider import FrameProvider +from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider from cvat.apps.engine.models import ( - Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget, + Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget ) from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.serializers import LabeledDataSerializer @@ -489,19 +489,19 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes): def _get_image(self, db_task, frame, quality): if quality is None or quality == "original": - quality = FrameProvider.Quality.ORIGINAL + quality = FrameQuality.ORIGINAL elif quality == "compressed": - quality = FrameProvider.Quality.COMPRESSED + quality = FrameQuality.COMPRESSED else: raise ValidationError( '`{}` lambda function was run '.format(self.id) + 'with wrong arguments (quality={})'.format(quality), code=status.HTTP_400_BAD_REQUEST) - frame_provider = FrameProvider(db_task.data) + frame_provider = TaskFrameProvider(db_task) image = frame_provider.get_frame(frame, quality=quality) - return base64.b64encode(image[0].getvalue()).decode('utf-8') + return base64.b64encode(image.data.getvalue()).decode('utf-8') class LambdaQueue: RESULT_TTL = timedelta(minutes=30) diff --git a/cvat/requirements/base.in b/cvat/requirements/base.in index 50723357d27..2bc36c18d8e 100644 --- a/cvat/requirements/base.in +++ b/cvat/requirements/base.in @@ -1,7 +1,13 @@ -r ../../utils/dataset_manifest/requirements.in attrs==21.4.0 + +# This is the last version of av that supports ffmpeg we depend on. +# Changing ffmpeg is undesirable, as there might be video decoding differences +# between versions. +# TODO: try to move to the newer version av==9.2.0 + azure-storage-blob==12.13.0 boto3==1.17.61 clickhouse-connect==0.6.8 diff --git a/cvat/schema.yml b/cvat/schema.yml index ff97755b26c..779b08fe376 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -2322,11 +2322,18 @@ paths: type: integer description: A unique integer value identifying this job. required: true + - in: query + name: index + schema: + type: integer + description: A unique number value identifying chunk, starts from 0 for each + job - in: query name: number schema: type: integer - description: A unique number value identifying chunk or frame + description: A unique number value identifying chunk or frame. The numbers + are the same as for the task. Deprecated for chunks in favor of 'index' - in: query name: quality schema: @@ -8074,6 +8081,10 @@ components: allOf: - $ref: '#/components/schemas/ChunkType' readOnly: true + data_original_chunk_type: + allOf: + - $ref: '#/components/schemas/ChunkType' + readOnly: true created_date: type: string format: date-time diff --git a/dev/format_python_code.sh b/dev/format_python_code.sh index 5b455a296f4..7eff923abb8 100755 --- a/dev/format_python_code.sh +++ b/dev/format_python_code.sh @@ -25,6 +25,9 @@ for paths in \ "cvat/apps/analytics_report" \ "cvat/apps/engine/lazy_list.py" \ "cvat/apps/engine/background.py" \ + "cvat/apps/engine/frame_provider.py" \ + "cvat/apps/engine/cache.py" \ + "cvat/apps/engine/default_settings.py" \ ; do ${BLACK} -- ${paths} ${ISORT} -- ${paths} diff --git a/docker-compose.yml b/docker-compose.yml index 051bd0bfd8c..569e163e9fe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,6 +10,7 @@ x-backend-env: &backend-env CVAT_REDIS_ONDISK_HOST: cvat_redis_ondisk CVAT_REDIS_ONDISK_PORT: 6666 CVAT_LOG_IMPORT_ERRORS: 'true' + CVAT_ALLOW_STATIC_CACHE: '${CVAT_ALLOW_STATIC_CACHE:-no}' DJANGO_LOG_SERVER_HOST: vector DJANGO_LOG_SERVER_PORT: 80 no_proxy: clickhouse,grafana,vector,nuclio,opa,${no_proxy:-} diff --git a/helm-chart/test.values.yaml b/helm-chart/test.values.yaml index 5a5fa8fe6ba..73edaa815d7 100644 --- a/helm-chart/test.values.yaml +++ b/helm-chart/test.values.yaml @@ -27,6 +27,12 @@ cvat: frontend: imagePullPolicy: Never +redis: + master: + # The "flushall" command, which we use in tests, is disabled in helm by default + # https://artifacthub.io/packages/helm/bitnami/redis#redis-master-configuration-parameters + disableCommands: [] + keydb: resources: requests: diff --git a/tests/python/rest_api/test_jobs.py b/tests/python/rest_api/test_jobs.py index 4fbea276e0a..a6cd225a5d5 100644 --- a/tests/python/rest_api/test_jobs.py +++ b/tests/python/rest_api/test_jobs.py @@ -11,7 +11,7 @@ from copy import deepcopy from http import HTTPStatus from io import BytesIO -from itertools import product +from itertools import groupby, product from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -361,7 +361,7 @@ def _test_destroy_job_fails(self, user, job_id, *, expected_status: int, **kwarg assert response.status == expected_status return response - @pytest.mark.usefixtures("restore_cvat_data") + @pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.parametrize("job_type, allow", (("ground_truth", True), ("annotation", False))) def test_destroy_job(self, admin_user, jobs, job_type, allow): job = next(j for j in jobs if j["type"] == job_type) @@ -603,12 +603,8 @@ def test_get_gt_job_in_org_task( self._test_get_job_403(user["username"], job["id"]) -@pytest.mark.usefixtures( - # if the db is restored per test, there are conflicts with the server data cache - # if we don't clean the db, the gt jobs created will be reused, and their - # ids won't conflict - "restore_db_per_class" -) +@pytest.mark.usefixtures("restore_db_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_class") class TestGetGtJobData: def _delete_gt_job(self, user, gt_job_id): with make_api_client(user) as api_client: @@ -636,12 +632,11 @@ def test_can_get_gt_job_meta(self, admin_user, tasks, jobs, task_mode, request): :job_frame_count ] gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id)) with make_api_client(user) as api_client: (gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id) - request.addfinalizer(lambda: self._delete_gt_job(user, gt_job.id)) - # These values are relative to the resulting task frames, unlike meta values assert 0 == gt_job.start_frame assert task_meta.size - 1 == gt_job.stop_frame @@ -691,12 +686,11 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request) task_frame_ids = range(start_frame, stop_frame, frame_step) job_frame_ids = list(task_frame_ids[::3]) gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) with make_api_client(admin_user) as api_client: (gt_job_meta, _) = api_client.jobs_api.retrieve_data_meta(gt_job.id) - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - # These values are relative to the resulting task frames, unlike meta values assert 0 == gt_job.start_frame assert len(task_frame_ids) - 1 == gt_job.stop_frame @@ -717,7 +711,10 @@ def test_can_get_gt_job_meta_with_complex_frame_setup(self, admin_user, request) @pytest.mark.parametrize("task_mode", ["annotation", "interpolation"]) @pytest.mark.parametrize("quality", ["compressed", "original"]) - def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality, request): + @pytest.mark.parametrize("indexing", ["absolute", "relative"]) + def test_can_get_gt_job_chunk( + self, admin_user, tasks, jobs, task_mode, quality, request, indexing + ): user = admin_user job_frame_count = 4 task = next( @@ -734,41 +731,49 @@ def test_can_get_gt_job_chunk(self, admin_user, tasks, jobs, task_mode, quality, (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) frame_step = int(task_meta.frame_filter.split("=")[-1]) if task_meta.frame_filter else 1 - job_frame_ids = list(range(task_meta.start_frame, task_meta.stop_frame, frame_step))[ - :job_frame_count - ] + task_frame_ids = range(task_meta.start_frame, task_meta.stop_frame + 1, frame_step) + rng = np.random.Generator(np.random.MT19937(42)) + job_frame_ids = sorted(rng.choice(task_frame_ids, job_frame_count, replace=False).tolist()) + gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - with make_api_client(admin_user) as api_client: - (chunk_file, response) = api_client.jobs_api.retrieve_data( - gt_job.id, number=0, quality=quality, type="chunk" - ) - assert response.status == HTTPStatus.OK + if indexing == "absolute": + chunk_iter = groupby(task_frame_ids, key=lambda f: f // task_meta.chunk_size) + else: + chunk_iter = groupby(job_frame_ids, key=lambda f: f // task_meta.chunk_size) - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) + for chunk_id, chunk_frames in chunk_iter: + chunk_frames = list(chunk_frames) - frame_range = range( - task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step - ) - included_frames = job_frame_ids + if indexing == "absolute": + kwargs = {"number": chunk_id} + else: + kwargs = {"index": chunk_id} - # The frame count is the same as in the whole range - # with placeholders in the frames outside the job. - # This is required by the UI implementation - with zipfile.ZipFile(chunk_file) as chunk: - assert set(chunk.namelist()) == set("{:06d}.jpeg".format(i) for i in frame_range) + with make_api_client(admin_user) as api_client: + (chunk_file, response) = api_client.jobs_api.retrieve_data( + gt_job.id, **kwargs, quality=quality, type="chunk" + ) + assert response.status == HTTPStatus.OK + + # The frame count is the same as in the whole range + # with placeholders in the frames outside the job. + # This is required by the UI implementation + with zipfile.ZipFile(chunk_file) as chunk: + assert set(chunk.namelist()) == set( + f"{i:06d}.jpeg" for i in range(len(chunk_frames)) + ) - for file_info in chunk.filelist: - with chunk.open(file_info) as image_file: - image = Image.open(image_file) - image_data = np.array(image) + for file_info in chunk.filelist: + with chunk.open(file_info) as image_file: + image = Image.open(image_file) - if int(os.path.splitext(file_info.filename)[0]) not in included_frames: - assert image.size == (1, 1) - assert np.all(image_data == 0), image_data - else: - assert image.size > (1, 1) - assert np.any(image_data != 0) + chunk_frame_id = int(os.path.splitext(file_info.filename)[0]) + if chunk_frames[chunk_frame_id] not in job_frame_ids: + assert image.size == (1, 1) + else: + assert image.size > (1, 1) def _create_gt_job(self, user, task_id, frames): with make_api_client(user) as api_client: @@ -813,6 +818,7 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality, :job_frame_count ] gt_job = self._create_gt_job(admin_user, task_id, job_frame_ids) + request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) frame_range = range( task_meta.start_frame, min(task_meta.stop_frame + 1, task_meta.chunk_size), frame_step @@ -830,15 +836,13 @@ def test_can_get_gt_job_frame(self, admin_user, tasks, jobs, task_mode, quality, _check_status=False, ) assert response.status == HTTPStatus.BAD_REQUEST - assert b"The frame number doesn't belong to the job" in response.data + assert b"Incorrect requested frame number" in response.data (_, response) = api_client.jobs_api.retrieve_data( gt_job.id, number=included_frames[0], quality=quality, type="frame" ) assert response.status == HTTPStatus.OK - request.addfinalizer(lambda: self._delete_gt_job(admin_user, gt_job.id)) - @pytest.mark.usefixtures("restore_db_per_class") class TestListJobs: diff --git a/tests/python/rest_api/test_queues.py b/tests/python/rest_api/test_queues.py index f801e661e42..4ce314b865b 100644 --- a/tests/python/rest_api/test_queues.py +++ b/tests/python/rest_api/test_queues.py @@ -18,7 +18,7 @@ @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestRQQueueWorking: _USER_1 = "admin1" diff --git a/tests/python/rest_api/test_resource_import_export.py b/tests/python/rest_api/test_resource_import_export.py index 833661fcfab..39f4be22a01 100644 --- a/tests/python/rest_api/test_resource_import_export.py +++ b/tests/python/rest_api/test_resource_import_export.py @@ -177,7 +177,7 @@ def test_user_cannot_export_to_cloud_storage_with_specific_location_without_acce @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") class TestImportResourceFromS3(_S3ResourceTest): @pytest.mark.usefixtures("restore_redis_inmem_per_function") @pytest.mark.parametrize("cloud_storage_id", [3]) diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index e849244361f..eda54b8ddd0 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -6,10 +6,14 @@ import io import itertools import json +import math import os import os.path as osp import zipfile +from abc import ABCMeta, abstractmethod +from contextlib import closing from copy import deepcopy +from enum import Enum from functools import partial from http import HTTPStatus from itertools import chain, product @@ -18,8 +22,10 @@ from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory from time import sleep, time -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union +import attrs +import numpy as np import pytest from cvat_sdk import Client, Config, exceptions from cvat_sdk.api_client import models @@ -30,6 +36,7 @@ from cvat_sdk.core.uploading import Uploader from deepdiff import DeepDiff from PIL import Image +from pytest_cases import fixture_ref, parametrize import shared.utils.s3 as s3 from shared.fixtures.init import docker_exec_cvat, kube_exec_cvat @@ -48,6 +55,7 @@ generate_image_files, generate_manifest, generate_video_file, + read_video_file, ) from .utils import ( @@ -903,7 +911,7 @@ def test_uses_subset_name( @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") class TestPostTaskData: _USERNAME = "admin1" @@ -2028,6 +2036,525 @@ def test_create_task_with_cloud_storage_directories_and_default_bucket_prefix( assert task.size == expected_task_size +class _SourceDataType(str, Enum): + images = "images" + video = "video" + + +class _TaskSpec(models.ITaskWriteRequest, models.IDataRequest, metaclass=ABCMeta): + size: int + frame_step: int + source_data_type: _SourceDataType + + @abstractmethod + def read_frame(self, i: int) -> Image.Image: ... + + +@attrs.define +class _TaskSpecBase(_TaskSpec): + _params: Union[Dict, models.TaskWriteRequest] + _data_params: Union[Dict, models.DataRequest] + size: int = attrs.field(kw_only=True) + + @property + def frame_step(self) -> int: + v = getattr(self, "frame_filter", "step=1") + return int(v.split("=")[-1]) + + def __getattr__(self, k: str) -> Any: + notfound = object() + + for params in [self._params, self._data_params]: + if isinstance(params, dict): + v = params.get(k, notfound) + else: + v = getattr(params, k, notfound) + + if v is not notfound: + return v + + raise AttributeError(k) + + +@attrs.define +class _ImagesTaskSpec(_TaskSpecBase): + source_data_type: ClassVar[_SourceDataType] = _SourceDataType.images + + _get_frame: Callable[[int], bytes] = attrs.field(kw_only=True) + + def read_frame(self, i: int) -> Image.Image: + return Image.open(io.BytesIO(self._get_frame(i))) + + +@attrs.define +class _VideoTaskSpec(_TaskSpecBase): + source_data_type: ClassVar[_SourceDataType] = _SourceDataType.video + + _get_video_file: Callable[[], io.IOBase] = attrs.field(kw_only=True) + + def read_frame(self, i: int) -> Image.Image: + with closing(read_video_file(self._get_video_file())) as reader: + for _ in range(i + 1): + frame = next(reader) + + return frame + + +@pytest.mark.usefixtures("restore_db_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_class") +@pytest.mark.usefixtures("restore_cvat_data_per_function") +class TestTaskData: + _USERNAME = "admin1" + + def _uploaded_images_task_fxt_base( + self, + request: pytest.FixtureRequest, + *, + frame_count: int = 10, + segment_size: Optional[int] = None, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + task_params = { + "name": request.node.name, + "labels": [{"name": "a"}], + } + if segment_size: + task_params["segment_size"] = segment_size + + image_files = generate_image_files(frame_count) + images_data = [f.getvalue() for f in image_files] + data_params = { + "image_quality": 70, + "client_files": image_files, + } + + def get_frame(i: int) -> bytes: + return images_data[i] + + task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params) + yield _ImagesTaskSpec( + models.TaskWriteRequest._from_openapi_data(**task_params), + models.DataRequest._from_openapi_data(**data_params), + get_frame=get_frame, + size=len(images_data), + ), task_id + + @pytest.fixture(scope="class") + def fxt_uploaded_images_task( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_images_task_fxt_base(request=request) + + @pytest.fixture(scope="class") + def fxt_uploaded_images_task_with_segments( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_images_task_fxt_base(request=request, segment_size=4) + + def _uploaded_video_task_fxt_base( + self, + request: pytest.FixtureRequest, + *, + frame_count: int = 10, + segment_size: Optional[int] = None, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + task_params = { + "name": request.node.name, + "labels": [{"name": "a"}], + } + if segment_size: + task_params["segment_size"] = segment_size + + video_file = generate_video_file(frame_count) + video_data = video_file.getvalue() + data_params = { + "image_quality": 70, + "client_files": [video_file], + } + + def get_video_file() -> io.BytesIO: + return io.BytesIO(video_data) + + task_id, _ = create_task(self._USERNAME, spec=task_params, data=data_params) + yield _VideoTaskSpec( + models.TaskWriteRequest._from_openapi_data(**task_params), + models.DataRequest._from_openapi_data(**data_params), + get_video_file=get_video_file, + size=frame_count, + ), task_id + + @pytest.fixture(scope="class") + def fxt_uploaded_video_task( + self, + request: pytest.FixtureRequest, + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_video_task_fxt_base(request=request) + + @pytest.fixture(scope="class") + def fxt_uploaded_video_task_with_segments( + self, request: pytest.FixtureRequest + ) -> Generator[Tuple[_TaskSpec, int], None, None]: + yield from self._uploaded_video_task_fxt_base(request=request, segment_size=4) + + def _compute_segment_params(self, task_spec: _TaskSpec) -> List[Tuple[int, int]]: + segment_params = [] + segment_size = getattr(task_spec, "segment_size", 0) or task_spec.size + start_frame = getattr(task_spec, "start_frame", 0) + end_frame = (getattr(task_spec, "stop_frame", None) or (task_spec.size - 1)) + 1 + overlap = min( + ( + getattr(task_spec, "overlap", None) or 0 + if task_spec.source_data_type == _SourceDataType.images + else 5 + ), + segment_size // 2, + ) + segment_start = start_frame + while segment_start < end_frame: + if start_frame < segment_start: + segment_start -= overlap * task_spec.frame_step + + segment_end = segment_start + task_spec.frame_step * segment_size + + segment_params.append((segment_start, min(segment_end, end_frame) - 1)) + segment_start = segment_end + + return segment_params + + @staticmethod + def _compare_images( + expected: Image.Image, actual: Image.Image, *, must_be_identical: bool = True + ): + expected_pixels = np.array(expected) + chunk_frame_pixels = np.array(actual) + assert expected_pixels.shape == chunk_frame_pixels.shape + + if not must_be_identical: + # video chunks can have slightly changed colors, due to codec specifics + # compressed images can also be distorted + assert np.allclose(chunk_frame_pixels, expected_pixels, atol=2) + else: + assert np.array_equal(chunk_frame_pixels, expected_pixels) + + _default_task_cases = [ + fixture_ref("fxt_uploaded_images_task"), + fixture_ref("fxt_uploaded_images_task_with_segments"), + fixture_ref("fxt_uploaded_video_task"), + fixture_ref("fxt_uploaded_video_task_with_segments"), + ] + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_meta(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + assert task_meta.size == task_spec.size + assert task_meta.start_frame == getattr(task_spec, "start_frame", 0) + assert task_meta.stop_frame == getattr(task_spec, "stop_frame", None) or task_spec.size + assert task_meta.frame_filter == getattr(task_spec, "frame_filter", "") + + task_frame_set = set( + range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step) + ) + assert len(task_frame_set) == task_meta.size + + if getattr(task_spec, "chunk_size", None): + assert task_meta.chunk_size == task_spec.chunk_size + + if task_spec.source_data_type == _SourceDataType.video: + assert len(task_meta.frames) == 1 + else: + assert len(task_meta.frames) == task_meta.size + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_frames(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + for quality, abs_frame_id in product( + ["original", "compressed"], + range(task_meta.start_frame, task_meta.stop_frame + 1, task_spec.frame_step), + ): + rel_frame_id = ( + abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step + ) + (_, response) = api_client.tasks_api.retrieve_data( + task_id, + type="frame", + quality=quality, + number=rel_frame_id, + _parse_response=False, + ) + + if task_spec.source_data_type == _SourceDataType.video: + frame_size = (task_meta.frames[0].width, task_meta.frames[0].height) + else: + frame_size = ( + task_meta.frames[rel_frame_id].width, + task_meta.frames[rel_frame_id].height, + ) + + frame = Image.open(io.BytesIO(response.data)) + assert frame_size == frame.size + + self._compare_images( + task_spec.read_frame(abs_frame_id), + frame, + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_task_chunks(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + (task, _) = api_client.tasks_api.retrieve(task_id) + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + if task_spec.source_data_type == _SourceDataType.images: + assert task.data_original_chunk_type == "imageset" + assert task.data_compressed_chunk_type == "imageset" + elif task_spec.source_data_type == _SourceDataType.video: + assert task.data_original_chunk_type == "video" + + if getattr(task_spec, "use_zip_chunks", False): + assert task.data_compressed_chunk_type == "imageset" + else: + assert task.data_compressed_chunk_type == "video" + else: + assert False + + chunk_count = math.ceil(task_meta.size / task_meta.chunk_size) + for quality, chunk_id in product(["original", "compressed"], range(chunk_count)): + expected_chunk_frame_ids = range( + chunk_id * task_meta.chunk_size, + min((chunk_id + 1) * task_meta.chunk_size, task_meta.size), + ) + + (_, response) = api_client.tasks_api.retrieve_data( + task_id, type="chunk", quality=quality, number=chunk_id, _parse_response=False + ) + + chunk_file = io.BytesIO(response.data) + if zipfile.is_zipfile(chunk_file): + with zipfile.ZipFile(chunk_file, "r") as chunk_archive: + chunk_images = { + int(os.path.splitext(name)[0]): np.array( + Image.open(io.BytesIO(chunk_archive.read(name))) + ) + for name in chunk_archive.namelist() + } + chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0])) + else: + chunk_images = dict(enumerate(read_video_file(chunk_file))) + + assert sorted(chunk_images.keys()) == list(range(len(expected_chunk_frame_ids))) + + for chunk_frame, abs_frame_id in zip(chunk_images, expected_chunk_frame_ids): + self._compare_images( + task_spec.read_frame(abs_frame_id), + chunk_images[chunk_frame], + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_job_meta(self, task_spec: _TaskSpec, task_id: int): + segment_params = self._compute_segment_params(task_spec) + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + assert len(jobs) == len(segment_params) + + for (segment_start, segment_end), job in zip(segment_params, jobs): + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + assert (job_meta.start_frame, job_meta.stop_frame) == (segment_start, segment_end) + assert job_meta.frame_filter == getattr(task_spec, "frame_filter", "") + + segment_size = segment_end - segment_start + 1 + assert job_meta.size == segment_size + + task_frame_set = set( + range(job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step) + ) + assert len(task_frame_set) == job_meta.size + + if getattr(task_spec, "chunk_size", None): + assert job_meta.chunk_size == task_spec.chunk_size + + if task_spec.source_data_type == _SourceDataType.video: + assert len(job_meta.frames) == 1 + else: + assert len(job_meta.frames) == job_meta.size + + @parametrize("task_spec, task_id", _default_task_cases) + def test_can_get_job_frames(self, task_spec: _TaskSpec, task_id: int): + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + for job in jobs: + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + for quality, (frame_pos, abs_frame_id) in product( + ["original", "compressed"], + enumerate(range(job_meta.start_frame, job_meta.stop_frame)), + ): + rel_frame_id = ( + abs_frame_id - getattr(task_spec, "start_frame", 0) // task_spec.frame_step + ) + (_, response) = api_client.jobs_api.retrieve_data( + job.id, + type="frame", + quality=quality, + number=rel_frame_id, + _parse_response=False, + ) + + if task_spec.source_data_type == _SourceDataType.video: + frame_size = (job_meta.frames[0].width, job_meta.frames[0].height) + else: + frame_size = ( + job_meta.frames[frame_pos].width, + job_meta.frames[frame_pos].height, + ) + + frame = Image.open(io.BytesIO(response.data)) + assert frame_size == frame.size + + self._compare_images( + task_spec.read_frame(abs_frame_id), + frame, + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @parametrize("task_spec, task_id", _default_task_cases) + @parametrize("indexing", ["absolute", "relative"]) + def test_can_get_job_chunks(self, task_spec: _TaskSpec, task_id: int, indexing: str): + with make_api_client(self._USERNAME) as api_client: + jobs = sorted( + get_paginated_collection(api_client.jobs_api.list_endpoint, task_id=task_id), + key=lambda j: j.start_frame, + ) + + (task_meta, _) = api_client.tasks_api.retrieve_data_meta(task_id) + + for job in jobs: + (job_meta, _) = api_client.jobs_api.retrieve_data_meta(job.id) + + if task_spec.source_data_type == _SourceDataType.images: + assert job.data_original_chunk_type == "imageset" + assert job.data_compressed_chunk_type == "imageset" + elif task_spec.source_data_type == _SourceDataType.video: + assert job.data_original_chunk_type == "video" + + if getattr(task_spec, "use_zip_chunks", False): + assert job.data_compressed_chunk_type == "imageset" + else: + assert job.data_compressed_chunk_type == "video" + else: + assert False + + if indexing == "absolute": + chunk_count = math.ceil(task_meta.size / job_meta.chunk_size) + + def get_task_chunk_abs_frame_ids(chunk_id: int) -> Sequence[int]: + return range( + task_meta.start_frame + + chunk_id * task_meta.chunk_size * task_spec.frame_step, + task_meta.start_frame + + min((chunk_id + 1) * task_meta.chunk_size, task_meta.size) + * task_spec.frame_step, + ) + + def get_job_frame_ids() -> Sequence[int]: + return range( + job_meta.start_frame, job_meta.stop_frame + 1, task_spec.frame_step + ) + + def get_expected_chunk_abs_frame_ids(chunk_id: int): + return sorted( + set(get_task_chunk_abs_frame_ids(chunk_id)) & set(get_job_frame_ids()) + ) + + job_chunk_ids = ( + task_chunk_id + for task_chunk_id in range(chunk_count) + if get_expected_chunk_abs_frame_ids(task_chunk_id) + ) + else: + chunk_count = math.ceil(job_meta.size / job_meta.chunk_size) + job_chunk_ids = range(chunk_count) + + def get_expected_chunk_abs_frame_ids(chunk_id: int): + return sorted( + frame + for frame in range( + job_meta.start_frame + + chunk_id * job_meta.chunk_size * task_spec.frame_step, + job_meta.start_frame + + min((chunk_id + 1) * job_meta.chunk_size, job_meta.size) + * task_spec.frame_step, + ) + if not job_meta.included_frames or frame in job_meta.included_frames + ) + + for quality, chunk_id in product(["original", "compressed"], job_chunk_ids): + expected_chunk_abs_frame_ids = get_expected_chunk_abs_frame_ids(chunk_id) + + kwargs = {} + if indexing == "absolute": + kwargs["number"] = chunk_id + elif indexing == "relative": + kwargs["index"] = chunk_id + else: + assert False + + (_, response) = api_client.jobs_api.retrieve_data( + job.id, + type="chunk", + quality=quality, + **kwargs, + _parse_response=False, + ) + + chunk_file = io.BytesIO(response.data) + if zipfile.is_zipfile(chunk_file): + with zipfile.ZipFile(chunk_file, "r") as chunk_archive: + chunk_images = { + int(os.path.splitext(name)[0]): np.array( + Image.open(io.BytesIO(chunk_archive.read(name))) + ) + for name in chunk_archive.namelist() + } + chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0])) + else: + chunk_images = dict(enumerate(read_video_file(chunk_file))) + + assert sorted(chunk_images.keys()) == list(range(job_meta.size)) + + for chunk_frame, abs_frame_id in zip( + chunk_images, expected_chunk_abs_frame_ids + ): + self._compare_images( + task_spec.read_frame(abs_frame_id), + chunk_images[chunk_frame], + must_be_identical=( + task_spec.source_data_type == _SourceDataType.images + and quality == "original" + ), + ) + + @pytest.mark.usefixtures("restore_db_per_function") class TestPatchTaskLabel: def _get_task_labels(self, pid, user, **kwargs) -> List[models.Label]: @@ -2229,7 +2756,7 @@ def test_admin_can_add_skeleton(self, tasks, admin_user): @pytest.mark.usefixtures("restore_db_per_function") -@pytest.mark.usefixtures("restore_cvat_data") +@pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") class TestWorkWithTask: _USERNAME = "admin1" @@ -2286,7 +2813,13 @@ def _make_client(self) -> Client: return Client(BASE_URL, config=Config(status_check_period=0.01)) @pytest.fixture(autouse=True) - def setup(self, restore_db_per_function, restore_cvat_data, tmp_path: Path, admin_user: str): + def setup( + self, + restore_db_per_function, + restore_cvat_data_per_function, + tmp_path: Path, + admin_user: str, + ): self.tmp_dir = tmp_path self.client = self._make_client() diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index 142c4354c4d..e7ac8418b69 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -29,6 +29,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py index d5fbc0957eb..542ad9a1e80 100644 --- a/tests/python/sdk/test_datasets.py +++ b/tests/python/sdk/test_datasets.py @@ -23,6 +23,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_jobs.py b/tests/python/sdk/test_jobs.py index ef46fcb8cf0..3202e2957ff 100644 --- a/tests/python/sdk/test_jobs.py +++ b/tests/python/sdk/test_jobs.py @@ -29,6 +29,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/sdk/test_projects.py b/tests/python/sdk/test_projects.py index 43d6257c03c..b03df660d87 100644 --- a/tests/python/sdk/test_projects.py +++ b/tests/python/sdk/test_projects.py @@ -32,6 +32,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 722cb37ab00..2bcbd122abf 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -36,6 +36,7 @@ def _common_setup( tmp_path: Path, fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], + restore_redis_ondisk_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_tasks.py b/tests/python/sdk/test_tasks.py index 0dc5c0694e9..54e0823d331 100644 --- a/tests/python/sdk/test_tasks.py +++ b/tests/python/sdk/test_tasks.py @@ -33,6 +33,7 @@ def setup( fxt_login: Tuple[Client, str], fxt_logger: Tuple[Logger, io.StringIO], fxt_stdout: io.StringIO, + restore_redis_ondisk_per_function, ): self.tmp_path = tmp_path logger, self.logger_stream = fxt_logger diff --git a/tests/python/shared/assets/jobs.json b/tests/python/shared/assets/jobs.json index d4add795c78..415fb67d44c 100644 --- a/tests/python/shared/assets/jobs.json +++ b/tests/python/shared/assets/jobs.json @@ -10,6 +10,7 @@ "created_date": "2024-07-15T15:34:53.594000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -51,6 +52,7 @@ "created_date": "2024-07-15T15:33:10.549000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -92,6 +94,7 @@ "created_date": "2024-03-21T20:50:05.838000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 3, "guide_id": null, @@ -125,6 +128,7 @@ "created_date": "2024-03-21T20:50:05.815000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 1, "guide_id": null, @@ -158,6 +162,7 @@ "created_date": "2024-03-21T20:50:05.811000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -191,6 +196,7 @@ "created_date": "2024-03-21T20:50:05.805000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -224,6 +230,7 @@ "created_date": "2023-05-26T16:11:23.946000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 3, "guide_id": null, @@ -257,6 +264,7 @@ "created_date": "2023-05-26T16:11:23.880000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -290,6 +298,7 @@ "created_date": "2023-03-27T19:08:07.649000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 4, "guide_id": null, @@ -331,6 +340,7 @@ "created_date": "2023-03-27T19:08:07.649000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 6, "guide_id": null, @@ -372,6 +382,7 @@ "created_date": "2023-03-10T11:57:31.614000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -413,6 +424,7 @@ "created_date": "2023-03-10T11:56:33.757000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -454,6 +466,7 @@ "created_date": "2023-03-01T15:36:26.668000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 2, "guide_id": null, @@ -495,6 +508,7 @@ "created_date": "2023-02-10T14:05:25.947000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -528,6 +542,7 @@ "created_date": "2022-12-01T12:53:10.425000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "video", "dimension": "2d", "frame_count": 25, "guide_id": null, @@ -569,6 +584,7 @@ "created_date": "2022-09-22T14:22:25.820000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 8, "guide_id": null, @@ -610,6 +626,7 @@ "created_date": "2022-06-08T08:33:06.505000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -649,6 +666,7 @@ "created_date": "2022-03-05T10:32:19.149000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -690,6 +708,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -723,6 +742,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -756,6 +776,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -795,6 +816,7 @@ "created_date": "2022-03-05T09:33:10.420000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 5, "guide_id": null, @@ -834,6 +856,7 @@ "created_date": "2022-03-05T08:30:48.612000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 14, "guide_id": null, @@ -867,6 +890,7 @@ "created_date": "2022-02-21T10:31:52.429000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 11, "guide_id": null, @@ -900,6 +924,7 @@ "created_date": "2022-02-16T06:26:54.631000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "3d", "frame_count": 1, "guide_id": null, @@ -939,6 +964,7 @@ "created_date": "2022-02-16T06:25:48.168000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "video", "dimension": "2d", "frame_count": 25, "guide_id": null, @@ -978,6 +1004,7 @@ "created_date": "2021-12-14T18:50:29.458000Z", "data_chunk_size": 72, "data_compressed_chunk_type": "imageset", + "data_original_chunk_type": "imageset", "dimension": "2d", "frame_count": 23, "guide_id": null, diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py index 8e9d334f7a4..4a17454617d 100644 --- a/tests/python/shared/fixtures/init.py +++ b/tests/python/shared/fixtures/init.py @@ -96,12 +96,20 @@ def pytest_addoption(parser): def _run(command, capture_output=True): _command = command.split() if isinstance(command, str) else command try: + logger.debug(f"Executing a command: {_command}") + stdout, stderr = "", "" if capture_output: proc = run(_command, check=True, stdout=PIPE, stderr=PIPE) # nosec stdout, stderr = proc.stdout.decode(), proc.stderr.decode() else: proc = run(_command) # nosec + + if stdout: + logger.debug(f"Output (stdout): {stdout}") + if stderr: + logger.debug(f"Output (stderr): {stderr}") + return stdout, stderr except CalledProcessError as exc: message = f"Command failed: {' '.join(map(shlex.quote, _command))}." @@ -232,20 +240,20 @@ def kube_restore_clickhouse_db(): def docker_restore_redis_inmem(): - docker_exec_redis_inmem(["redis-cli", "flushall"]) + docker_exec_redis_inmem(["redis-cli", "-e", "flushall"]) def kube_restore_redis_inmem(): - kube_exec_redis_inmem(["redis-cli", "flushall"]) + kube_exec_redis_inmem(["sh", "-c", 'redis-cli -e -a "${REDIS_PASSWORD}" flushall']) def docker_restore_redis_ondisk(): - docker_exec_redis_ondisk(["redis-cli", "-p", "6666", "flushall"]) + docker_exec_redis_ondisk(["redis-cli", "-e", "-p", "6666", "flushall"]) def kube_restore_redis_ondisk(): kube_exec_redis_ondisk( - ["redis-cli", "-p", "6666", "-a", "${CVAT_REDIS_ONDISK_PASSWORD}", "flushall"] + ["sh", "-c", 'redis-cli -e -p 6666 -a "${CVAT_REDIS_ONDISK_PASSWORD}" flushall'] ) @@ -551,7 +559,7 @@ def restore_db_per_class(request): @pytest.fixture(scope="function") -def restore_cvat_data(request): +def restore_cvat_data_per_function(request): platform = request.config.getoption("--platform") if platform == "local": docker_restore_data_volumes() @@ -592,6 +600,15 @@ def restore_redis_inmem_per_function(request): kube_restore_redis_inmem() +@pytest.fixture(scope="class") +def restore_redis_inmem_per_class(request): + platform = request.config.getoption("--platform") + if platform == "local": + docker_restore_redis_inmem() + else: + kube_restore_redis_inmem() + + @pytest.fixture(scope="function") def restore_redis_ondisk_per_function(request): platform = request.config.getoption("--platform") @@ -599,3 +616,12 @@ def restore_redis_ondisk_per_function(request): docker_restore_redis_ondisk() else: kube_restore_redis_ondisk() + + +@pytest.fixture(scope="class") +def restore_redis_ondisk_per_class(request): + platform = request.config.getoption("--platform") + if platform == "local": + docker_restore_redis_ondisk() + else: + kube_restore_redis_ondisk() diff --git a/tests/python/shared/utils/helpers.py b/tests/python/shared/utils/helpers.py index f336cb3f911..ac5948182d7 100644 --- a/tests/python/shared/utils/helpers.py +++ b/tests/python/shared/utils/helpers.py @@ -1,10 +1,11 @@ -# Copyright (C) 2022 CVAT.ai Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation # # SPDX-License-Identifier: MIT import subprocess +from contextlib import closing from io import BytesIO -from typing import List, Optional +from typing import Generator, List, Optional import av import av.video.reformatter @@ -13,7 +14,7 @@ from shared.fixtures.init import get_server_image_tag -def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)): +def generate_image_file(filename="image.png", size=(100, 50), color=(0, 0, 0)): f = BytesIO() f.name = filename image = Image.new("RGB", size=size, color=color) @@ -40,7 +41,7 @@ def generate_image_files( return images -def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO: +def generate_video_file(num_frames: int, size=(100, 50)) -> BytesIO: f = BytesIO() f.name = "video.avi" @@ -60,6 +61,19 @@ def generate_video_file(num_frames: int, size=(50, 50)) -> BytesIO: return f +def read_video_file(file: BytesIO) -> Generator[Image.Image, None, None]: + file.seek(0) + + with av.open(file) as container: + video_stream = container.streams.video[0] + + with closing(video_stream.codec_context): # pyav has a memory leak in stream.close() + with closing(container.demux(video_stream)) as demux_iter: + for packet in demux_iter: + for frame in packet.decode(): + yield frame.to_image() + + def generate_manifest(path: str) -> None: command = [ "docker",