Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Option to skip an error when downloading images from ls #361

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ml_utils/ml_utils_cli/cli/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ def export(
"provided (typically, if the source is Label Studio)"
),
] = 0.8,
error_raise: Annotated[
bool,
typer.Option(
help="Raise an error if an image download fails, only for Ultralytics"
),
] = True,
):
"""Export Label Studio annotation, either to Hugging Face Datasets or
local files (ultralytics format)."""
Expand Down Expand Up @@ -204,6 +210,7 @@ def export(
label_names_list,
typing.cast(int, project_id),
train_ratio=train_ratio,
error_raise=error_raise,
)

elif from_ == ExportSource.hf:
Expand All @@ -212,6 +219,7 @@ def export(
typing.cast(str, repo_id),
typing.cast(Path, output_dir),
download_images=download_images,
error_raise=error_raise,
)
else:
raise typer.BadParameter("Unsupported export format")
13 changes: 11 additions & 2 deletions ml_utils/ml_utils_cli/cli/apps/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ def add_prediction(
help="Launch in dry run mode, without uploading annotations to Label Studio"
),
] = False,
error_raise: Annotated[
bool,
typer.Option(
help="Raise an error if image download fails"
),
] = True,
):
"""Add predictions as pre-annotations to Label Studio tasks,
for an object detection model running on Triton Inference Server."""
Expand Down Expand Up @@ -245,7 +251,10 @@ def add_prediction(
threshold = 0.1

model = YOLO(model_name)
model.set_classes(labels)
if hasattr(model, "set_classes"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was from my previous PR, but it is running and error when I have a locally pre-trained model, hence this handling method

model.set_classes(labels)
else:
logger.warning("The model does not support setting classes directly.")
elif backend == PredictorBackend.triton:
if triton_uri is None:
raise typer.BadParameter("Triton URI is required for Triton backend")
Expand All @@ -262,7 +271,7 @@ def add_prediction(
image_url = task.data["image_url"]
image = typing.cast(
Image.Image,
get_image_from_url(image_url, error_raise=True),
get_image_from_url(image_url, error_raise=error_raise),
)
if backend == PredictorBackend.ultralytics:
results = model.predict(
Expand Down
9 changes: 6 additions & 3 deletions ml_utils/ml_utils_cli/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def export_from_ls_to_ultralytics(
category_names: list[str],
project_id: int,
train_ratio: float = 0.8,
error_raise: bool = True,
):
"""Export annotations from a Label Studio project to the Ultralytics
format.
Expand Down Expand Up @@ -161,7 +162,7 @@ def export_from_ls_to_ultralytics(
has_valid_annotation = True

if has_valid_annotation:
download_output = download_image(image_url, return_bytes=True)
download_output = download_image(image_url, return_bytes=True, error_raise=error_raise)
if download_output is None:
logger.error("Failed to download image: %s", image_url)
continue
Expand All @@ -182,7 +183,9 @@ def export_from_ls_to_ultralytics(


def export_from_hf_to_ultralytics(
repo_id: str, output_dir: Path, download_images: bool = True
repo_id: str, output_dir: Path,
download_images: bool = True,
error_raise: bool = True,
):
"""Export annotations from a Hugging Face dataset project to the
Ultralytics format.
Expand All @@ -207,7 +210,7 @@ def export_from_hf_to_ultralytics(
image_url = sample["meta"]["image_url"]

if download_images:
download_output = download_image(image_url, return_bytes=True)
download_output = download_image(image_url, return_bytes=True, error_raise=error_raise)
if download_output is None:
logger.error("Failed to download image: %s", image_url)
continue
Expand Down
Loading