From 2d3a0edaee97d9643febbf2e57e77f75f56a535f Mon Sep 17 00:00:00 2001 From: Kurt Stolle Date: Sat, 24 Feb 2024 08:07:15 +0100 Subject: [PATCH] Fix: Reinit on stage completion --- sources/unipercept/_api_config.py | 11 ++++++----- sources/unipercept/cli/__main__.py | 11 ++++++++++- sources/unipercept/cli/echo.py | 9 +++++++-- sources/unipercept/cli/train.py | 4 +++- sources/unipercept/data/__init__.py | 2 +- sources/unipercept/engine/_engine.py | 9 +++++---- sources/unipercept/evaluators/_panoptic.py | 6 +++--- 7 files changed, 35 insertions(+), 17 deletions(-) diff --git a/sources/unipercept/_api_config.py b/sources/unipercept/_api_config.py index 0e61743..cda64c7 100644 --- a/sources/unipercept/_api_config.py +++ b/sources/unipercept/_api_config.py @@ -2,8 +2,8 @@ This file defines some basic API methods for working with UniPercept models, data and other submodules. """ from __future__ import annotations -import contextlib +import contextlib import os import re import typing as T @@ -25,7 +25,6 @@ import torch.types import unipercept - from unipercept.data.ops import Op from unipercept.data.sets import Metadata from unipercept.model import InputData, ModelBase @@ -142,9 +141,9 @@ def read_config(config: ConfigParam) -> DictConfig: config A DictConfig object. """ - from unipercept.engine._engine import _sort_children_by_suffix from unipercept import file_io from unipercept.config import load_config, load_config_remote + from unipercept.engine._engine import _sort_children_by_suffix if isinstance(config, str): try: @@ -167,7 +166,7 @@ def read_config(config: ConfigParam) -> DictConfig: if not isinstance(obj, DictConfig): msg = f"Expected a DictConfig, got {obj}" raise TypeError(msg) - + if config_path.suffix == ".yaml": # Check if the config has a latest checkpoint models_path = config_path.parent / "outputs" / "checkpoints" @@ -273,10 +272,11 @@ def create_engine(config: ConfigParam) -> unipercept.engine.Engine: engine A engine instance. """ - from .config import instantiate from unipercept.engine import Engine from unipercept.state import barrier + from .config import instantiate + config = read_config(config) engine = T.cast(Engine, instantiate(config.ENGINE)) @@ -536,6 +536,7 @@ def prepare_images( The dataset metadata. """ from torch.utils.data import DataLoader + from unipercept import file_io from unipercept.model import InputData diff --git a/sources/unipercept/cli/__main__.py b/sources/unipercept/cli/__main__.py index 971c649..c2dafde 100644 --- a/sources/unipercept/cli/__main__.py +++ b/sources/unipercept/cli/__main__.py @@ -6,7 +6,16 @@ import sys -from unipercept.cli import backbones, command, datasets, echo, profile, trace, train, evaluate +from unipercept.cli import ( + backbones, + command, + datasets, + echo, + evaluate, + profile, + trace, + train, +) __all__ = ["backbones", "echo", "profile", "trace", "train", "datasets", "evaluate"] diff --git a/sources/unipercept/cli/echo.py b/sources/unipercept/cli/echo.py index 61916ef..3bbae20 100644 --- a/sources/unipercept/cli/echo.py +++ b/sources/unipercept/cli/echo.py @@ -1,5 +1,7 @@ from __future__ import annotations + import sys + from unipercept.cli._command import command __all__ = [] @@ -44,7 +46,7 @@ def main(args): print(res) elif fmt == "json": - import json + import json res = json.dumps(out, indent=4, ensure_ascii=False) print(res, file=sys.stdout, flush=True) @@ -62,7 +64,10 @@ def main(args): @command.with_config def echo(parser): parser.add_argument( - "--format", default="pprint", help="output format", choices=["yaml", "pprint", "json"] + "--format", + default="pprint", + help="output format", + choices=["yaml", "pprint", "json"], ) parser.add_argument("--key", default="config", help="key to output") diff --git a/sources/unipercept/cli/train.py b/sources/unipercept/cli/train.py index e5b1da8..4515a68 100644 --- a/sources/unipercept/cli/train.py +++ b/sources/unipercept/cli/train.py @@ -120,7 +120,9 @@ def _main(args): model_factory = up.model.ModelFactory(lazy_config.MODEL, weights=args.weights) try: if args.evaluation: - _logger.info("Running in EVALUATION ONLY MODE. Be aware that no weights are loaded if not provided explicitly!") + _logger.info( + "Running in EVALUATION ONLY MODE. Be aware that no weights are loaded if not provided explicitly!" + ) results = engine.run_evaluation(model_factory) _logger.info( "Evaluation results: \n%s", up.log.create_table(results, format="long") diff --git a/sources/unipercept/data/__init__.py b/sources/unipercept/data/__init__.py index 70793ae..a29a390 100644 --- a/sources/unipercept/data/__init__.py +++ b/sources/unipercept/data/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from . import collect, io, ops, pseudolabeler, sets, tensors, types, pipes +from . import collect, io, ops, pipes, pseudolabeler, sets, tensors, types from ._helpers import * from ._loader import * diff --git a/sources/unipercept/engine/_engine.py b/sources/unipercept/engine/_engine.py index f25584c..35d101e 100644 --- a/sources/unipercept/engine/_engine.py +++ b/sources/unipercept/engine/_engine.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -import sys + import enum as E import functools import gc @@ -13,11 +13,12 @@ import os import re import shutil +import sys import time import typing as T +import uuid from datetime import datetime from uuid import uuid4 -import uuid import torch import torch._dynamo @@ -35,8 +36,8 @@ from torch.utils.data import Dataset from typing_extensions import override -from unipercept import file_io import unipercept +from unipercept import file_io from unipercept.data import DataLoaderFactory from unipercept.engine._params import EngineParams, EngineStage from unipercept.engine._trial import Trial, TrialWithParameters @@ -1213,7 +1214,7 @@ def _start_experiment_trackers(self, *, restart: bool = True) -> None: "wandb": { "name": self.config_name, "job_type": job_type, - "reinit": False, + "reinit": True, "group": group_name, "notes": "\n\n".join( ( diff --git a/sources/unipercept/evaluators/_panoptic.py b/sources/unipercept/evaluators/_panoptic.py index a28187f..88f2b1c 100644 --- a/sources/unipercept/evaluators/_panoptic.py +++ b/sources/unipercept/evaluators/_panoptic.py @@ -224,10 +224,10 @@ def compute_pq( total=sample_amt, disable=not check_main_process(local=True) or not self.show_progress, ) - #mp_context = M.get_context("spawn" if device.type != "cpu" else None) - #with concurrent.futures.ProcessPoolExecutor( + # mp_context = M.get_context("spawn" if device.type != "cpu" else None) + # with concurrent.futures.ProcessPoolExecutor( # min(cpus_available(), M.cpu_count() // 2, 32), mp_context=mp_context - #) as pool: + # ) as pool: with concurrent.futures.ThreadPoolExecutor() as pool: for result in pool.map(compute_at, range(sample_amt)): progress_bar.update(1)