From b1c547a8fb22f7b62bb9c46cbb39e16841db08a7 Mon Sep 17 00:00:00 2001 From: AlexandreKempf Date: Tue, 20 Feb 2024 14:49:44 +0100 Subject: [PATCH] plots: add bboxes to `dvc plots diff` so VSCode can used it to display bbox on images Fixes #10198 --- dvc/render/__init__.py | 1 + dvc/render/convert.py | 3 ++- dvc/render/converter/image.py | 21 ++++++++++++++++----- dvc/repo/plots/__init__.py | 27 +++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/dvc/render/__init__.py b/dvc/render/__init__.py index a3ee972fb9..82c11337dc 100644 --- a/dvc/render/__init__.py +++ b/dvc/render/__init__.py @@ -6,3 +6,4 @@ ANCHOR_DEFINITIONS = "anchor_definitions" TYPE_KEY = "type" SRC = "src" +ANNOTATIONS = "annotations" diff --git a/dvc/render/convert.py b/dvc/render/convert.py index 11e98114e9..c4d739c674 100644 --- a/dvc/render/convert.py +++ b/dvc/render/convert.py @@ -1,6 +1,6 @@ from typing import Union -from dvc.render import REVISION, REVISIONS, SRC, TYPE_KEY +from dvc.render import REVISION, REVISIONS, SRC, TYPE_KEY, ANNOTATIONS from dvc.render.converter.image import ImageConverter from dvc.render.converter.vega import VegaConverter @@ -42,6 +42,7 @@ def to_json(renderer, split: bool = False) -> list[dict]: { TYPE_KEY: renderer.TYPE, REVISIONS: [datapoint.get(REVISION)], + **datapoint.get(ANNOTATIONS), "url": datapoint.get(SRC), } for datapoint in renderer.datapoints diff --git a/dvc/render/converter/image.py b/dvc/render/converter/image.py index c3bbdf628e..3ae8ede2bc 100644 --- a/dvc/render/converter/image.py +++ b/dvc/render/converter/image.py @@ -1,8 +1,9 @@ import base64 import os +import json from typing import TYPE_CHECKING, Any -from dvc.render import FILENAME, REVISION, SRC +from dvc.render import FILENAME, REVISION, SRC, ANNOTATIONS from . import Converter @@ -46,15 +47,25 @@ def flat_datapoints(self, revision: str) -> tuple[list[dict], dict]: path = self.properties.get("out") datapoints = [] datas, properties = self.convert() - for filename, _, image_data in datas: + + if "annotations" in properties: + with open(properties["annotations"], encoding="utf-8") as annotations_path: + annotations = json.load(annotations_path) + + for filename, _, image_content in datas: if path: if not os.path.isdir(path): os.makedirs(path, exist_ok=True) src = self._write_image( - os.path.abspath(path), revision, filename, image_data + os.path.abspath(path), revision, filename, image_content ) else: - src = self._encode_image(image_data) - datapoint = {REVISION: revision, FILENAME: filename, SRC: src} + src = self._encode_image(image_content) + datapoint = { + REVISION: revision, + FILENAME: filename, + SRC: src, + ANNOTATIONS: annotations, + } datapoints.append(datapoint) return datapoints, properties diff --git a/dvc/repo/plots/__init__.py b/dvc/repo/plots/__init__.py index fe526d2891..d5a5737315 100644 --- a/dvc/repo/plots/__init__.py +++ b/dvc/repo/plots/__init__.py @@ -6,6 +6,7 @@ from copy import deepcopy from functools import partial from multiprocessing import cpu_count +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union import dpath @@ -357,6 +358,10 @@ def infer_data_sources(plot_id, config=None): if isinstance(x, dict): sources.append(first(x.keys())) + # annotation = config.get("annotations", None) + # if isinstance(annotation, str): + # sources.append(annotation) + return ldistinct(source for source in sources) @@ -467,6 +472,7 @@ def _resolve_definitions( for k, v in unpacked["data"].items() if _closest_parent(fs, k, plot_ids_parents) == data_path } + unpacked = _add_annotations_to_image_definition(unpacked) dpath.merge(result, unpacked) elif _matches(targets, config_path, plot_id): adjusted_props = _adjust_sources(fs, plot_props, config_dir) @@ -483,6 +489,27 @@ def _closest_parent(fs, path, parents): best_result = common_path return best_result +def _add_annotations_to_image_definition(target): + if "data" not in target: + return target + + path_to_remove = [] + for path in target["data"]: + annotation_file = Path(path).with_suffix(".json").as_posix() + if ImageRenderer.matches(path) and annotation_file in target["data"]: + annotations = {"annotations": annotation_file} + # empty dict all share the same reference, so override them + if target["data"][path]: + target["data"][path].update({"annotations": annotation_file}) + else: + target["data"][path] = annotations + path_to_remove.append(annotation_file) + # remove the annotation files from the data once they are added to the config + target["data"] = { + k: v for k, v in target["data"].items() if k not in path_to_remove + } + return target + def _collect_pipeline_files(repo, targets: list[str], props, onerror=None): result: dict[str, dict] = {}