Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Feb 24, 2024
2 parents 64698aa + 2d3a0ed commit 2db5a29
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
9 changes: 5 additions & 4 deletions sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
This file defines some basic API methods for working with UniPercept models, data and other submodules.
"""
from __future__ import annotations
import contextlib

import contextlib
import os
import re
import typing as T
Expand All @@ -25,7 +25,6 @@
import torch.types

import unipercept

from unipercept.data.ops import Op
from unipercept.data.sets import Metadata
from unipercept.model import InputData, ModelBase
Expand Down Expand Up @@ -142,9 +141,9 @@ def read_config(config: ConfigParam) -> DictConfig:
config
A DictConfig object.
"""
from unipercept.engine._engine import _sort_children_by_suffix
from unipercept import file_io
from unipercept.config import load_config, load_config_remote
from unipercept.engine._engine import _sort_children_by_suffix

if isinstance(config, str):
try:
Expand Down Expand Up @@ -273,10 +272,11 @@ def create_engine(config: ConfigParam) -> unipercept.engine.Engine:
engine
A engine instance.
"""
from .config import instantiate
from unipercept.engine import Engine
from unipercept.state import barrier

from .config import instantiate

config = read_config(config)
engine = T.cast(Engine, instantiate(config.ENGINE))

Expand Down Expand Up @@ -536,6 +536,7 @@ def prepare_images(
The dataset metadata.
"""
from torch.utils.data import DataLoader

from unipercept import file_io
from unipercept.model import InputData

Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
command,
datasets,
echo,
evaluate,
profile,
trace,
train,
evaluate,
)

__all__ = ["backbones", "echo", "profile", "trace", "train", "datasets", "evaluate"]
Expand Down
2 changes: 2 additions & 0 deletions sources/unipercept/cli/echo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import sys

from unipercept.cli._command import command

__all__ = []
Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

from __future__ import annotations

from . import collect, io, ops, pseudolabeler, sets, tensors, types, pipes
from . import collect, io, ops, pipes, pseudolabeler, sets, tensors, types
from ._helpers import *
from ._loader import *
9 changes: 5 additions & 4 deletions sources/unipercept/engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from __future__ import annotations
import sys

import enum as E
import functools
import gc
Expand All @@ -13,11 +13,12 @@
import os
import re
import shutil
import sys
import time
import typing as T
import uuid
from datetime import datetime
from uuid import uuid4
import uuid

import torch
import torch._dynamo
Expand All @@ -35,8 +36,8 @@
from torch.utils.data import Dataset
from typing_extensions import override

from unipercept import file_io
import unipercept
from unipercept import file_io
from unipercept.data import DataLoaderFactory
from unipercept.engine._params import EngineParams, EngineStage
from unipercept.engine._trial import Trial, TrialWithParameters
Expand Down Expand Up @@ -1213,7 +1214,7 @@ def _start_experiment_trackers(self, *, restart: bool = True) -> None:
"wandb": {
"name": self.config_name,
"job_type": job_type,
"reinit": False,
"reinit": True,
"group": group_name,
"notes": "\n\n".join(
(
Expand Down
4 changes: 1 addition & 3 deletions sources/unipercept/engine/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class EngineParams:

full_determinism: bool = False
seed: int = 42
max_grad_norm: float = D.field(
default=5.0, metadata={"help": "Max gradient norm."}
)
max_grad_norm: float = D.field(default=5.0, metadata={"help": "Max gradient norm."})

# Memory tracker
memory_tracker: bool = D.field(
Expand Down

0 comments on commit 2db5a29

Please sign in to comment.