Skip to content

Commit 260f618

Browse files
authored
Merge pull request #8 from DagsHub/bug/test_path
YOLO Export: Use absolute data path + guess train/val/test directories
2 parents 67430f9 + 6ca950f commit 260f618

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed

dagshub_annotation_converter/converters/yolo.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,33 @@ def annotations_to_string(annotations: Sequence[IRImageAnnotationBase], context:
123123
return "\n".join([export_fn(ann, context) for ann in filtered_annotations])
124124

125125

126+
def _get_common_folder_with_part(paths: List[Path], part: str) -> Optional[Path]:
127+
paths_with_part = [p for p in paths if part in p.parts]
128+
if len(paths_with_part) == 0:
129+
return None
130+
131+
candidates = set()
132+
for p in paths_with_part:
133+
for i, path_part in enumerate(p.parts):
134+
if path_part == part:
135+
candidates.add(Path(*p.parts[: i + 1]))
136+
if len(candidates) == 1:
137+
return candidates.pop()
138+
139+
# Choose the shortest one. If there are multiple shortest ones - return None
140+
shortest = min(candidates, key=lambda p: len(p.parts))
141+
shortest_candidates = [p for p in candidates if len(p.parts) == len(shortest.parts)]
142+
if len(shortest_candidates) > 1:
143+
return None
144+
return shortest
145+
146+
147+
def _guess_train_val_test_split(image_paths: List[str]) -> Tuple[Optional[Path], Optional[Path], Optional[Path]]:
148+
paths = [Path(p) for p in image_paths]
149+
splits = ["train", "val", "test"]
150+
return (*[_get_common_folder_with_part(paths, split) for split in splits],)
151+
152+
126153
def export_to_fs(
127154
context: YoloContext,
128155
annotations: List[IRImageAnnotationBase],
@@ -166,7 +193,19 @@ def export_to_fs(
166193
with open(annotation_filename, "w") as f:
167194
f.write(annotation_content)
168195

169-
# TODO: test/val splitting
196+
guessed_train_path, guessed_val_path, guessed_test_path = _guess_train_val_test_split(
197+
list(grouped_annotations.keys())
198+
)
199+
200+
# Don't accidentally override with Nones. If we find Nones, then assume YOLO should train on the whole dataset
201+
if guessed_train_path is not None:
202+
context.train_path = guessed_train_path
203+
204+
if guessed_val_path is not None:
205+
context.val_path = guessed_val_path
206+
207+
context.test_path = guessed_test_path
208+
170209
yaml_file_path = export_path / meta_file
171210
with open(yaml_file_path, "w") as yaml_f:
172211
yaml_f.write(context.get_yaml_content())

dagshub_annotation_converter/formats/yolo/context.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class YoloContext(ParentModel):
3939
"""Path to the train data, relative to the base path"""
4040
val_path: Optional[Path] = Path(".")
4141
"""Path to the validation data, relative to the base path"""
42+
test_path: Optional[Path] = None
43+
"""Path to the test data, relative to the base path (defaults to None, might be discovered)"""
4244

4345
@staticmethod
4446
def from_yaml_file(file_path: Union[str, Path], annotation_type: YoloAnnotationTypes) -> "YoloContext":
@@ -59,6 +61,8 @@ def from_yaml_file(file_path: Union[str, Path], annotation_type: YoloAnnotationT
5961
res.train_path = Path(meta_dict["train"])
6062
if "val" in meta_dict:
6163
res.val_path = Path(meta_dict["val"])
64+
if "test" in meta_dict:
65+
res.test_path = Path(meta_dict["test"])
6266

6367
return res
6468

@@ -81,7 +85,7 @@ def get_yaml_content(self, path_override: Optional[Path] = None) -> str:
8185
)
8286

8387
content = {
84-
"path": str(path),
88+
"path": str(path.resolve()),
8589
"names": {cat.id: cat.name for cat in self.categories.categories},
8690
"nc": len(self.categories),
8791
}
@@ -90,6 +94,8 @@ def get_yaml_content(self, path_override: Optional[Path] = None) -> str:
9094
content["train"] = str(self.train_path)
9195
if self.val_path is not None:
9296
content["val"] = str(self.val_path)
97+
if self.test_path is not None:
98+
content["test"] = str(self.test_path)
9399

94100
if self.annotation_type == "pose":
95101
if self.keypoints_in_annotation is None:

tests/fs_export/yolo/test_fs_export.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from dagshub_annotation_converter.converters.yolo import export_to_fs
1+
from pathlib import Path
2+
3+
from dagshub_annotation_converter.converters.yolo import export_to_fs, _get_common_folder_with_part
24
from dagshub_annotation_converter.formats.yolo import YoloContext
35
from dagshub_annotation_converter.ir.image import (
46
CoordinateStyle,
@@ -9,6 +11,8 @@
911
IRPosePoint,
1012
)
1113

14+
import pytest
15+
1216

1317
def test_bbox_export(tmp_path):
1418
ctx = YoloContext(annotation_type="bbox", path="data")
@@ -146,3 +150,26 @@ def test_not_exporting_wrong_annotations(tmp_path):
146150
assert (tmp_path / "yolo_dagshub.yaml").exists()
147151
assert (tmp_path / "data" / "labels" / "cats" / "1.txt").exists()
148152
assert not (tmp_path / "data" / "labels" / "dogs" / "2.txt").exists()
153+
154+
155+
@pytest.mark.parametrize(
156+
"paths, prefix, expected",
157+
(
158+
(["/a/b/c", "/a/b/d", "/a/b/e"], "b", "/a/b"),
159+
(["/a/b/c", "/a/b/d", "/a/b/e"], "b", "/a/b"),
160+
(["/a/b/c", "/a/b/d", "/a/b/b"], "b", "/a/b"),
161+
(["/a/b/c", "/a/b/d", "/a/b/e/b"], "b", "/a/b"),
162+
(["/a/b/c", "/a/e/b", "/a/e/b/b"], "b", "/a/b"),
163+
(["/a/b/c", "/a/b/d", "/some_other/b/e"], "b", None), # Fails because there are two different common b folders
164+
(["/a/b/c", "/a/some_other/d", "/a/b/e"], "b", "/a/b"),
165+
(["/a/b/c", "/a/bbb/d", "/a/b/e"], "b", "/a/b"),
166+
),
167+
)
168+
def test__get_common_folder_with_part(paths, prefix, expected):
169+
paths = [Path(p) for p in paths]
170+
actual = _get_common_folder_with_part(paths, prefix)
171+
172+
if expected is not None:
173+
expected = Path(expected)
174+
175+
assert actual == expected

0 commit comments

Comments
 (0)