Skip to content

Commit

Permalink
Update trainer to Nodelib 0.13.0
Browse files Browse the repository at this point in the history
chores: improve type hints and cleanup imports
  • Loading branch information
denniswittich committed Dec 6, 2024
1 parent 5cda389 commit b8969bd
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 61 deletions.
1 change: 0 additions & 1 deletion trainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-setuptools \
python3-matplotlib \
curl \
# git \
&& rm -rf /var/lib/apt/lists/*

RUN apt update && \
Expand Down
13 changes: 6 additions & 7 deletions trainer/app_code/model_files.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
from pathlib import Path
from typing import List, Optional, Union


def get_best(training_path: Path) -> Optional[Path]:
def get_best(training_path: Path) -> Path | None:
path = training_path / 'result/weights'
if not path.exists():
return None
Expand All @@ -14,7 +13,7 @@ def get_best(training_path: Path) -> Optional[Path]:
return weightfiles[0]


def get_all_weightfiles(training_path: Path) -> List[Path]:
def get_all_weightfiles(training_path: Path) -> list[Path]:
path = (training_path / 'result/weights').absolute()
if not path.exists():
return []
Expand All @@ -32,7 +31,7 @@ def epoch_from_weightfile(weightfile: Path) -> int:
return 0


def delete_older_epochs(training_path: Path, weightfile: Path):
def delete_older_epochs(training_path: Path, weightfile: Path) -> None:
all_weightfiles = get_all_weightfiles(training_path)

target_epoch = epoch_from_weightfile(weightfile)
Expand All @@ -42,18 +41,18 @@ def delete_older_epochs(training_path: Path, weightfile: Path):
delete_json_for_weightfile(f)


def delete_json_for_weightfile(weightfile: Path):
def delete_json_for_weightfile(weightfile: Path) -> None:
_try_remove(weightfile.with_suffix('.json'))


def _try_remove(file: Path):
def _try_remove(file: Path) -> None:
try:
os.remove(file)
except Exception:
logging.exception(f'could not remove {file}')


def get_new(training_path: Path) -> Union[Path, None]:
def get_new(training_path: Path) -> Path | None:
all_weightfiles = get_all_weightfiles(training_path)
if all_weightfiles:
all_weightfiles.sort(key=epoch_from_weightfile)
Expand Down
9 changes: 4 additions & 5 deletions trainer/app_code/tests/test_yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import os
import shutil
from pathlib import Path
from typing import Dict, List, Tuple
from uuid import uuid4

import pytest
from learning_loop_node.data_classes import (Category, Context, TrainerState,
Training)
from learning_loop_node.data_classes import Category, Context, Training
from learning_loop_node.data_exchanger import DataExchanger
from learning_loop_node.enums import TrainerState
from learning_loop_node.helpers.misc import create_image_folder
from learning_loop_node.loop_communication import LoopCommunicator
from learning_loop_node.trainer.downloader import TrainingsDownloader
Expand Down Expand Up @@ -305,7 +304,7 @@ def assert_yaml_content(yaml_path, **kwargs):


async def download_training_data(images_folder: str, data_exchanger: DataExchanger, glc: LoopCommunicator
) -> Tuple[List[Category], List[Dict]]:
) -> tuple[list[Category], list[dict]]:

image_data, _ = await TrainingsDownloader(data_exchanger).download_training_data(images_folder)

Expand All @@ -317,7 +316,7 @@ async def download_training_data(images_folder: str, data_exchanger: DataExchang
return categories, image_data


def mock_epoch(number: int, confusion_matrix: Dict):
def mock_epoch(number: int, confusion_matrix: dict) -> None:
os.makedirs('result/weights/', exist_ok=True)
with open(f'result/weights/epoch{number}.json', 'w') as f:
json.dump(confusion_matrix, f)
Expand Down
11 changes: 5 additions & 6 deletions trainer/app_code/tests/test_yolov5_cla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import logging
import os
from pathlib import Path
from typing import Dict, List, Tuple
from uuid import uuid4

import pytest
from learning_loop_node.data_classes import (Category, CategoryType, Context,
ModelInformation, TrainerState,
Training)
from learning_loop_node.data_classes import (Category, Context,
ModelInformation, Training)
from learning_loop_node.data_exchanger import DataExchanger
from learning_loop_node.enums import CategoryType, TrainerState
from learning_loop_node.helpers.misc import create_image_folder
from learning_loop_node.loop_communication import LoopCommunicator
from learning_loop_node.trainer.downloader import TrainingsDownloader
Expand Down Expand Up @@ -262,7 +261,7 @@ async def test_cla_create_file_structure(self, use_training_dir):


async def download_training_data(images_folder: str, data_exchanger: DataExchanger, glc: LoopCommunicator
) -> Tuple[List[Category], List[Dict]]:
) -> tuple[list[Category], list[dict]]:
image_data, _ = await TrainingsDownloader(data_exchanger).download_training_data(images_folder)

response = await glc.get(f"/{os.environ['LOOP_ORGANIZATION']}/projects/{os.environ['LOOP_PROJECT']}/data")
Expand All @@ -275,7 +274,7 @@ async def download_training_data(images_folder: str, data_exchanger: DataExchang
return categories, image_data


def mock_epoch(confusion_matrix: Dict):
def mock_epoch(confusion_matrix: dict) -> None:
os.makedirs('result/weights/', exist_ok=True)
with open('result/weights/best.json', 'w') as f:
json.dump(confusion_matrix, f)
Expand Down
19 changes: 10 additions & 9 deletions trainer/app_code/yolov5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any

from learning_loop_node.data_classes import CategoryType, Training
from learning_loop_node.data_classes import Training
from learning_loop_node.enums import CategoryType
from ruamel.yaml import YAML

yaml = YAML()


def get_ids_and_sizes_of_point_classes(training: Training) -> Tuple[List[str], List[str]]:
def get_ids_and_sizes_of_point_classes(training: Training) -> tuple[list[str], list[str]]:
"""Returns a list of trainingids and sizes (in px) of point classes in the training data."""
assert training is not None, 'Training should have data'
point_ids, point_sizes = [], []
Expand All @@ -21,7 +22,7 @@ def get_ids_and_sizes_of_point_classes(training: Training) -> Tuple[List[str], L
return point_ids, point_sizes


def category_lookup_from_training(training: Training) -> Dict[str, str]:
def category_lookup_from_training(training: Training) -> dict[str, str]:
return {c.name: c.id for c in training.categories}


Expand Down Expand Up @@ -81,7 +82,7 @@ def _create_set(training: Training, set_name: str) -> int:
return img_count


def _create_set_cla(training: Training, set_name: str):
def _create_set_cla(training: Training, set_name: str) -> None:
training_path = training.training_folder
images_path = f'{training_path}/{set_name}'

Expand Down Expand Up @@ -116,7 +117,7 @@ def _create_set_cla(training: Training, set_name: str):
logging.info(f'Created {count} image links')


def create_dataset_yaml(training: Training):
def create_dataset_yaml(training: Training) -> None:
categories = category_lookup_from_training(training)
path = training.training_folder
data = {
Expand All @@ -131,7 +132,7 @@ def create_dataset_yaml(training: Training):
yaml.dump(data, f)


def create_file_structure_cla(training: Training):
def create_file_structure_cla(training: Training) -> None:
path = training.training_folder
assert path is not None, 'Training should have a path'
Path(path).mkdir(parents=True, exist_ok=True)
Expand All @@ -140,7 +141,7 @@ def create_file_structure_cla(training: Training):
_create_set_cla(training, 'train')


def create_file_structure(training: Training):
def create_file_structure(training: Training) -> None:
"""Uses:
- training.training_folder to create the file structure.
- training.image_data to create the image links and annotations.
Expand All @@ -155,7 +156,7 @@ def create_file_structure(training: Training):
logging.info(f'Prepared file structure with {num_train_imgs} training images and {num_test_imgs} test images')


def set_hyperparameters_in_file(yaml_path: str, hyperparameter: Dict[str, Any]):
def set_hyperparameters_in_file(yaml_path: str, hyperparameter: dict[str, Any]) -> None:

with open(yaml_path) as f:
content = yaml.load(f)
Expand Down
36 changes: 17 additions & 19 deletions trainer/app_code/yolov5_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import cv2
import yaml # type: ignore
from fastapi.encoders import jsonable_encoder
from learning_loop_node.data_classes import (BoxDetection, CategoryType,
from learning_loop_node.data_classes import (BoxDetection,
ClassificationDetection,
Detections, ModelInformation,
PointDetection, PretrainedModel,
TrainingStateData)
from learning_loop_node.enums import CategoryType
from learning_loop_node.trainer import trainer_logic
from learning_loop_node.trainer.exceptions import (CriticalError,
NodeNeedsRestartError)
Expand All @@ -41,19 +41,19 @@ def __init__(self) -> None:
# ---------------------------------------- IMPLEMENTED ABSTRACT PROPERTIES ----------------------------------------

@property
def training_progress(self) -> Optional[float]:
def training_progress(self) -> float | None:
if self._executor is None:
return None
if self.is_cla:
return self._get_progress_from_log_cla()
return self._get_progress_from_log()

@property
def model_architecture(self):
def model_architecture(self) -> str:
return 'yolov5_cls' if self.is_cla else 'yolov5'

@property
def provided_pretrained_models(self) -> List[PretrainedModel]:
def provided_pretrained_models(self) -> list[PretrainedModel]:
if self.is_cla:
return [PretrainedModel(name='s-cls', label='YOLO v5 small', description='~5fps on Jetson Nano'),
PretrainedModel(name='x-cls', label='YOLO v5 large', description='~5fps on Jetson Nano'),]
Expand All @@ -75,7 +75,7 @@ def _can_resume(self) -> bool:
async def _resume(self) -> None:
await self._start(model=str(self.training.training_folder_path / 'result/weights/published/latest.pt'))

def _get_executor_error_from_log(self) -> Optional[str]:
def _get_executor_error_from_log(self) -> str | None:
if self._executor is None:
return None
for line in self._executor.get_log_by_lines(tail=50):
Expand All @@ -85,7 +85,7 @@ def _get_executor_error_from_log(self) -> Optional[str]:
return 'graphics card not found'
return None

def _get_new_best_training_state(self) -> Optional[TrainingStateData]:
def _get_new_best_training_state(self) -> TrainingStateData | None:
if self.is_cla:
weightfile = model_files.get_best(self.training.training_folder_path)
else:
Expand Down Expand Up @@ -119,7 +119,7 @@ def _on_metrics_published(self, training_state_data: TrainingStateData) -> None:
model_files.delete_json_for_weightfile(Path(weightfile))
model_files.delete_older_epochs(Path(self.training.training_folder), Path(weightfile))

async def _get_latest_model_files(self) -> Dict[str, List[str]]:
async def _get_latest_model_files(self) -> dict[str, list[str]]:
weightfile = (self.training.training_folder_path / "result/weights/published/latest.pt").absolute()
if not os.path.isfile(weightfile):
logging.error('No model found at %s - Training failed!', weightfile)
Expand All @@ -139,9 +139,8 @@ async def _get_latest_model_files(self) -> Dict[str, List[str]]:

return {self.model_format: ['/tmp/model.pt', f'{training_path}/hyp.yaml'], 'yolov5_wts': ['/tmp/model.wts']}

async def _detect(
self, model_information: ModelInformation, images: List[str],
model_folder: str) -> List[Detections]:
async def _detect(self, model_information: ModelInformation, images: list[str],
model_folder: str) -> list[Detections]:
images_folder = '/tmp/imagelinks_for_detecting'
shutil.rmtree(images_folder, ignore_errors=True)
os.makedirs(images_folder)
Expand Down Expand Up @@ -256,7 +255,7 @@ def _try_replace_optimized_hyperparameter(self):
else:
logging.warning('No optimized hyperparameter found (!)')

def _parse(self, labels_path: str, images_folder: str, model_information: ModelInformation) -> List[Detections]:
def _parse(self, labels_path: str, images_folder: str, model_information: ModelInformation) -> list[Detections]:
detections = []
if os.path.exists(labels_path):
for filename in os.scandir(labels_path):
Expand Down Expand Up @@ -302,7 +301,7 @@ def _get_progress_from_log(self) -> float:
# ---------------------------------------- HELPER METHODS ----------------------------------------

@staticmethod
def _parse_file_cla(model_info: ModelInformation, filepath: str) -> List[ClassificationDetection]:
def _parse_file_cla(model_info: ModelInformation, filepath: str) -> list[ClassificationDetection]:
with open(filepath, 'r') as f:
content = f.readlines()
classification_detections = []
Expand All @@ -324,9 +323,8 @@ def _parse_file_cla(model_info: ModelInformation, filepath: str) -> List[Classif
return classification_detections

@staticmethod
def clip_box(
x: float, y: float, width: float, height: float, img_width: int, img_height: int) -> Tuple[
float, float, float, float]:
def clip_box(x: float, y: float, width: float, height: float, img_width: int, img_height: int
) -> tuple[float, float, float, float]:
'''make sure the box is within the image
x,y is the center of the box
'''
Expand All @@ -343,14 +341,14 @@ def clip_box(
return x, y, width, height

@staticmethod
def clip_point(x: float, y: float, img_width: int, img_height: int) -> Tuple[float, float]:
def clip_point(x: float, y: float, img_width: int, img_height: int) -> tuple[float, float]:
x = min(max(0, x), img_width)
y = min(max(0, y), img_height)
return x, y

@staticmethod
def _parse_file(model_info: ModelInformation, images_folder: str, filename: str) -> Tuple[
List[BoxDetection], List[PointDetection]]:
def _parse_file(model_info: ModelInformation, images_folder: str, filename: str
) -> tuple[list[BoxDetection], list[PointDetection]]:
uuid = os.path.splitext(os.path.basename(filename))[0]

# TODO change to approach that does not require to read the image
Expand Down
4 changes: 2 additions & 2 deletions trainer/docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ fi
# NODE_LIB_VERSION should only be used, to build the corresponding version and deploy to docker
# make sure the remote repository always has the 'latest' tag (otherwise the CI tests will fail)

SEMANTIC_VERSION=0.1.8
NODE_LIB_VERSION=0.12.0
SEMANTIC_VERSION=0.1.9
NODE_LIB_VERSION=0.13.0

if [ "$2" = "test_latest" ]; then
image="zauberzeug/yolov5-trainer:latest"
Expand Down
1 change: 0 additions & 1 deletion trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
icecream.install()


# check if env variable 'YOLOV5_MODE' is set to 'cla'
print(f'Uvicorn reload is set to: {os.getenv("UVICORN_RELOAD", "FALSE").lower() == "true"}')
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S', level=logging.INFO)
Expand Down
15 changes: 5 additions & 10 deletions trainer/pred_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,11 @@
import torch

from app_code.yolov5.models.common import DetectMultiBackend
from app_code.yolov5.utils.dataloaders import (IMG_FORMATS, VID_FORMATS,
LoadImages, LoadScreenshots,
LoadStreams)
from app_code.yolov5.utils.general import (LOGGER, Profile, check_file,
check_img_size, check_imshow,
check_requirements, colorstr, cv2,
increment_path, non_max_suppression,
print_args, scale_boxes,
strip_optimizer, xyxy2xywh)
from app_code.yolov5.utils.plots import Annotator, save_one_box
from app_code.yolov5.utils.dataloaders import LoadImages
from app_code.yolov5.utils.general import (LOGGER, Profile, check_img_size,
check_requirements, colorstr,
non_max_suppression, print_args,
scale_boxes, xyxy2xywh)
from app_code.yolov5.utils.torch_utils import (select_device,
smart_inference_mode)

Expand Down
3 changes: 2 additions & 1 deletion trainer/test_cuda_init.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import signal
import sys
import time
from typing import Any

import torch


def signal_handler(sig, frame):
def signal_handler(sig: int, frame: Any) -> None:
print('\n\nSignal received:', sig, flush=True)
sys.exit(0)

Expand Down

0 comments on commit b8969bd

Please sign in to comment.