Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plots: Add the bounding boxes informations in the dvc plots diff command so that VScode can display them. #10312

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dvc/render/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
ANCHOR_DEFINITIONS = "anchor_definitions"
TYPE_KEY = "type"
SRC = "src"
ANNOTATIONS = "annotations"
3 changes: 2 additions & 1 deletion dvc/render/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from dvc.render import REVISION, REVISIONS, SRC, TYPE_KEY
from dvc.render import ANNOTATIONS, REVISION, REVISIONS, SRC, TYPE_KEY
from dvc.render.converter.image import ImageConverter
from dvc.render.converter.vega import VegaConverter

Expand Down Expand Up @@ -42,6 +42,7 @@ def to_json(renderer, split: bool = False) -> list[dict]:
{
TYPE_KEY: renderer.TYPE,
REVISIONS: [datapoint.get(REVISION)],
**datapoint.get(ANNOTATIONS, {}),
"url": datapoint.get(SRC),
}
for datapoint in renderer.datapoints
Expand Down
68 changes: 59 additions & 9 deletions dvc/render/converter/image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import base64
import json
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Annotated, Any

from dvc.render import FILENAME, REVISION, SRC
from pydantic import BaseModel, Field, ValidationError

from dvc.render import ANNOTATIONS, FILENAME, REVISION, SRC

from . import Converter

if TYPE_CHECKING:
from dvc.types import StrPath

logger = logging.getLogger(__name__)


class ImageConverter(Converter):
@staticmethod
Expand All @@ -33,9 +39,11 @@ def _encode_image(image_data: bytes) -> str:
return f"data:image;base64,{base64_str}"

def convert(self) -> tuple[list[tuple[str, str, Any]], dict]:
datas = []
for filename, image_data in self.data.items():
datas.append((filename, "", image_data))
datas = [
(filename, "", image_data)
for filename, image_data in self.data.items()
if not filename.endswith(".json")
]
return datas, self.properties

def flat_datapoints(self, revision: str) -> tuple[list[dict], dict]:
Expand All @@ -46,15 +54,57 @@ def flat_datapoints(self, revision: str) -> tuple[list[dict], dict]:
path = self.properties.get("out")
datapoints = []
datas, properties = self.convert()
for filename, _, image_data in datas:

annotations = {}
if "annotations" in properties:
annotations = self._load_annotations(properties["annotations"])

for filename, _, image_content in datas:
if path:
if not os.path.isdir(path):
os.makedirs(path, exist_ok=True)
src = self._write_image(
os.path.abspath(path), revision, filename, image_data
os.path.abspath(path), revision, filename, image_content
)
else:
src = self._encode_image(image_data)
datapoint = {REVISION: revision, FILENAME: filename, SRC: src}
src = self._encode_image(image_content)
datapoint: dict[str, Any] = {
REVISION: revision,
FILENAME: filename,
SRC: src,
}
if annotations:
datapoint[ANNOTATIONS] = annotations
datapoints.append(datapoint)
return datapoints, properties

def _load_annotations(self, path: "StrPath") -> dict:
with open(path, encoding="utf-8") as annotations_path:
try:
return _Annotations(**json.load(annotations_path)).model_dump()
except json.JSONDecodeError as json_error:
logger.warning(json_error)
logger.warning("Annotations file %s is not a valid JSON file.", path)
return {"annotations": {}}
except ValidationError as pydantic_error:
logger.warning(pydantic_error)
logger.warning(
"Annotations file %s is not a valid annotations file.", path
)
return {"annotations": {}}


class _Coordinates(BaseModel):
left: int
top: int
bottom: int
right: int


class _BBoxe(BaseModel):
box: _Coordinates
score: Annotated[float, Field(ge=0, le=1)]


class _Annotations(BaseModel):
annotations: dict[str, list[_BBoxe]]
31 changes: 31 additions & 0 deletions dvc/repo/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
from functools import partial
from multiprocessing import cpu_count
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import dpath
Expand Down Expand Up @@ -357,6 +358,10 @@ def infer_data_sources(plot_id, config=None):
if isinstance(x, dict):
sources.append(first(x.keys()))

annotation = config.get("annotations", None)
if isinstance(annotation, str):
sources.append(annotation)

return ldistinct(source for source in sources)


Expand Down Expand Up @@ -467,6 +472,7 @@ def _resolve_definitions(
for k, v in unpacked["data"].items()
if _closest_parent(fs, k, plot_ids_parents) == data_path
}
unpacked = _add_annotations_to_image_definition(unpacked)
dpath.merge(result, unpacked)
elif _matches(targets, config_path, plot_id):
adjusted_props = _adjust_sources(fs, plot_props, config_dir)
Expand All @@ -484,6 +490,31 @@ def _closest_parent(fs, path, parents):
return best_result


def _add_annotations_to_image_definition(target):
if "data" not in target:
return target

path_to_remove = []
for path in target["data"]:
annotation_file = Path(path).with_suffix(".json").as_posix()
# if an image and a json match names, we consider the JSON to be annotations
# in that case, we add the annotations to the image "definition" as a property
# so they are link together.
if ImageRenderer.matches(path) and annotation_file in target["data"]:
# target["data"] empty dict all share the same reference, so override them
annotations = {"annotations": annotation_file}
if target["data"][path]:
target["data"][path].update(annotations)
else:
target["data"][path] = annotations
path_to_remove.append(annotation_file)
# remove the annotation files from the data once they are added to the config
target["data"] = {
k: v for k, v in target["data"].items() if k not in path_to_remove
}
return target


def _collect_pipeline_files(repo, targets: list[str], props, onerror=None):
result: dict[str, dict] = {}
top_plots = repo.index._plots
Expand Down
Loading