diff --git a/sources/unipercept/_api_config.py b/sources/unipercept/_api_config.py index e7d870c..cda64c7 100644 --- a/sources/unipercept/_api_config.py +++ b/sources/unipercept/_api_config.py @@ -2,8 +2,8 @@ This file defines some basic API methods for working with UniPercept models, data and other submodules. """ from __future__ import annotations -import contextlib +import contextlib import os import re import typing as T @@ -25,7 +25,6 @@ import torch.types import unipercept - from unipercept.data.ops import Op from unipercept.data.sets import Metadata from unipercept.model import InputData, ModelBase @@ -142,9 +141,9 @@ def read_config(config: ConfigParam) -> DictConfig: config A DictConfig object. """ - from unipercept.engine._engine import _sort_children_by_suffix from unipercept import file_io from unipercept.config import load_config, load_config_remote + from unipercept.engine._engine import _sort_children_by_suffix if isinstance(config, str): try: @@ -273,10 +272,11 @@ def create_engine(config: ConfigParam) -> unipercept.engine.Engine: engine A engine instance. """ - from .config import instantiate from unipercept.engine import Engine from unipercept.state import barrier + from .config import instantiate + config = read_config(config) engine = T.cast(Engine, instantiate(config.ENGINE)) @@ -536,6 +536,7 @@ def prepare_images( The dataset metadata. """ from torch.utils.data import DataLoader + from unipercept import file_io from unipercept.model import InputData diff --git a/sources/unipercept/cli/__main__.py b/sources/unipercept/cli/__main__.py index 05250b1..c2dafde 100644 --- a/sources/unipercept/cli/__main__.py +++ b/sources/unipercept/cli/__main__.py @@ -11,10 +11,10 @@ command, datasets, echo, + evaluate, profile, trace, train, - evaluate, ) __all__ = ["backbones", "echo", "profile", "trace", "train", "datasets", "evaluate"] diff --git a/sources/unipercept/cli/echo.py b/sources/unipercept/cli/echo.py index 48d4152..3bbae20 100644 --- a/sources/unipercept/cli/echo.py +++ b/sources/unipercept/cli/echo.py @@ -1,5 +1,7 @@ from __future__ import annotations + import sys + from unipercept.cli._command import command __all__ = [] diff --git a/sources/unipercept/data/__init__.py b/sources/unipercept/data/__init__.py index 70793ae..a29a390 100644 --- a/sources/unipercept/data/__init__.py +++ b/sources/unipercept/data/__init__.py @@ -2,6 +2,6 @@ from __future__ import annotations -from . import collect, io, ops, pseudolabeler, sets, tensors, types, pipes +from . import collect, io, ops, pipes, pseudolabeler, sets, tensors, types from ._helpers import * from ._loader import * diff --git a/sources/unipercept/engine/_engine.py b/sources/unipercept/engine/_engine.py index f25584c..35d101e 100644 --- a/sources/unipercept/engine/_engine.py +++ b/sources/unipercept/engine/_engine.py @@ -3,7 +3,7 @@ """ from __future__ import annotations -import sys + import enum as E import functools import gc @@ -13,11 +13,12 @@ import os import re import shutil +import sys import time import typing as T +import uuid from datetime import datetime from uuid import uuid4 -import uuid import torch import torch._dynamo @@ -35,8 +36,8 @@ from torch.utils.data import Dataset from typing_extensions import override -from unipercept import file_io import unipercept +from unipercept import file_io from unipercept.data import DataLoaderFactory from unipercept.engine._params import EngineParams, EngineStage from unipercept.engine._trial import Trial, TrialWithParameters @@ -1213,7 +1214,7 @@ def _start_experiment_trackers(self, *, restart: bool = True) -> None: "wandb": { "name": self.config_name, "job_type": job_type, - "reinit": False, + "reinit": True, "group": group_name, "notes": "\n\n".join( ( diff --git a/sources/unipercept/engine/_params.py b/sources/unipercept/engine/_params.py index f5b3078..caea485 100644 --- a/sources/unipercept/engine/_params.py +++ b/sources/unipercept/engine/_params.py @@ -124,9 +124,7 @@ class EngineParams: full_determinism: bool = False seed: int = 42 - max_grad_norm: float = D.field( - default=5.0, metadata={"help": "Max gradient norm."} - ) + max_grad_norm: float = D.field(default=5.0, metadata={"help": "Max gradient norm."}) # Memory tracker memory_tracker: bool = D.field(