-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'fuse_cla_and_new_nodelib'
- Loading branch information
Showing
157 changed files
with
21,024 additions
and
1,249 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ __pycache__/ | |
*.wts | ||
*.engine | ||
*.pt | ||
trainer/runs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
from typing import List, Optional, Union | ||
|
||
|
||
def get_best(training_path: Path) -> Optional[Path]: | ||
path = training_path / 'result/weights' | ||
if not path.exists(): | ||
return None | ||
weightfiles = [path / f for f in os.listdir(path) if 'best' in f and f.endswith('.pt')] | ||
if len(weightfiles) == 0: | ||
return None | ||
return weightfiles[0] | ||
|
||
|
||
def get_all_weightfiles(training_path: Path) -> List[Path]: | ||
path = (training_path / 'result/weights').absolute() | ||
if not path.exists(): | ||
return [] | ||
weightfiles = [path / f for f in os.listdir(path) if 'epoch' in f and f.endswith('.pt')] | ||
return weightfiles | ||
|
||
|
||
def _epoch_from_weightfile(weightfile: Path) -> int: | ||
number = weightfile.name[5:-3] | ||
if number == '': | ||
return 0 | ||
return int(number) | ||
|
||
|
||
def delete_older_epochs(training_path: Path, weightfile: Path): | ||
all_weightfiles = get_all_weightfiles(training_path) | ||
|
||
target_epoch = _epoch_from_weightfile(weightfile) | ||
for f in all_weightfiles: | ||
if _epoch_from_weightfile(f) < target_epoch: | ||
_try_remove(f) | ||
delete_json_for_weightfile(f) | ||
|
||
|
||
def delete_json_for_weightfile(weightfile: Path): | ||
_try_remove(weightfile.with_suffix('.json')) | ||
|
||
|
||
def _try_remove(file: Path): | ||
try: | ||
os.remove(file) | ||
except Exception: | ||
logging.exception(f'could not remove {file}') | ||
|
||
|
||
def get_new(training_path: Path) -> Union[Path, None]: | ||
all_weightfiles = get_all_weightfiles(training_path) | ||
if all_weightfiles: | ||
all_weightfiles.sort(key=_epoch_from_weightfile) | ||
return all_weightfiles[-1] | ||
return None |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
import shutil | ||
import subprocess | ||
from typing import Dict | ||
|
||
import icecream | ||
import pytest | ||
from _pytest.fixtures import SubRequest | ||
# from dotenv import load_dotenv | ||
from learning_loop_node.data_classes import Context | ||
from learning_loop_node.data_exchanger import DataExchanger | ||
from learning_loop_node.loop_communication import LoopCommunicator | ||
|
||
icecream.install() | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
# load_dotenv() | ||
|
||
# -------------------- Session fixtures -------------------- | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def event_loop(): | ||
"""Overrides pytest default function scoped event loop""" | ||
policy = asyncio.get_event_loop_policy() | ||
loop = policy.new_event_loop() | ||
yield loop | ||
loop.close() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def prepare_model(): | ||
"""Download model for testing""" | ||
if not os.path.exists('app_code/tests/test_data/model.pt'): | ||
url = 'https://github.com/ultralytics/yolov5/releases/download/v6.0/yolov5n.pt' | ||
result = subprocess.run(f'curl -L {url} -o app_code/tests/test_data/model.pt', shell=True, check=True) | ||
assert result.returncode == 0 | ||
assert os.path.exists('app_code/tests/test_data/model.pt') | ||
yield | ||
|
||
# -------------------- Class marks -------------------- | ||
|
||
|
||
@pytest.fixture(autouse=True, scope='class') | ||
async def check_marks(request: SubRequest, glc: LoopCommunicator): # pylint: disable=redefined-outer-name | ||
"""Set environment variables for testing and generate project if requested""" | ||
|
||
markers = list(request.node.iter_markers('environment')) | ||
assert len(markers) <= 1, 'Only one environment marker allowed' | ||
if len(markers) == 1: | ||
marker = markers[0] | ||
os.environ['LOOP_ORGANIZATION'] = marker.kwargs['organization'] | ||
os.environ['LOOP_PROJECT'] = marker.kwargs['project'] | ||
os.environ['YOLOV5_MODE'] = marker.kwargs['mode'] | ||
|
||
markers = list(request.node.iter_markers('generate_project')) | ||
assert len(markers) <= 1, 'Only one generate_project marker allowed' | ||
if len(markers) == 1: | ||
marker = markers[0] | ||
configuration: Dict = marker.kwargs['configuration'] | ||
project = configuration['project_name'] | ||
# May not return 200 if project does not exist | ||
await glc.delete(f"/zauberzeug/projects/{project}?keep_images=true") | ||
await asyncio.sleep(1) | ||
assert (await glc.post("/zauberzeug/projects/generator", json=configuration)).status_code == 200 | ||
await asyncio.sleep(1) | ||
yield | ||
# assert (await lc.delete(f"/zauberzeug/projects/{project}?keep_images=true")).status_code == 200 | ||
else: | ||
yield | ||
|
||
|
||
# -------------------- Optional fixtures -------------------- | ||
|
||
@pytest.fixture(scope="session") | ||
async def glc(): | ||
"""The same LoopCommunicator is used for all tests | ||
Credentials are read from environment variables""" | ||
|
||
lc = LoopCommunicator() | ||
await lc.ensure_login() | ||
yield lc | ||
await lc.shutdown() | ||
|
||
|
||
@pytest.fixture() | ||
def data_exchanger(glc: LoopCommunicator): # pylint: disable=redefined-outer-name | ||
context = Context(organization=os.environ['LOOP_ORGANIZATION'], project=os.environ['LOOP_PROJECT']) | ||
dx = DataExchanger(context, glc) | ||
yield dx | ||
|
||
|
||
@pytest.fixture() | ||
def use_training_dir(prepare_model, request: SubRequest): | ||
"""Step into a temporary directory for training tests and back out again""" | ||
|
||
shutil.rmtree('/tmp/test_training', ignore_errors=True) | ||
os.makedirs('/tmp/test_training', exist_ok=True) | ||
shutil.copyfile('app_code/tests/test_data/model.pt', '/tmp/test_training/model.pt') | ||
os.chdir('/tmp/test_training/') | ||
yield | ||
shutil.rmtree('/tmp/test_training', ignore_errors=True) | ||
os.chdir(request.config.invocation_dir) |
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Own set of hyps that are forwarded to albumentation | ||
|
||
# Optimizer is hardcoded to SGD | ||
|
||
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) | ||
momentum: 0.843 # SGD momentum/Adam beta1 | ||
decay: 0.00001 # optimizer weight decay | ||
label_smoothing: 0.1 # Label smoothing epsilon | ||
batch_size: 4 | ||
epochs: 3 | ||
|
||
# Augmentation | ||
jitter: 0.4 # colour jitter forr brightness, contrast, satuaration (hue is c-jitter/2) | ||
hue_jitter: 0.1 | ||
min_scale: 0.1 # minimum image scale for augmentation | ||
min_ratio: 0.75 # minimum aspect ratio for augmentation | ||
r90_prob: 0.0 # rotate 90 probability | ||
|
||
# Maybe overwritten by learning loop | ||
hflip: 0.5 # horizontal flip probability | ||
vflip: 0.5 # vertical flip probability |
Oops, something went wrong.