From aabe9009a2ec87e6fbe4f84fb520b344bde85346 Mon Sep 17 00:00:00 2001 From: fredzhu Date: Tue, 6 Feb 2024 15:24:52 +0800 Subject: [PATCH] feat(annotator): change model api & improve editor interface detail --- .../components/EditorStatus/index.less | 2 + .../components/ModelSelectModal/index.tsx | 97 +- .../SmartAnnotationControl/index.tsx | 74 +- packages/components/src/Annotator/editor.tsx | 22 +- .../src/Annotator/hooks/useActions.tsx | 898 +----------------- .../src/Annotator/hooks/useAiModels.ts | 792 +++++++++++++++ .../src/Annotator/hooks/useMouseEvents.tsx | 20 +- .../src/Annotator/hooks/useSubtools.tsx | 46 +- .../src/Annotator/hooks/useToolActions.ts | 167 +++- .../src/Annotator/hooks/useTopTools.tsx | 3 + packages/components/src/Annotator/index.less | 1 + .../components/src/Annotator/sevices/index.ts | 94 +- .../components/src/Annotator/tools/base.ts | 2 +- .../src/Annotator/tools/usePolygon.ts | 5 + .../src/Annotator/tools/useRectangle.ts | 1 + packages/components/src/Annotator/type.ts | 9 +- .../components/src/Annotator/utils/compute.ts | 12 +- 17 files changed, 1098 insertions(+), 1147 deletions(-) create mode 100644 packages/components/src/Annotator/hooks/useAiModels.ts diff --git a/packages/components/src/Annotator/components/EditorStatus/index.less b/packages/components/src/Annotator/components/EditorStatus/index.less index 2abf72f..e1e15b1 100644 --- a/packages/components/src/Annotator/components/EditorStatus/index.less +++ b/packages/components/src/Annotator/components/EditorStatus/index.less @@ -10,6 +10,8 @@ font-size: 14px; font-weight: 500; border-radius: 5px; + white-space: nowrap; + text-overflow: ellipsis; svg { width: 20px; diff --git a/packages/components/src/Annotator/components/ModelSelectModal/index.tsx b/packages/components/src/Annotator/components/ModelSelectModal/index.tsx index 5141f45..0130548 100644 --- a/packages/components/src/Annotator/components/ModelSelectModal/index.tsx +++ b/packages/components/src/Annotator/components/ModelSelectModal/index.tsx @@ -5,6 +5,7 @@ import { useLocale } from 'dds-utils'; import { memo, useMemo } from 'react'; import { EnumModelType, MODEL_INTRO_MAP } from '../../constants'; +import { FloatWrapper } from '../FloatWrapper'; import './index.less'; @@ -42,54 +43,56 @@ const ModelSelectModal: React.FC = memo( }, [AIAnnotation, modelOptions, selectedModel]); return ( - -
- {modelOptions.map((model, index) => { - const intro = MODEL_INTRO_MAP[model]; - if (!intro) return <>; - return ( -
onSelectModel(model)} - key={index} - > - -
- {localeText(intro.name)} + + +
+ {modelOptions.map((model, index) => { + const intro = MODEL_INTRO_MAP[model]; + if (!intro) return <>; + return ( +
onSelectModel(model)} + key={index} + > + +
+ {localeText(intro.name)} +
+
+ {localeText(intro.description)} +
+ {intro.hightlight && ( + + {'New'} + + )}
-
- {localeText(intro.description)} -
- {intro.hightlight && ( - - {'New'} - - )} -
- ); - })} -
- + ); + })} +
+ + ); }, ); diff --git a/packages/components/src/Annotator/components/SmartAnnotationControl/index.tsx b/packages/components/src/Annotator/components/SmartAnnotationControl/index.tsx index 9175867..525b00c 100644 --- a/packages/components/src/Annotator/components/SmartAnnotationControl/index.tsx +++ b/packages/components/src/Annotator/components/SmartAnnotationControl/index.tsx @@ -1,9 +1,9 @@ import { CloseOutlined } from '@ant-design/icons'; import Icon from '@ant-design/icons/lib/components/Icon'; -import { Button, Card, Select, Slider, Space } from 'antd'; +import { Button, Card, Input, Slider, Space } from 'antd'; import classNames from 'classnames'; import { useLocale } from 'dds-utils/locale'; -import { useMemo, memo, useState } from 'react'; +import { useMemo, memo } from 'react'; import { useImmer } from 'use-immer'; import { ReactComponent as DragToolIcon } from '../../assets/drag.svg'; @@ -19,8 +19,7 @@ import { EToolType, EnumModelType, } from '../../constants'; -import { OnAiAnnotationFunc } from '../../hooks/useActions'; -import { Category } from '../../type'; +import { OnAiAnnotationFunc } from '../../hooks/useAiModels'; import { FloatWrapper } from '../FloatWrapper'; import './index.less'; @@ -36,7 +35,6 @@ interface IProps { naturalSize: ISize; aiLabels?: string; limitConf: number; - categories: Category[]; setAiLabels: (labels?: string) => void; forceChangeTool: (tool: EBasicToolItem, subtool: ESubToolItem) => void; onExitAIAnnotation: () => void; @@ -56,7 +54,6 @@ const SmartAnnotationControl: React.FC = memo( isBatchEditing, isCtrlPressed, aiLabels, - categories, naturalSize, limitConf, setAiLabels, @@ -69,7 +66,6 @@ const SmartAnnotationControl: React.FC = memo( forceChangeTool, }) => { const { localeText } = useLocale(); - const [inputText, setInputText] = useState(''); /** Parameters for requesting segmemt everything API */ const [samParams, setSamParams] = useImmer({ @@ -109,29 +105,6 @@ const SmartAnnotationControl: React.FC = memo( }, }; - const labelOptions = useMemo(() => { - if (selectedTool === EBasicToolItem.Rectangle) { - let options = categories?.map((c) => c.name); - options = - inputText && !options.includes(inputText) - ? [inputText, ...options] - : options; - return options.map((text) => ( - - {text} - - )); - } else if (selectedTool === EBasicToolItem.Polygon) { - return []; - } else if (selectedTool === EBasicToolItem.Skeleton) { - return ['person'].map((label) => ( - - {label} - - )); - } - }, [selectedTool, categories, inputText]); - const mouseEventHandler = (event: React.MouseEvent) => { if ( event.type === 'mouseup' && @@ -183,11 +156,6 @@ const SmartAnnotationControl: React.FC = memo( isCtrlPressed, ]); - const onApplyCurrMaskObjs = () => { - onAcceptValidObjects(); - forceChangeTool(EBasicToolItem.Drag, ESubToolItem.PenAdd); - }; - const aiDetectionTip = useMemo(() => { if ( selectedTool === EBasicToolItem.Rectangle && @@ -296,27 +264,16 @@ const SmartAnnotationControl: React.FC = memo(
) : (
- + onChange={(e) => setAiLabels(e.target.value)} + onKeyUp={(event) => event.stopPropagation()} + onKeyDown={(event) => event.stopPropagation()} + /> - diff --git a/packages/components/src/Annotator/editor.tsx b/packages/components/src/Annotator/editor.tsx index f9a4daf..b3805fd 100755 --- a/packages/components/src/Annotator/editor.tsx +++ b/packages/components/src/Annotator/editor.tsx @@ -16,6 +16,7 @@ import SmartAnnotationControl from './components/SmartAnnotationControl'; import { TopPagination } from './components/TopPagination'; import { DisplayOption, EBasicToolItem, TOOL_MODELS_MAP } from './constants'; import useActions from './hooks/useActions'; +import useAiModels from './hooks/useAiModels'; import useAttributes from './hooks/useAttributes'; import useCanvasContainer from './hooks/useCanvasContainer'; import useCanvasRender from './hooks/useCanvasRender'; @@ -268,8 +269,19 @@ const Edit: React.FC = (props) => { updateAllObjectWithoutHistory, }); + const { onAiAnnotation } = useAiModels({ + currImageItem, + drawData, + setDrawData, + setDrawDataWithHistory, + editState, + setEditState, + naturalSize, + clientSize, + getAnnotColor, + }); + const { - onAiAnnotation, onSaveAnnotations, onCommitAnnotations, onCancelAnnotations, @@ -281,16 +293,9 @@ const Edit: React.FC = (props) => { currImageItem, modal, drawData, - setDrawData, - setDrawDataWithHistory, editState, setEditState, - naturalSize, - clientSize, - imagePos, - containerMouse, hadChangeRecord, - getAnnotColor, categories, translateObject, flagSaved, @@ -623,7 +628,6 @@ const Edit: React.FC = (props) => { limitConf={drawData.limitConf} aiLabels={aiLabels} naturalSize={naturalSize} - categories={categories} setAiLabels={setAiLabels} forceChangeTool={forceChangeTool} onAiAnnotation={onAiAnnotation} diff --git a/packages/components/src/Annotator/hooks/useActions.tsx b/packages/components/src/Annotator/hooks/useActions.tsx index f190c1c..ce32e14 100644 --- a/packages/components/src/Annotator/hooks/useActions.tsx +++ b/packages/components/src/Annotator/hooks/useActions.tsx @@ -1,49 +1,17 @@ -import { useModel } from '@umijs/max'; -import { CursorState } from 'ahooks/lib/useMouse'; -import { Modal, message } from 'antd'; +import { Modal } from 'antd'; import { ModalStaticFunctions } from 'antd/es/modal/confirm'; import { useLocale } from 'dds-utils/locale'; import { useCallback } from 'react'; import { Updater } from 'use-immer'; -import { - BODY_TEMPLATE, - EBasicToolItem, - EBasicToolTypeMap, - EnumModelType, - EObjectType, - ESubToolItem, -} from '../constants'; -import { NsApiAnnotator, fetchModelResults } from '../sevices'; -import { rleToCanvas } from '../tools/useMask'; import { DrawData, AnnoItem, EditState, EditorMode, - IAnnotationObject, - PromptItem, - EObjectStatus, Category, VideoFramesData, - EPromptType, - ReqPromptItem, - IMask, } from '../type'; -import { getImageBase64, getServerAddressableUrl } from '../utils/base64'; -import { - getVisibleAreaForImage, - translateBoundingBoxToRect, - translatePointsToPointObjs, - translatePointZoom, - translateRectToAbsBbox, - getCanvasPoint, - getNaturalPoint, - translateRectToBoundingBox, - translatePointObjsToPointAttrs, - translateRectZoom, - translateAbsBBoxToRect, -} from '../utils/compute'; interface IProps { mode: EditorMode; @@ -51,16 +19,9 @@ interface IProps { modal: Omit; framesData?: VideoFramesData; drawData: DrawData; - setDrawData: Updater; - setDrawDataWithHistory: Updater; editState: EditState; setEditState: Updater; - naturalSize: ISize; - clientSize: ISize; - containerMouse: CursorState; - imagePos: React.MutableRefObject; hadChangeRecord: boolean; - getAnnotColor: (category: string, forceColorByCategory?: boolean) => string; categories: Category[]; translateObject?: (object: any) => any; flagSaved?: () => void; @@ -85,44 +46,16 @@ interface IProps { classificationOptions?: Category[]; } -export type OnAiAnnotationFunc = ({ - type, - drawData, - aiLabels, - bbox, - promptsQueue, - segmentationClicks, - segmentEverythingParams, -}: { - type?: EObjectType; - drawData?: DrawData; - aiLabels?: string; - bbox?: IBoundingBox; - promptsQueue?: PromptItem[]; - segmentationClicks?: { - point: IPoint; - isPositive: boolean; - }[]; - segmentEverythingParams?: NsApiAnnotator.SegmentEverythingParams; -}) => Promise; - const useActions = ({ mode, currImageItem, modal, framesData, drawData: editorDrawData, - setDrawData, - setDrawDataWithHistory, editState, setEditState, - naturalSize, - clientSize, - imagePos, - containerMouse, hadChangeRecord, categories, - getAnnotColor, translateObject, flagSaved, onCancel, @@ -134,840 +67,12 @@ const useActions = ({ classificationOptions, }: IProps) => { const { localeText } = useLocale(); - const { setLoading } = useModel('global'); const { isRequiring } = editState; const setIsRequiring = (requiring: boolean) => setEditState((s) => { s.isRequiring = requiring; }); - const requestAiDetection = async (aiLabels: string) => { - if (!currImageItem) return; - - try { - setLoading(true); - const { result } = await fetchModelResults( - EnumModelType.Detection, - { - image: await getImageBase64(currImageItem.url), - text: aiLabels, - }, - ); - - if (result) { - const { objects, suggestThreshold } = result; - const limitConf = suggestThreshold || 0; - const newObjects: IAnnotationObject[] = objects - .map((item) => { - // mouse.elementW is not necessarily identical to the size during initialization transformation - const rect = { - ...translateBoundingBoxToRect(item.boundingBox, clientSize), - }; - return { - rect: { ...rect, visible: true }, - labelId: editState.latestLabelId, - type: EObjectType.Rectangle, - hidden: false, - status: - item.normalizedScore >= limitConf - ? EObjectStatus.Checked - : EObjectStatus.Unchecked, - conf: item.normalizedScore, - color: getAnnotColor(editState.latestLabelId, true), - }; - }) - .reverse(); - setDrawDataWithHistory((s) => { - s.isBatchEditing = true; - s.limitConf = limitConf; - const commitedObjects = s.objectList.filter( - (obj) => obj?.status === EObjectStatus.Commited, - ); - s.objectList = [...commitedObjects, ...newObjects]; - if (s.creatingObject && s.objectList[s.activeObjectIndex]) { - s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; - } - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - } finally { - setLoading(false); - } - }; - - const convertPromptFormat = (prompt: PromptItem[]): ReqPromptItem[] => { - const newPromptArr = prompt.map((item) => { - const { type, isPositive, point, rect, stroke, radius, polygons } = item; - - const newItem = { type, isPositive }; - - if (rect) { - const { xmax, xmin, ymax, ymin } = translateRectToAbsBbox(rect); - const topleftPoint = getNaturalPoint( - [xmin, ymin], - naturalSize, - clientSize, - ); - const bottomRightPoint = getNaturalPoint( - [xmax, ymax], - naturalSize, - clientSize, - ); - Object.assign(newItem, { - rect: [ - topleftPoint.x, - topleftPoint.y, - bottomRightPoint.x, - bottomRightPoint.y, - ], - }); - } - - if (point) { - const naturalPoint = getNaturalPoint( - [point.x, point.y], - naturalSize, - clientSize, - ); - Object.assign(newItem, { - point: [naturalPoint.x, naturalPoint.y], - }); - } - - if (stroke) { - const points = stroke.reduce((acc: number[], point: IPoint) => { - const { x, y } = point; - const naturalPoint = getNaturalPoint([x, y], naturalSize, clientSize); - return acc.concat([naturalPoint.x, naturalPoint.y]); - }, []); - Object.assign(newItem, { - stroke: points, - radius, - }); - } - - if (polygons) { - const transformedPolygons = polygons.map((polygon) => { - const res = []; - for (let i = 0; i < polygon.length; i += 2) { - const transformedPoint = getNaturalPoint( - [polygon[i], polygon[i + 1]], - naturalSize, - clientSize, - ); - res.push(transformedPoint.x, transformedPoint.y); - } - return res; - }); - Object.assign(newItem, { - polygons: transformedPolygons, - }); - } - - return newItem; - }); - - return newPromptArr; - }; - - const requestIvpDetection = async ( - drawData: DrawData, - promptsQueue?: PromptItem[], - ) => { - if (!currImageItem || !promptsQueue) return; - - if (promptsQueue.every((prompt) => !prompt.isPositive)) { - message.error(localeText('DDSAnnotator.smart.msg.positivePrompt')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - return; - } - - try { - setLoading(true); - const reqParams = { - prompts: convertPromptFormat(promptsQueue || []), - labelTypes: ['bbox'], - }; - if (drawData.prompt.sessionId) { - Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); - } else { - const url = await getServerAddressableUrl(currImageItem.url); - Object.assign(reqParams, { - promptImage: url, - inferImage: url, - }); - } - - const { result, sessionId } = await fetchModelResults( - EnumModelType.IVP, - reqParams, - ); - - if (result) { - const { objects } = result; - const limitConf = 0.3; - const newObjects: IAnnotationObject[] = objects - .filter((item) => { - return item.bbox; - }) - .map((item) => { - const [xmin, ymin, xmax, ymax] = item.bbox!; - const rect = translateRectZoom( - translateAbsBBoxToRect({ xmin, ymin, xmax, ymax }), - naturalSize, - clientSize, - ); - return { - rect: { ...rect, visible: true }, - labelId: editState.latestLabelId, - type: EObjectType.Rectangle, - hidden: false, - status: - item.score >= limitConf - ? EObjectStatus.Checked - : EObjectStatus.Unchecked, - conf: item.score, - color: getAnnotColor(editState.latestLabelId, true), - }; - }) - .reverse(); - - setDrawDataWithHistory((s) => { - s.isBatchEditing = true; - s.limitConf = limitConf; - const commitedObjects = s.objectList.filter( - (obj) => obj.status === EObjectStatus.Commited, - ); - s.objectList = [...commitedObjects, ...newObjects]; - if (s.creatingObject && s.objectList[s.activeObjectIndex]) { - s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; - } - s.prompt.promptsQueue = promptsQueue; - s.prompt.sessionId = sessionId; - s.prompt.creatingPrompt = undefined; - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - } finally { - setLoading(false); - } - }; - - const requestIvpMask = async ( - drawData: DrawData, - promptsQueue?: PromptItem[], - ) => { - if (!currImageItem || !promptsQueue) return; - - if (promptsQueue.every((prompt) => !prompt.isPositive)) { - message.error(localeText('DDSAnnotator.smart.msg.positivePrompt')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - return; - } - - try { - setLoading(true); - const reqParams = { - prompts: convertPromptFormat(promptsQueue || []), - labelTypes: ['mask'], - }; - if (drawData.prompt.sessionId) { - Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); - } else { - const url = await getServerAddressableUrl(currImageItem.url); - Object.assign(reqParams, { - promptImage: url, - inferImage: url, - }); - } - - const { result, sessionId } = await fetchModelResults( - EnumModelType.IVP, - reqParams, - ); - - if (result) { - // Display mask in different color - setEditState((s) => { - s.annotsDisplayOptions.colorByCategory = false; - }); - - const { objects } = result; - const newObjects: IAnnotationObject[] = objects - .filter((item) => !!item.mask) - .map((item) => { - const color = getAnnotColor(editState.latestLabelId); - const maskRleStr = item.mask?.counts || ''; - return { - type: EObjectType.Mask, - hidden: false, - labelId: editState.latestLabelId, - maskRle: maskRleStr, - maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), - status: EObjectStatus.Checked, - conf: item.score, - color: getAnnotColor(editState.latestLabelId, true), - }; - }); - - setDrawDataWithHistory((s) => { - s.isBatchEditing = true; - const commitedObjects = s.objectList.filter( - (obj) => obj.status === EObjectStatus.Commited, - ); - s.objectList = [...commitedObjects, ...newObjects]; - if (s.creatingObject && s.objectList[s.activeObjectIndex]) { - s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; - } - s.prompt.promptsQueue = promptsQueue; - s.prompt.sessionId = sessionId; - s.prompt.creatingPrompt = undefined; - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - } finally { - setLoading(false); - } - }; - - const getCurrVisibleBbox = () => { - // record visible area currently for model prediction - const { xmin, ymin, xmax, ymax } = getVisibleAreaForImage( - imagePos.current, - clientSize, - containerMouse, - ); - let area = [0, 0, naturalSize.width, naturalSize.height]; - if (xmax > 0 && ymax > 0) { - const { x: x1, y: y1 } = translatePointZoom( - { - x: xmin, - y: ymin, - }, - clientSize, - naturalSize, - ); - const { x: x2, y: y2 } = translatePointZoom( - { - x: xmax, - y: ymax, - }, - clientSize, - naturalSize, - ); - area = [Math.round(x1), Math.round(y1), Math.round(x2), Math.round(y2)]; - } - return area; - }; - - const requestAiSegmentByPolygon = async ( - drawData: DrawData, - promptsQueue?: PromptItem[], - ) => { - if (!currImageItem || !promptsQueue) return; - - const reqParams = { - image: editState.imageCacheIdForPolygon - ? `image_id://${editState.imageCacheIdForPolygon}` - : await getImageBase64(currImageItem.url), - density: drawData.pointResolution, - area: getCurrVisibleBbox(), - prompts: convertPromptFormat(promptsQueue || []), - }; - - if (drawData.prompt.sessionId) { - Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); - } - - try { - setLoading(true); - const { result } = - await fetchModelResults( - EnumModelType.SegmentByPolygon, - reqParams, - ); - if (result) { - const { image, polygons, sessionId } = result; - - if (polygons && polygons.length > 0) { - const predictPolygons = polygons - .filter((item) => { - return item.length >= 6; - }) - .map((item) => { - const result: IPolygon = []; - for (let i = 0; i < item.length; i += 2) { - const x = item[i]; - const y = item[i + 1]; - const canvasPoint = getCanvasPoint( - [x, y], - naturalSize, - clientSize, - ); - result.push(canvasPoint); - } - return result; - }); - - const creatingObj = { - type: EObjectType.Polygon, - hidden: false, - labelId: editState.latestLabelId, - color: - drawData.creatingObject?.color || - getAnnotColor(editState.latestLabelId), - currIndex: -1, - polygon: { - visible: true, - group: predictPolygons, - }, - status: EObjectStatus.Checked, - }; - - setDrawDataWithHistory((s) => { - s.creatingObject = creatingObj; - s.prompt.promptsQueue = promptsQueue; - s.prompt.sessionId = sessionId; - s.prompt.creatingPrompt = undefined; - }); - setEditState((s) => { - s.imageCacheIdForPolygon = image.replace(/^image_id:\/\//, ''); - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - } finally { - setLoading(false); - } - }; - - const requestAiSegmentByMask = async ( - drawData: DrawData, - promptsQueue?: PromptItem[], - ) => { - if (!promptsQueue || !currImageItem) return; - - const reqParams: NsApiAnnotator.FetchAIMaskSegmentReq = { - prompts: convertPromptFormat(promptsQueue || []), - }; - if (drawData.prompt.sessionId) { - Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); - } else { - Object.assign(reqParams, { - image: await getServerAddressableUrl(currImageItem.url), - }); - } - - try { - setLoading(true); - const { result, sessionId } = - await fetchModelResults( - EnumModelType.SegmentByMask, - reqParams, - ); - if (result) { - const { mask } = result; - const color = - drawData.creatingObject?.color || - getAnnotColor(editState.latestLabelId); - const maskRleStr = mask.counts || ''; - const creatingObj = { - type: EObjectType.Mask, - hidden: false, - labelId: editState.latestLabelId, - currIndex: -1, - maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), - maskRle: maskRleStr, - status: EObjectStatus.Checked, - color, - }; - setDrawDataWithHistory((s) => { - s.creatingObject = creatingObj; - s.prompt.promptsQueue = promptsQueue; - s.prompt.sessionId = sessionId; - s.prompt.creatingPrompt = undefined; - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - } finally { - setLoading(false); - } - }; - - const requestAiPoseEstimation = async ( - drawData: DrawData, - aiLabels: string, - ) => { - if (!currImageItem) return; - - // TODO: Integrate custom templates - const { lines, pointNames, pointColors } = BODY_TEMPLATE; - const reqParams = { - image: await getImageBase64(currImageItem.url), - targets: aiLabels, - template: { - lines, - pointNames, - pointColors, - }, - }; - - if (drawData.isBatchEditing) { - const objectList = [...drawData.objectList]; - if ( - drawData.activeObjectIndex > -1 && - objectList[drawData.activeObjectIndex] && - drawData.creatingObject - ) { - // update creating object - objectList[drawData.activeObjectIndex] = { - ...objectList[drawData.activeObjectIndex], - ...drawData.creatingObject, - }; - } - const skeletonObjs = objectList.filter( - (obj) => - obj.type === EObjectType.Skeleton && - obj.status === EObjectStatus.Checked, - ); - if (skeletonObjs.length > 0) { - const objects = skeletonObjs.map((item) => { - return { - categoryName: aiLabels, - points: item.keypoints - ? translatePointObjsToPointAttrs( - item.keypoints.points, - naturalSize, - clientSize, - ).points - : undefined, - boundingBox: item.rect - ? translateRectToBoundingBox(item.rect, clientSize) - : undefined, - }; - }); - Object.assign(reqParams, { objects }); - } - } - - try { - setLoading(true); - const { result } = await fetchModelResults( - EnumModelType.Pose, - reqParams, - ); - - if (result) { - const { objects } = result; - - if (objects && objects.length > 0) { - const skeletonObjs = objects.map((obj) => { - let { boundingBox, points, conf } = obj; - const newObj: IAnnotationObject = { - labelId: editState.latestLabelId, - color: getAnnotColor(editState.latestLabelId), - type: EObjectType.Skeleton, - hidden: false, - conf, - status: EObjectStatus.Checked, - }; - if (boundingBox) { - const rect = translateBoundingBoxToRect(boundingBox!, clientSize); - Object.assign(newObj, { rect: { visible: true, ...rect } }); - } - if (points && lines && pointColors && pointNames) { - const pointObjs = translatePointsToPointObjs( - points, - pointNames, - pointColors, - naturalSize, - clientSize, - ); - Object.assign(newObj, { - keypoints: { - points: pointObjs, - lines, - }, - }); - } - return newObj; - }); - - setDrawDataWithHistory((s) => { - if (!s.isBatchEditing) { - s.isBatchEditing = true; - } - const commitedObjects = s.objectList.filter( - (obj) => obj.status === EObjectStatus.Commited, - ); - s.objectList = [...commitedObjects, ...skeletonObjs]; - if (s.creatingObject && s.objectList[s.activeObjectIndex]) { - s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; - } - }); - - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - } finally { - setLoading(false); - } - }; - - const requestEdgeStitchingForMask = async (drawData: DrawData) => { - if ( - !currImageItem || - !drawData.prompt.creatingPrompt?.stroke || - !drawData.prompt.creatingPrompt?.radius - ) - return; - - const { stroke, radius } = drawData.prompt.creatingPrompt; - - const maskObjects = drawData.objectList.filter( - (item) => item.type === EObjectType.Mask, - ); - - if (maskObjects.length < 2) { - message.error( - 'To ensure valid results when using intelligent edge stitching, make sure to use at least 2 mask objects.', - ); - setDrawData((s) => { - s.prompt.creatingPrompt = undefined; - }); - return; - } - - const masks: IMask[] = maskObjects.map((item) => ({ - counts: item.maskRle || '', - size: [naturalSize.height, naturalSize.width], - })); - - const points = stroke.reduce((acc: number[], point: IPoint) => { - const { x, y } = point; - const naturalPoint = getNaturalPoint([x, y], naturalSize, clientSize); - return acc.concat([naturalPoint.x, naturalPoint.y]); - }, []); - - const reqParams: NsApiAnnotator.FetchEdgeStitchingReq = { - masks, - prompts: [ - { - type: EPromptType.Stroke, - stroke: points, - radius, - }, - ], - }; - if (drawData.prompt.sessionId) { - Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); - } else { - Object.assign(reqParams, { - image: await getServerAddressableUrl(currImageItem.url), - }); - } - - try { - setLoading(true); - const { result, sessionId } = - await fetchModelResults( - EnumModelType.MaskEdgeStitching, - reqParams, - ); - if (result && result.masks?.length > 0) { - const newMaskObjects = maskObjects.map((item, index) => { - const maskRleStr = result.masks?.[index]?.counts || ''; - return { - ...item, - maskRle: maskRleStr, - maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, item.color), - }; - }); - - // Replace all instances of the mask type - const leftObjs = drawData.objectList.filter( - (obj) => obj.type !== EObjectType.Mask, - ); - - setDrawDataWithHistory((s) => { - s.objectList = [...leftObjs, ...newMaskObjects]; - s.prompt.creatingPrompt = undefined; - s.prompt.sessionId = sessionId; - }); - - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - setDrawDataWithHistory((s) => { - s.prompt.creatingPrompt = undefined; - }); - } finally { - setLoading(false); - } - }; - - const requestSegmentEverything = async ( - params?: NsApiAnnotator.SegmentEverythingParams, - ) => { - if (!currImageItem) return; - - const reqParams: NsApiAnnotator.FetchSegmentEverythingReq = { - image: await getServerAddressableUrl(currImageItem.url), - ...params, - }; - - try { - setLoading(true); - const { result } = - await fetchModelResults( - EnumModelType.SegmentEverything, - reqParams, - ); - if (result && result.masks?.length > 0) { - // change to display different color - setEditState((s) => { - s.annotsDisplayOptions.colorByCategory = false; - }); - const maskObjects: IAnnotationObject[] = result.masks.map((item) => { - const color = getAnnotColor(editState.latestLabelId); - const maskRleStr = item?.counts || ''; - return { - type: EObjectType.Mask, - hidden: false, - labelId: editState.latestLabelId, - maskRle: maskRleStr, - maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), - conf: 1, - status: EObjectStatus.Checked, - color, - }; - }); - setDrawDataWithHistory((s) => { - s.objectList = maskObjects; - s.isBatchEditing = true; - }); - message.success(localeText('DDSAnnotator.smart.msg.success')); - } - } catch (error: any) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - } finally { - setLoading(false); - } - }; - - const onAiAnnotation: OnAiAnnotationFunc = useCallback( - async ({ - type, - drawData: propsDrawData, - aiLabels, - promptsQueue, - segmentEverythingParams, - }) => { - if (isRequiring) return; - - const drawData = propsDrawData || editorDrawData; - - if ( - !aiLabels && - (drawData.selectedTool === EBasicToolItem.Skeleton || - (drawData.selectedTool === EBasicToolItem.Rectangle && - drawData.selectedModel[drawData.selectedTool] === - EnumModelType.Detection)) - ) { - message.warning(localeText('DDSAnnotator.smart.msg.labelRequired')); - return; - } - - const hide = message.loading( - localeText('DDSAnnotator.smart.msg.loading'), - 100000, - ); - try { - setIsRequiring(true); - const aiType = type || EBasicToolTypeMap[drawData.selectedTool]; - switch (aiType) { - case EObjectType.Rectangle: { - if ( - drawData.selectedModel[drawData.selectedTool] === - EnumModelType.Detection - ) { - await requestAiDetection(aiLabels || ''); - } else { - await requestIvpDetection(drawData, promptsQueue); - } - break; - } - case EObjectType.Skeleton: { - await requestAiPoseEstimation(drawData, aiLabels || ''); - break; - } - case EObjectType.Polygon: { - await requestAiSegmentByPolygon(drawData, promptsQueue); - break; - } - case EObjectType.Mask: { - const model = drawData.selectedModel[drawData.selectedTool]; - if (model === EnumModelType.SegmentEverything) { - if (drawData.selectedSubTool === ESubToolItem.AutoEdgeStitching) { - await requestEdgeStitchingForMask(drawData); - } else if ( - drawData.selectedSubTool === ESubToolItem.AutoSegmentEverything - ) { - await requestSegmentEverything(segmentEverythingParams); - } - } else if (model === EnumModelType.IVP) { - await requestIvpMask(drawData, promptsQueue); - } else { - await requestAiSegmentByMask(drawData, promptsQueue); - } - break; - } - default: - message.warning('Plan to Support!'); - break; - } - } catch (error) { - message.error(localeText('DDSAnnotator.smart.msg.error')); - } finally { - setIsRequiring(false); - setDrawData((s) => { - s.prompt.activeRectWhileLoading = undefined; - }); - hide(); - } - }, - [editorDrawData], - ); - const translateDrawData = useCallback( (drawData: DrawData): [string, any[], any] => { let objectList = []; @@ -1131,7 +236,6 @@ const useActions = ({ }; return { - onAiAnnotation, onSaveAnnotations, onCommitAnnotations, onCancelAnnotations, diff --git a/packages/components/src/Annotator/hooks/useAiModels.ts b/packages/components/src/Annotator/hooks/useAiModels.ts new file mode 100644 index 0000000..de7bce4 --- /dev/null +++ b/packages/components/src/Annotator/hooks/useAiModels.ts @@ -0,0 +1,792 @@ +import { useModel } from '@umijs/max'; +import { message } from 'antd'; +import { useLocale } from 'dds-utils/locale'; +import { useCallback } from 'react'; +import { Updater } from 'use-immer'; + +import { + BODY_TEMPLATE, + EBasicToolTypeMap, + EnumModelType, + EObjectType, + ESubToolItem, +} from '../constants'; +import { NsApiAnnotator, fetchModelResults } from '../sevices'; +import { rleToCanvas } from '../tools/useMask'; +import { + DrawData, + AnnoItem, + EditState, + IAnnotationObject, + PromptItem, + EObjectStatus, + EPromptType, + ReqPromptItem, + IMask, +} from '../type'; +import { getServerAddressableUrl } from '../utils/base64'; +import { + translateRectToAbsBbox, + getCanvasPoint, + getNaturalPoint, + translateRectZoom, + translateAbsBBoxToRect, + translatePointsToRect, + translateRectToPointsArray, + newTranslatePointsToPointObjs, + newTranslatePointObjsToPointAttrs, +} from '../utils/compute'; + +interface IProps { + currImageItem?: AnnoItem; + drawData: DrawData; + setDrawData: Updater; + setDrawDataWithHistory: Updater; + editState: EditState; + setEditState: Updater; + naturalSize: ISize; + clientSize: ISize; + getAnnotColor: (category: string, forceColorByCategory?: boolean) => string; +} + +export type OnAiAnnotationFunc = ({ + type, + drawData, + aiLabels, + bbox, + promptsQueue, + segmentationClicks, + segmentEverythingParams, +}: { + type?: EObjectType; + drawData?: DrawData; + aiLabels?: string; + bbox?: IBoundingBox; + promptsQueue?: PromptItem[]; + segmentationClicks?: { + point: IPoint; + isPositive: boolean; + }[]; + segmentEverythingParams?: NsApiAnnotator.FetchSegmentEverythingReq; +}) => Promise; + +const useAiModels = ({ + currImageItem, + drawData: editorDrawData, + setDrawData, + setDrawDataWithHistory, + editState, + setEditState, + naturalSize, + clientSize, + getAnnotColor, +}: IProps) => { + const { localeText } = useLocale(); + const { setLoading } = useModel('global'); + + const fetchCommonReqParams = async ( + drawData: DrawData, + reqParams: T, + ): Promise => { + if (drawData.prompt.sessionId) { + Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); + } else if (currImageItem) { + Object.assign(reqParams, { + image: await getServerAddressableUrl(currImageItem.url), + }); + } + return reqParams; + }; + + const convertPromptFormat = (prompt: PromptItem[]): ReqPromptItem[] => { + const newPromptArr = prompt.map((item) => { + const { type, isPositive, point, rect, stroke, radius, polygons } = item; + + const newItem = { type, isPositive }; + + if (rect) { + const { xmax, xmin, ymax, ymin } = translateRectToAbsBbox(rect); + const topleftPoint = getNaturalPoint( + [xmin, ymin], + naturalSize, + clientSize, + ); + const bottomRightPoint = getNaturalPoint( + [xmax, ymax], + naturalSize, + clientSize, + ); + Object.assign(newItem, { + rect: [ + topleftPoint.x, + topleftPoint.y, + bottomRightPoint.x, + bottomRightPoint.y, + ], + }); + } + + if (point) { + const naturalPoint = getNaturalPoint( + [point.x, point.y], + naturalSize, + clientSize, + ); + Object.assign(newItem, { + point: [naturalPoint.x, naturalPoint.y], + }); + } + + if (stroke) { + const points = stroke.reduce((acc: number[], point: IPoint) => { + const { x, y } = point; + const naturalPoint = getNaturalPoint([x, y], naturalSize, clientSize); + return acc.concat([naturalPoint.x, naturalPoint.y]); + }, []); + Object.assign(newItem, { + stroke: points, + radius, + }); + } + + if (polygons) { + const transformedPolygons = polygons.map((polygon) => { + const res = []; + for (let i = 0; i < polygon.length; i += 2) { + const transformedPoint = getNaturalPoint( + [polygon[i], polygon[i + 1]], + naturalSize, + clientSize, + ); + res.push(transformedPoint.x, transformedPoint.y); + } + return res; + }); + Object.assign(newItem, { + polygons: transformedPolygons, + }); + } + + return newItem; + }); + + return newPromptArr; + }; + + const requestAiDetection = async (drawData: DrawData, aiLabels: string) => { + if (!aiLabels) { + message.warning(localeText('DDSAnnotator.smart.msg.labelRequired')); + return; + } + + const reqParams = await fetchCommonReqParams(drawData, { + prompts: [ + { + type: EPromptType.Text, + text: aiLabels, + }, + ], + }); + + const { result, sessionId } = + await fetchModelResults( + EnumModelType.Detection, + reqParams, + ); + + if (result) { + const { objects, suggestThreshold } = result; + const limitConf = suggestThreshold || 0; + const maxScore = objects.reduce( + (max, item) => (item.score > max ? item.score : max), + objects[0]?.score || 0, + ); + const newObjects: IAnnotationObject[] = objects + .map((item) => { + // mouse.elementW is not necessarily identical to the size during initialization transformation + const rect = { + ...translatePointsToRect(item.bbox, naturalSize, clientSize), + }; + const conf = item.score / maxScore; + return { + rect: { ...rect, visible: true }, + labelId: editState.latestLabelId, + type: EObjectType.Rectangle, + hidden: false, + status: + conf >= limitConf + ? EObjectStatus.Checked + : EObjectStatus.Unchecked, + conf, + color: getAnnotColor(editState.latestLabelId, true), + }; + }) + .reverse(); + setDrawDataWithHistory((s) => { + s.isBatchEditing = true; + s.limitConf = limitConf; + const commitedObjects = s.objectList.filter( + (obj) => obj?.status === EObjectStatus.Commited, + ); + s.objectList = [...commitedObjects, ...newObjects]; + if (s.creatingObject && s.objectList[s.activeObjectIndex]) { + s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; + } + s.prompt.sessionId = sessionId; + }); + return true; + } + }; + + const requestIvpDetection = async ( + drawData: DrawData, + promptsQueue?: PromptItem[], + ) => { + if (!currImageItem || !promptsQueue) return; + + const reqParams = { + prompts: convertPromptFormat(promptsQueue || []), + labelTypes: ['bbox'], + }; + if (drawData.prompt.sessionId) { + Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); + } else { + const url = await getServerAddressableUrl(currImageItem.url); + Object.assign(reqParams, { + promptImage: url, + inferImage: url, + }); + } + + const { result, sessionId } = await fetchModelResults( + EnumModelType.IVP, + reqParams, + ); + + if (result) { + const { objects } = result; + const limitConf = 0.3; + const newObjects: IAnnotationObject[] = objects + .filter((item) => { + return item.bbox; + }) + .map((item) => { + const [xmin, ymin, xmax, ymax] = item.bbox!; + const rect = translateRectZoom( + translateAbsBBoxToRect({ xmin, ymin, xmax, ymax }), + naturalSize, + clientSize, + ); + return { + rect: { ...rect, visible: true }, + labelId: editState.latestLabelId, + type: EObjectType.Rectangle, + hidden: false, + status: + item.score >= limitConf + ? EObjectStatus.Checked + : EObjectStatus.Unchecked, + conf: item.score, + color: getAnnotColor(editState.latestLabelId, true), + }; + }) + .reverse(); + + setDrawDataWithHistory((s) => { + s.isBatchEditing = true; + s.limitConf = limitConf; + const commitedObjects = s.objectList.filter( + (obj) => obj.status === EObjectStatus.Commited, + ); + s.objectList = [...commitedObjects, ...newObjects]; + if (s.creatingObject && s.objectList[s.activeObjectIndex]) { + s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; + } + s.prompt.promptsQueue = promptsQueue; + s.prompt.sessionId = sessionId; + s.prompt.creatingPrompt = undefined; + }); + return true; + } + }; + + const requestIvpMask = async ( + drawData: DrawData, + promptsQueue?: PromptItem[], + ) => { + if (!currImageItem || !promptsQueue) return; + + const reqParams = { + prompts: convertPromptFormat(promptsQueue || []), + labelTypes: ['mask'], + }; + if (drawData.prompt.sessionId) { + Object.assign(reqParams, { sessionId: drawData.prompt.sessionId }); + } else { + const url = await getServerAddressableUrl(currImageItem.url); + Object.assign(reqParams, { + promptImage: url, + inferImage: url, + }); + } + + const { result, sessionId } = await fetchModelResults( + EnumModelType.IVP, + reqParams, + ); + + if (result) { + // Display mask in different color + setEditState((s) => { + s.annotsDisplayOptions.colorByCategory = false; + }); + + const { objects } = result; + const newObjects: IAnnotationObject[] = objects + .filter((item) => !!item.mask) + .map((item) => { + const color = getAnnotColor(editState.latestLabelId); + const maskRleStr = item.mask?.counts || ''; + return { + type: EObjectType.Mask, + hidden: false, + labelId: editState.latestLabelId, + maskRle: maskRleStr, + maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), + status: EObjectStatus.Checked, + conf: item.score, + color: getAnnotColor(editState.latestLabelId, true), + }; + }); + + setDrawDataWithHistory((s) => { + s.isBatchEditing = true; + const commitedObjects = s.objectList.filter( + (obj) => obj.status === EObjectStatus.Commited, + ); + s.objectList = [...commitedObjects, ...newObjects]; + if (s.creatingObject && s.objectList[s.activeObjectIndex]) { + s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; + } + s.prompt.promptsQueue = promptsQueue; + s.prompt.sessionId = sessionId; + s.prompt.creatingPrompt = undefined; + }); + return true; + } + }; + + const requestAiSegmentByPolygon = async ( + drawData: DrawData, + promptsQueue?: PromptItem[], + ) => { + if (!promptsQueue) return; + + const reqParams = await fetchCommonReqParams(drawData, { + density: drawData.pointResolution, + prompts: convertPromptFormat(promptsQueue || []), + }); + + const { result, sessionId } = + await fetchModelResults( + EnumModelType.SegmentByPolygon, + reqParams, + ); + if (result) { + const { polygons } = result; + + if (polygons && polygons.length > 0) { + const predictPolygons = polygons + .filter((item) => { + return item.length >= 6; + }) + .map((item) => { + const result: IPolygon = []; + for (let i = 0; i < item.length; i += 2) { + const x = item[i]; + const y = item[i + 1]; + const canvasPoint = getCanvasPoint( + [x, y], + naturalSize, + clientSize, + ); + result.push(canvasPoint); + } + return result; + }); + + const creatingObj = { + type: EObjectType.Polygon, + hidden: false, + labelId: editState.latestLabelId, + color: + drawData.creatingObject?.color || + getAnnotColor(editState.latestLabelId), + currIndex: -1, + polygon: { + visible: true, + group: predictPolygons, + }, + status: EObjectStatus.Checked, + }; + + setDrawDataWithHistory((s) => { + s.creatingObject = creatingObj; + s.prompt.promptsQueue = promptsQueue; + s.prompt.sessionId = sessionId; + s.prompt.creatingPrompt = undefined; + }); + return true; + } + } + }; + + const requestAiSegmentByMask = async ( + drawData: DrawData, + promptsQueue?: PromptItem[], + ) => { + if (!promptsQueue) return; + + const reqParams = await fetchCommonReqParams(drawData, { + prompts: convertPromptFormat(promptsQueue || []), + }); + + const { result, sessionId } = + await fetchModelResults( + EnumModelType.SegmentByMask, + reqParams, + ); + if (result) { + const { mask } = result; + const color = + drawData.creatingObject?.color || + getAnnotColor(editState.latestLabelId); + const maskRleStr = mask.counts || ''; + const creatingObj = { + type: EObjectType.Mask, + hidden: false, + labelId: editState.latestLabelId, + currIndex: -1, + maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), + maskRle: maskRleStr, + status: EObjectStatus.Checked, + color, + }; + setDrawDataWithHistory((s) => { + s.creatingObject = creatingObj; + s.prompt.promptsQueue = promptsQueue; + s.prompt.sessionId = sessionId; + s.prompt.creatingPrompt = undefined; + }); + return true; + } + }; + + const requestAiPoseEstimation = async (drawData: DrawData) => { + // TODO: Integrate custom templates + const { lines, pointNames, pointColors } = BODY_TEMPLATE; + const reqParams = await fetchCommonReqParams(drawData, {}); + if (drawData.isBatchEditing) { + const objectList = [...drawData.objectList]; + if ( + drawData.activeObjectIndex > -1 && + objectList[drawData.activeObjectIndex] && + drawData.creatingObject + ) { + // update creating object + objectList[drawData.activeObjectIndex] = { + ...objectList[drawData.activeObjectIndex], + ...drawData.creatingObject, + }; + } + const skeletonObjs = objectList.filter( + (obj) => + obj.type === EObjectType.Skeleton && + obj.status === EObjectStatus.Checked, + ); + if (skeletonObjs.length > 0) { + const objects = skeletonObjs.map((item) => { + return { + keypoints: item.keypoints + ? newTranslatePointObjsToPointAttrs( + item.keypoints.points, + naturalSize, + clientSize, + ).points + : undefined, + bbox: item.rect + ? translateRectToPointsArray(item.rect, clientSize, naturalSize) + : undefined, + }; + }); + Object.assign(reqParams, { objects }); + } + } + + const { result, sessionId } = await fetchModelResults( + EnumModelType.Pose, + reqParams, + ); + + if (result) { + const { objects } = result; + if (objects && objects.length > 0) { + const skeletonObjs = objects.map((obj) => { + let { bbox, keypoints, score } = obj; + const newObj: IAnnotationObject = { + labelId: editState.latestLabelId, + color: getAnnotColor(editState.latestLabelId), + type: EObjectType.Skeleton, + hidden: false, + conf: score, + status: EObjectStatus.Checked, + }; + if (bbox) { + const rect = translatePointsToRect(bbox, naturalSize, clientSize); + Object.assign(newObj, { rect: { visible: true, ...rect } }); + } + if (keypoints && lines && pointColors && pointNames) { + const pointObjs = newTranslatePointsToPointObjs( + keypoints, + pointNames, + pointColors, + naturalSize, + clientSize, + ); + Object.assign(newObj, { + keypoints: { + points: pointObjs, + lines, + }, + }); + } + return newObj; + }); + + setDrawDataWithHistory((s) => { + if (!s.isBatchEditing) { + s.isBatchEditing = true; + } + const commitedObjects = s.objectList.filter( + (obj) => obj.status === EObjectStatus.Commited, + ); + s.objectList = [...commitedObjects, ...skeletonObjs]; + if (s.creatingObject && s.objectList[s.activeObjectIndex]) { + s.creatingObject = { ...s.objectList[s.activeObjectIndex] }; + } + s.prompt.sessionId = sessionId; + }); + return true; + } + } + }; + + const requestEdgeStitchingForMask = async (drawData: DrawData) => { + if ( + !drawData.prompt.creatingPrompt?.stroke || + !drawData.prompt.creatingPrompt?.radius + ) + return; + + const { stroke, radius } = drawData.prompt.creatingPrompt; + + const maskObjects = drawData.objectList.filter( + (item) => item.type === EObjectType.Mask, + ); + + if (maskObjects.length < 2) { + message.error(localeText('DDSAnnotator.smart.tip.edgeStitchError')); + setDrawData((s) => { + s.prompt.creatingPrompt = undefined; + }); + return; + } + + const masks: IMask[] = maskObjects.map((item) => ({ + counts: item.maskRle || '', + size: [naturalSize.height, naturalSize.width], + })); + + const points = stroke.reduce((acc: number[], point: IPoint) => { + const { x, y } = point; + const naturalPoint = getNaturalPoint([x, y], naturalSize, clientSize); + return acc.concat([naturalPoint.x, naturalPoint.y]); + }, []); + + const reqParams = await fetchCommonReqParams(drawData, { + masks, + prompts: [ + { + type: EPromptType.Stroke, + stroke: points, + radius, + }, + ], + }); + + const { result, sessionId } = + await fetchModelResults( + EnumModelType.MaskEdgeStitching, + reqParams, + ); + if (result && result.masks?.length > 0) { + const newMaskObjects = maskObjects.map((item, index) => { + const maskRleStr = result.masks?.[index]?.counts || ''; + return { + ...item, + maskRle: maskRleStr, + maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, item.color), + }; + }); + + // Replace all instances of the mask type + const leftObjs = drawData.objectList.filter( + (obj) => obj.type !== EObjectType.Mask, + ); + + setDrawDataWithHistory((s) => { + s.objectList = [...leftObjs, ...newMaskObjects]; + s.prompt.creatingPrompt = undefined; + s.prompt.sessionId = sessionId; + }); + return true; + } + }; + + const requestSegmentEverything = async ( + params?: NsApiAnnotator.FetchSegmentEverythingReq, + ) => { + if (!currImageItem) return; + + const reqParams = { + image: await getServerAddressableUrl(currImageItem.url), + ...params, + }; + + const { result } = await fetchModelResults( + EnumModelType.SegmentEverything, + reqParams, + ); + if (result && result.masks?.length > 0) { + // change to display different color + setEditState((s) => { + s.annotsDisplayOptions.colorByCategory = false; + }); + const maskObjects: IAnnotationObject[] = result.masks.map((item) => { + const color = getAnnotColor(editState.latestLabelId); + const maskRleStr = item?.counts || ''; + return { + type: EObjectType.Mask, + hidden: false, + labelId: editState.latestLabelId, + maskRle: maskRleStr, + maskCanvasElement: rleToCanvas(maskRleStr, naturalSize, color), + conf: 1, + status: EObjectStatus.Checked, + color, + }; + }); + setDrawDataWithHistory((s) => { + s.objectList = maskObjects; + s.isBatchEditing = true; + }); + return true; + } + }; + + const onAiAnnotation: OnAiAnnotationFunc = useCallback( + async ({ + type, + drawData: propsDrawData, + aiLabels, + promptsQueue, + segmentEverythingParams, + }) => { + if (editState.isRequiring || !currImageItem) return; + + const drawData = propsDrawData || editorDrawData; + + const hide = message.loading( + localeText('DDSAnnotator.smart.msg.loading'), + 100000, + ); + try { + setLoading(true); + setEditState((s) => { + s.isRequiring = true; + }); + const aiType = type || EBasicToolTypeMap[drawData.selectedTool]; + let isSuccess; + switch (aiType) { + case EObjectType.Rectangle: { + if ( + drawData.selectedModel[drawData.selectedTool] === + EnumModelType.Detection + ) { + isSuccess = await requestAiDetection(drawData, aiLabels || ''); + } else { + isSuccess = await requestIvpDetection(drawData, promptsQueue); + } + break; + } + case EObjectType.Skeleton: { + isSuccess = await requestAiPoseEstimation(drawData); + break; + } + case EObjectType.Polygon: { + isSuccess = await requestAiSegmentByPolygon(drawData, promptsQueue); + break; + } + case EObjectType.Mask: { + const model = drawData.selectedModel[drawData.selectedTool]; + if (model === EnumModelType.SegmentEverything) { + if (drawData.selectedSubTool === ESubToolItem.AutoEdgeStitching) { + isSuccess = await requestEdgeStitchingForMask(drawData); + } else if ( + drawData.selectedSubTool === ESubToolItem.AutoSegmentEverything + ) { + isSuccess = await requestSegmentEverything( + segmentEverythingParams, + ); + } + } else if (model === EnumModelType.IVP) { + isSuccess = await requestIvpMask(drawData, promptsQueue); + } else { + isSuccess = await requestAiSegmentByMask(drawData, promptsQueue); + } + break; + } + default: + message.warning('Plan to Support!'); + break; + } + if (isSuccess) { + message.success(localeText('DDSAnnotator.smart.msg.success')); + } + } catch (error) { + setDrawDataWithHistory((s) => { + if (s.prompt.creatingPrompt) { + s.prompt.creatingPrompt = undefined; + } + }); + message.error(localeText('DDSAnnotator.smart.msg.error')); + } finally { + setLoading(false); + setEditState((s) => { + s.isRequiring = false; + }); + setDrawData((s) => { + s.prompt.activeRectWhileLoading = undefined; + }); + hide(); + } + }, + [editorDrawData], + ); + + return { + onAiAnnotation, + }; +}; + +export default useAiModels; diff --git a/packages/components/src/Annotator/hooks/useMouseEvents.tsx b/packages/components/src/Annotator/hooks/useMouseEvents.tsx index 19a230d..ec25be2 100644 --- a/packages/components/src/Annotator/hooks/useMouseEvents.tsx +++ b/packages/components/src/Annotator/hooks/useMouseEvents.tsx @@ -162,7 +162,6 @@ const useMouseEvents = ({ } else { setMoveVisibleAreaInterval(undefined); } - updateRender(); }; const getFocusFilter = () => { @@ -273,9 +272,9 @@ const useMouseEvents = ({ }); } else { s.activeObjectIndex = index; - if (!drawData.objectList[index].frameEmpty) { + if (!s.objectList[index]?.frameEmpty) { s.creatingObject = { - ...drawData.objectList[index], + ...s.objectList[index], currIndex: undefined, startPoint: undefined, tempMaskSteps: [], @@ -445,7 +444,13 @@ const useMouseEvents = ({ }, ) ) { - checkContainerVisibleArea(); + const noUnfinishedMaskStep = + drawData.creatingObject.type === EObjectType.Mask && + !drawData.creatingObject?.maskStep; + if (!noUnfinishedMaskStep) { + checkContainerVisibleArea(); + } + updateRender(); return; } } else if ( @@ -461,7 +466,12 @@ const useMouseEvents = ({ object: drawData.creatingObject, }) ) { - checkContainerVisibleArea(); + const noUnfinishedMaskStep = + objectType === EObjectType.Mask && !drawData.creatingObject?.maskStep; + if (!noUnfinishedMaskStep) { + checkContainerVisibleArea(); + } + updateRender(); return; } } diff --git a/packages/components/src/Annotator/hooks/useSubtools.tsx b/packages/components/src/Annotator/hooks/useSubtools.tsx index c28cc87..9a33827 100644 --- a/packages/components/src/Annotator/hooks/useSubtools.tsx +++ b/packages/components/src/Annotator/hooks/useSubtools.tsx @@ -55,46 +55,36 @@ const useSubTools = ({ drawData, onChangePointResolution }: IProps) => { ); }, [drawData.objectList, drawData.creatingObject, drawData.isBatchEditing]); - const isManualAvailable = useMemo(() => { - return ( - !drawData.prompt.sessionId && - !( - drawData.prompt.promptsQueue && drawData.prompt.promptsQueue.length > 0 - ) && - !drawData.isBatchEditing - ); - }, [drawData.prompt, drawData.isBatchEditing]); - const basicMaskTools: TToolItem[] = useMemo( () => [ { key: ESubToolItem.PenAdd, name: localeText('DDSAnnotator.subtoolbar.mask.penAdd'), icon: , - available: isManualAvailable, + available: true, }, { key: ESubToolItem.PenErase, name: localeText('DDSAnnotator.subtoolbar.mask.penErase'), icon: , - available: isManualAvailable && !!drawData.creatingObject, + available: !!drawData.creatingObject, }, { key: ESubToolItem.BrushAdd, name: localeText('DDSAnnotator.subtoolbar.mask.brushAdd'), icon: , - available: isManualAvailable, + available: true, withSize: true, }, { key: ESubToolItem.BrushErase, name: localeText('DDSAnnotator.subtoolbar.mask.brushErase'), icon: , - available: isManualAvailable && !!drawData.creatingObject, + available: !!drawData.creatingObject, withSize: true, }, ], - [isManualAvailable, drawData.creatingObject], + [drawData.creatingObject], ); const isgTools: TToolItem[] = useMemo(() => { @@ -160,10 +150,12 @@ const useSubTools = ({ drawData, onChangePointResolution }: IProps) => { key: ESubToolItem.NegativeVisualPrompt, name: localeText('DDSAnnotator.subtoolbar.visualprompt.negative'), icon: , - available: true, + available: !!drawData.prompt.promptsQueue?.some( + (item) => item.isPositive, + ), }, ]; - }, []); + }, [drawData.prompt]); const samTools: TToolItem[] = useMemo(() => { return [ @@ -187,26 +179,26 @@ const useSubTools = ({ drawData, onChangePointResolution }: IProps) => { }, [isSegEverythingAvailable]); const showSubTools = useMemo(() => { - if (drawData.selectedTool === EBasicToolItem.Mask) return true; + const { selectedTool, creatingObject, AIAnnotation, selectedModel } = + drawData; if ( - drawData.selectedTool === EBasicToolItem.Polygon && - drawData.AIAnnotation + selectedTool === EBasicToolItem.Mask || + creatingObject?.type === EObjectType.Mask ) return true; if ( - drawData.selectedTool === EBasicToolItem.Rectangle && - drawData.AIAnnotation && - drawData.selectedModel[drawData.selectedTool] === EnumModelType.IVP + selectedTool === EBasicToolItem.Rectangle && + AIAnnotation && + selectedModel[selectedTool] === EnumModelType.IVP ) return true; - if (drawData.creatingObject?.type === EObjectType.Mask) return true; - if ( - drawData.creatingObject?.type === EObjectType.Polygon && - drawData.AIAnnotation + (selectedTool === EBasicToolItem.Polygon || + creatingObject?.type === EObjectType.Polygon) && + AIAnnotation ) return true; diff --git a/packages/components/src/Annotator/hooks/useToolActions.ts b/packages/components/src/Annotator/hooks/useToolActions.ts index 087f456..84fe78c 100644 --- a/packages/components/src/Annotator/hooks/useToolActions.ts +++ b/packages/components/src/Annotator/hooks/useToolActions.ts @@ -9,6 +9,7 @@ import { EnumModelType, EObjectType, ESubToolItem, + TOOL_MODELS_MAP, } from '../constants'; import { objectToRle, rleToCanvas } from '../tools/useMask'; import { @@ -21,7 +22,7 @@ import { IAnnotsDisplayOptions, } from '../type'; -import { OnAiAnnotationFunc } from './useActions'; +import { OnAiAnnotationFunc } from './useAiModels'; interface IProps { mode: EditorMode; @@ -187,28 +188,24 @@ const useToolActions = ({ s.creatingObject = undefined; s.prompt = {}; s.activeObjectIndex = -1; - if ( - [ESubToolItem.PenErase, ESubToolItem.BrushErase].includes( - s.selectedSubTool, - ) - ) { + if (s.selectedSubTool === ESubToolItem.PenErase) { s.selectedSubTool = ESubToolItem.PenAdd; + } else if (s.selectedSubTool === ESubToolItem.BrushErase) { + s.selectedSubTool = ESubToolItem.BrushAdd; } }); setEditState((s) => { s.latestLabelId = labelId; }); }, - [drawData.creatingObject, drawData.activeObjectIndex, drawData.objectList], + [ + drawData.creatingObject, + drawData.activeObjectIndex, + drawData.objectList, + drawData.selectedSubTool, + ], ); - const onCloseAnnotationEditor = useCallback(() => { - setDrawData((s) => { - s.creatingObject = undefined; - s.activeObjectIndex = -1; - }); - }, []); - const onAcceptValidObjects = useCallback(() => { setDrawDataWithHistory((s) => { const validObjs = cloneDeep(drawData.objectList) @@ -244,9 +241,61 @@ const useToolActions = ({ }); }, [drawData.objectList]); + const isInAiSession = useCallback(() => { + const { + selectedTool, + AIAnnotation, + selectedModel, + selectedSubTool, + isBatchEditing, + creatingObject, + } = drawData; + + if (!AIAnnotation) return false; + + if (selectedTool === EBasicToolItem.Rectangle) { + return isBatchEditing; + } + + if (selectedTool === EBasicToolItem.Polygon) { + return creatingObject; + } + + if (selectedTool === EBasicToolItem.Skeleton) { + return isBatchEditing; + } + + if (selectedTool === EBasicToolItem.Mask) { + const currModel = selectedModel[selectedTool]; + if (currModel === EnumModelType.IVP) { + return isBatchEditing; + } + + if ( + currModel === EnumModelType.SegmentEverything && + selectedSubTool === ESubToolItem.AutoSegmentEverything + ) { + return isBatchEditing; + } + + if (currModel === EnumModelType.SegmentByMask) { + return creatingObject; + } + + return false; + } + return false; + }, [ + drawData.selectedTool, + drawData.selectedModel, + drawData.AIAnnotation, + drawData.selectedSubTool, + drawData.isBatchEditing, + drawData.creatingObject, + ]); + const selectTool = useCallback( (tool: EBasicToolItem) => { - console.log(drawData.selectedTool, drawData.AIAnnotation, tool); if ( mode !== EditorMode.Edit || (tool === drawData.selectedTool && !drawData.AIAnnotation) || @@ -254,6 +303,8 @@ const useToolActions = ({ ) return; + if (isInAiSession()) return; + setDrawData((s) => { s.selectedTool = tool; s.AIAnnotation = false; @@ -268,34 +319,46 @@ const useToolActions = ({ drawData.selectedTool, drawData.isBatchEditing, drawData.AIAnnotation, + isInAiSession, ], ); const selectSubTool = useCallback( - (tool: ESubToolItem) => { - if (mode !== EditorMode.Edit || tool === drawData.selectedSubTool) return; + (subtool: ESubToolItem) => { + const { + selectedTool, + selectedModel, + selectedSubTool, + AIAnnotation, + isBatchEditing, + } = drawData; + + if (mode !== EditorMode.Edit || subtool === selectedSubTool) return; + + // TODO: check subtool belong to current tool & model if ( - drawData.selectedTool === EBasicToolItem.Mask && - drawData.selectedModel[drawData.selectedTool] === - EnumModelType.SegmentEverything && - drawData.isBatchEditing + selectedTool === EBasicToolItem.Mask && + AIAnnotation && + selectedModel[selectedTool] === EnumModelType.SegmentEverything && + isBatchEditing ) { return; } setDrawData((s) => { - s.selectedSubTool = tool; + s.selectedSubTool = subtool; }); - - // save unfinished mask object - if (tool === ESubToolItem.AutoEdgeStitching && drawData.creatingObject) { - onFinishCurrCreate( - drawData.creatingObject.labelId || editState.latestLabelId || '', - ); - } }, - [mode, drawData.selectedSubTool, drawData.isBatchEditing], + [ + mode, + drawData.selectedTool, + drawData.AIAnnotation, + drawData.selectedModel, + drawData.isBatchEditing, + drawData.selectedSubTool, + isInAiSession, + ], ); const forceChangeTool = useCallback( @@ -379,17 +442,15 @@ const useToolActions = ({ const activeAIAnnotation = useCallback( (active: boolean) => { - if (!process.env.MODEL_API_PATH && active) { - displayAIModeUnavailableModal(); - return; - } - if (mode !== EditorMode.Edit || drawData.isBatchEditing || manualMode) - return; + if (mode !== EditorMode.Edit || manualMode) return; + + if (isInAiSession()) return; + setDrawData((s) => { s.AIAnnotation = active; }); }, - [mode, drawData.isBatchEditing], + [mode, manualMode, isInAiSession], ); const onChangeSkeletonConf = useCallback( @@ -480,11 +541,32 @@ const useToolActions = ({ updateAllObject(newObjectList); }, [drawData.objectList, getAnnotColor]); - const onSelectModel = useCallback((modelKey: EnumModelType) => { - setDrawData((s) => { - s.selectedModel[s.selectedTool] = modelKey; - }); - }, []); + const checkChangeModel = useCallback( + (modelKey: EnumModelType) => { + const { selectedTool } = drawData; + + const currModels = TOOL_MODELS_MAP[selectedTool]; + if (!currModels.includes(modelKey)) return false; + + if (isInAiSession()) return false; + + return true; + }, + [TOOL_MODELS_MAP, drawData.selectedTool, isInAiSession], + ); + + const onSelectModel = useCallback( + (modelKey: EnumModelType) => { + if (!checkChangeModel(modelKey)) { + return; + } + + setDrawData((s) => { + s.selectedModel[s.selectedTool] = modelKey; + }); + }, + [checkChangeModel], + ); useEffect(() => { setDrawData((s) => { @@ -529,7 +611,6 @@ const useToolActions = ({ return { onChangeObjectLabel, onFinishCurrCreate, - onCloseAnnotationEditor, onAcceptValidObjects, onAbortBatchObjects, selectTool, diff --git a/packages/components/src/Annotator/hooks/useTopTools.tsx b/packages/components/src/Annotator/hooks/useTopTools.tsx index c5d4803..8f0c7e2 100644 --- a/packages/components/src/Annotator/hooks/useTopTools.tsx +++ b/packages/components/src/Annotator/hooks/useTopTools.tsx @@ -172,6 +172,9 @@ const useTopTools = ({ ), }); + if (mode === EditorMode.Edit && fileName) { + actions.unshift({ customElement: <>{fileName} }); + } return actions; }, [ mode, diff --git a/packages/components/src/Annotator/index.less b/packages/components/src/Annotator/index.less index 5f1e56d..1757583 100644 --- a/packages/components/src/Annotator/index.less +++ b/packages/components/src/Annotator/index.less @@ -100,6 +100,7 @@ width: 100%; height: 100vh; background-color: #000; + border-radius: 0; overflow: hidden; .editor-container { diff --git a/packages/components/src/Annotator/sevices/index.ts b/packages/components/src/Annotator/sevices/index.ts index 540164e..0275f5d 100644 --- a/packages/components/src/Annotator/sevices/index.ts +++ b/packages/components/src/Annotator/sevices/index.ts @@ -41,70 +41,53 @@ export namespace NsApiAnnotator { ? FetchAIPoseEstimationRsp : never; - export interface FetchAIDetectionReq { - image: string; - text: string; + export interface CommonReqParams { + image?: string; + sessionId?: string; + } + + export interface FetchAIDetectionReq extends CommonReqParams { + prompts: ReqPromptItem[]; } - export interface FetchIVPReq { + export interface FetchIVPReq extends CommonReqParams { promptImage?: string; inferImage?: string; labelTypes: string[]; // ["bbox", "mask"] prompts: ReqPromptItem[]; - sessionId?: string; } - export interface FetchAIPolygonSegmentReq { - image: string; // image_id:// | base64:// | http:// | https:// + export interface FetchAIPolygonSegmentReq extends CommonReqParams { density: number; // (0, 1) default 0.2 - area: number[]; // [xmin, ymin, xmax, ymax]; prompts: ReqPromptItem[]; - sessionId?: string; } - export interface FetchAIMaskSegmentReq { - image?: string; // required when first request - sessionId?: string; + export interface FetchAIMaskSegmentReq extends CommonReqParams { prompts: ReqPromptItem[]; } - export interface FetchEdgeStitchingReq { - image?: string; + export interface FetchEdgeStitchingReq extends CommonReqParams { masks: IMask[]; prompts: ReqPromptItem[]; } - export interface SegmentEverythingParams { + export interface FetchSegmentEverythingReq extends CommonReqParams { pointsPerSide?: number; // default 32 predIouThresh?: number; // default 0.89 minMaskRegionArea?: number; // default 300 } - export interface FetchSegmentEverythingReq extends SegmentEverythingParams { - image?: string; - } - - export interface FetchAIPoseEstimationReq { - image: string; - targets: string; - template: { - lines: number[]; - pointNames: string[]; - pointColors: string[]; - }; + export interface FetchAIPoseEstimationReq extends CommonReqParams { objects?: Array<{ - categoryName: string; - boundingBox: IBoundingBox; - points: number[]; + bbox: [number, number, number, number]; + keypoints: number[]; // [x, y, visible, conf, ...] }>; } export interface FetchAIDetectionRsp { objects: Array<{ - categoryName: string; - boundingBox: IBoundingBox; + bbox: [number, number, number, number]; score: number; - normalizedScore: number; }>; suggestThreshold: number; } @@ -119,8 +102,6 @@ export namespace NsApiAnnotator { } export interface FetchAIPolygonSegmentRsp { - image: string; // image_id:// - sessionId: string; polygons: number[][]; // [[x1, y1, x2, y2, ...], [xn, yn, xn+1, yn+1, ...], ....] } @@ -132,10 +113,9 @@ export namespace NsApiAnnotator { } export interface FetchAIPoseEstimationRsp { objects: Array<{ - categoryName: string; - boundingBox: IBoundingBox; - points: number[]; - conf: number; + bbox: [number, number, number, number]; + keypoints: number[]; // [x, y, visible, conf, ...] + score: number; }>; } @@ -161,31 +141,31 @@ async function fetchTaskUuid( params: any, options?: { [key: string]: any }, ) { - return request( - `${process.env.MODEL_API_PATH}/tasks/${type}`, - { - method: 'POST', - data: { - ...params, - }, - ...(options || { - hideCodeErrorMsg: true, - }), + const postUrl = process.env.MODEL_API_PATH + ? `${process.env.MODEL_API_PATH}/tasks/${type}` + : `/v1/algos/${type}`; + return request(postUrl, { + method: 'POST', + data: { + ...params, }, - ); + ...(options || { + hideCodeErrorMsg: true, + }), + }); } function fetchTaskResults( taskUuid: string, options?: { [key: string]: any }, ) { - return request>( - `${process.env.MODEL_API_PATH}/task_statuses/${taskUuid}`, - { - method: 'GET', - ...(options || {}), - }, - ); + const getUrl = process.env.MODEL_API_PATH + ? `${process.env.MODEL_API_PATH}/task_statuses/${taskUuid}` + : `/v1/algos/tasks/${taskUuid}`; + return request>(getUrl, { + method: 'GET', + ...(options || {}), + }); } export async function pollTaskResults( diff --git a/packages/components/src/Annotator/tools/base.ts b/packages/components/src/Annotator/tools/base.ts index 21ace95..d132964 100644 --- a/packages/components/src/Annotator/tools/base.ts +++ b/packages/components/src/Annotator/tools/base.ts @@ -3,7 +3,7 @@ import { CursorState } from 'ahooks/lib/useMouse'; import { Updater } from 'use-immer'; import { DisplayOption, EElementType, EObjectType } from '../constants'; -import { OnAiAnnotationFunc } from '../hooks/useActions'; +import { OnAiAnnotationFunc } from '../hooks/useAiModels'; import { Category, DrawData, diff --git a/packages/components/src/Annotator/tools/usePolygon.ts b/packages/components/src/Annotator/tools/usePolygon.ts index 8bb0ae3..90543de 100644 --- a/packages/components/src/Annotator/tools/usePolygon.ts +++ b/packages/components/src/Annotator/tools/usePolygon.ts @@ -520,6 +520,11 @@ const usePolygon: ToolInstanceHook = ({ }); return true; } + + if (drawData.creatingObject) { + return true; + } + return false; }; diff --git a/packages/components/src/Annotator/tools/useRectangle.ts b/packages/components/src/Annotator/tools/useRectangle.ts index d0579f5..1d6cf8b 100644 --- a/packages/components/src/Annotator/tools/useRectangle.ts +++ b/packages/components/src/Annotator/tools/useRectangle.ts @@ -292,6 +292,7 @@ const useRectangle: ToolInstanceHook = ({ setDrawData((s) => { const model = s.selectedModel[s.selectedTool]; if (s.AIAnnotation && model === EnumModelType.IVP) { + s.activeObjectIndex = -1; s.prompt.creatingPrompt = { type: EPromptType.Rect, startPoint: point, diff --git a/packages/components/src/Annotator/type.ts b/packages/components/src/Annotator/type.ts index 7bae8ed..76bc1b6 100644 --- a/packages/components/src/Annotator/type.ts +++ b/packages/components/src/Annotator/type.ts @@ -150,6 +150,7 @@ export enum EPromptType { Stroke = 'stroke', EdgeStitch = 'edgeStitch', Modify = 'modify', + Text = 'text', } export interface PromptItem { @@ -165,6 +166,8 @@ export interface PromptItem { radius?: number; /** Modify */ polygons?: number[][]; + /** Text */ + text?: string; } export interface ReqPromptItem { @@ -175,12 +178,13 @@ export interface ReqPromptItem { stroke?: number[]; radius?: number; polygons?: number[][]; + text?: string; } export interface IPrompt { + sessionId?: string; creatingPrompt?: PromptItem; promptsQueue?: PromptItem[]; - sessionId?: string; activeRectWhileLoading?: IRect; } @@ -256,9 +260,6 @@ export interface EditState { pointIndex: number; lineIndex: number; }; - imageCacheId?: string; - // TODO - imageCacheIdForPolygon?: string; isCtrlPressed: boolean; hideCreatingObject: boolean; imageDisplayOptions: IImageDisplayOptions; diff --git a/packages/components/src/Annotator/utils/compute.ts b/packages/components/src/Annotator/utils/compute.ts index 48d5eca..48f6722 100644 --- a/packages/components/src/Annotator/utils/compute.ts +++ b/packages/components/src/Annotator/utils/compute.ts @@ -1,5 +1,5 @@ import { CursorState } from 'ahooks/lib/useMouse'; -import { cloneDeep, isEqual, isNumber } from 'lodash'; +import { cloneDeep, isEqual, isNumber, isUndefined, omitBy } from 'lodash'; import { EElementType, @@ -1693,7 +1693,13 @@ export const convertFrameObjectsIntoFramesObjects = ( } const frameEmpty = obj?.frameEmpty || Boolean(!obj); let resultObject = obj; - if (frameIdx > activeIndex && isEqual(obj, objectframes[activeIndex])) { + if ( + frameIdx > activeIndex && + isEqual( + omitBy(obj, isUndefined), + omitBy(objectframes[activeIndex], isUndefined), + ) + ) { // [active frame, later changed frame] -> same change resultObject = item; } else if ( @@ -1719,7 +1725,7 @@ export const convertFrameObjectsIntoFramesObjects = ( customStyles: item.customStyles, attributes: item.attributes, status: item.status, - frameEmpty: obj?.frameEmpty || Boolean(!obj), + frameEmpty, }; }); });