Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLO Export: Use absolute data path + guess train/val/test directories #8

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion dagshub_annotation_converter/converters/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,33 @@ def annotations_to_string(annotations: Sequence[IRImageAnnotationBase], context:
return "\n".join([export_fn(ann, context) for ann in filtered_annotations])


def _get_common_folder_with_part(paths: List[Path], part: str) -> Optional[Path]:
paths_with_part = [p for p in paths if part in p.parts]
if len(paths_with_part) == 0:
return None

candidates = set()
for p in paths_with_part:
for i, path_part in enumerate(p.parts):
if path_part == part:
candidates.add(Path(*p.parts[: i + 1]))
if len(candidates) == 1:
return candidates.pop()

# Choose the shortest one. If there are multiple shortest ones - return None
shortest = min(candidates, key=lambda p: len(p.parts))
shortest_candidates = [p for p in candidates if len(p.parts) == len(shortest.parts)]
if len(shortest_candidates) > 1:
return None
return shortest


def _guess_train_val_test_split(image_paths: List[str]) -> Tuple[Optional[Path], Optional[Path], Optional[Path]]:
paths = [Path(p) for p in image_paths]
splits = ["train", "val", "test"]
return (*[_get_common_folder_with_part(paths, split) for split in splits],)


def export_to_fs(
context: YoloContext,
annotations: List[IRImageAnnotationBase],
Expand Down Expand Up @@ -166,7 +193,19 @@ def export_to_fs(
with open(annotation_filename, "w") as f:
f.write(annotation_content)

# TODO: test/val splitting
guessed_train_path, guessed_val_path, guessed_test_path = _guess_train_val_test_split(
list(grouped_annotations.keys())
)

# Don't accidentally override with Nones. If we find Nones, then assume YOLO should train on the whole dataset
if guessed_train_path is not None:
context.train_path = guessed_train_path

if guessed_val_path is not None:
context.val_path = guessed_val_path

context.test_path = guessed_test_path

yaml_file_path = export_path / meta_file
with open(yaml_file_path, "w") as yaml_f:
yaml_f.write(context.get_yaml_content())
Expand Down
8 changes: 7 additions & 1 deletion dagshub_annotation_converter/formats/yolo/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class YoloContext(ParentModel):
"""Path to the train data, relative to the base path"""
val_path: Optional[Path] = Path(".")
"""Path to the validation data, relative to the base path"""
test_path: Optional[Path] = None
"""Path to the test data, relative to the base path (defaults to None, might be discovered)"""

@staticmethod
def from_yaml_file(file_path: Union[str, Path], annotation_type: YoloAnnotationTypes) -> "YoloContext":
Expand All @@ -59,6 +61,8 @@ def from_yaml_file(file_path: Union[str, Path], annotation_type: YoloAnnotationT
res.train_path = Path(meta_dict["train"])
if "val" in meta_dict:
res.val_path = Path(meta_dict["val"])
if "test" in meta_dict:
res.test_path = Path(meta_dict["test"])

return res

Expand All @@ -81,7 +85,7 @@ def get_yaml_content(self, path_override: Optional[Path] = None) -> str:
)

content = {
"path": str(path),
"path": str(path.resolve()),
"names": {cat.id: cat.name for cat in self.categories.categories},
"nc": len(self.categories),
}
Expand All @@ -90,6 +94,8 @@ def get_yaml_content(self, path_override: Optional[Path] = None) -> str:
content["train"] = str(self.train_path)
if self.val_path is not None:
content["val"] = str(self.val_path)
if self.test_path is not None:
content["test"] = str(self.test_path)

if self.annotation_type == "pose":
if self.keypoints_in_annotation is None:
Expand Down
29 changes: 28 additions & 1 deletion tests/fs_export/yolo/test_fs_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dagshub_annotation_converter.converters.yolo import export_to_fs
from pathlib import Path

from dagshub_annotation_converter.converters.yolo import export_to_fs, _get_common_folder_with_part
from dagshub_annotation_converter.formats.yolo import YoloContext
from dagshub_annotation_converter.ir.image import (
CoordinateStyle,
Expand All @@ -9,6 +11,8 @@
IRPosePoint,
)

import pytest


def test_bbox_export(tmp_path):
ctx = YoloContext(annotation_type="bbox", path="data")
Expand Down Expand Up @@ -146,3 +150,26 @@ def test_not_exporting_wrong_annotations(tmp_path):
assert (tmp_path / "yolo_dagshub.yaml").exists()
assert (tmp_path / "data" / "labels" / "cats" / "1.txt").exists()
assert not (tmp_path / "data" / "labels" / "dogs" / "2.txt").exists()


@pytest.mark.parametrize(
"paths, prefix, expected",
(
(["/a/b/c", "/a/b/d", "/a/b/e"], "b", "/a/b"),
(["/a/b/c", "/a/b/d", "/a/b/e"], "b", "/a/b"),
(["/a/b/c", "/a/b/d", "/a/b/b"], "b", "/a/b"),
(["/a/b/c", "/a/b/d", "/a/b/e/b"], "b", "/a/b"),
(["/a/b/c", "/a/e/b", "/a/e/b/b"], "b", "/a/b"),
(["/a/b/c", "/a/b/d", "/some_other/b/e"], "b", None), # Fails because there are two different common b folders
(["/a/b/c", "/a/some_other/d", "/a/b/e"], "b", "/a/b"),
(["/a/b/c", "/a/bbb/d", "/a/b/e"], "b", "/a/b"),
),
)
def test__get_common_folder_with_part(paths, prefix, expected):
paths = [Path(p) for p in paths]
actual = _get_common_folder_with_part(paths, prefix)

if expected is not None:
expected = Path(expected)

assert actual == expected
Loading