Skip to content

Commit

Permalink
Fix: Reinit on stage completion
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Feb 24, 2024
1 parent 8e734e2 commit 2d3a0ed
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 17 deletions.
11 changes: 6 additions & 5 deletions sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
This file defines some basic API methods for working with UniPercept models, data and other submodules.
"""
from __future__ import annotations
import contextlib

import contextlib
import os
import re
import typing as T
Expand All @@ -25,7 +25,6 @@
import torch.types

import unipercept

from unipercept.data.ops import Op
from unipercept.data.sets import Metadata
from unipercept.model import InputData, ModelBase
Expand Down Expand Up @@ -142,9 +141,9 @@ def read_config(config: ConfigParam) -> DictConfig:
config
A DictConfig object.
"""
from unipercept.engine._engine import _sort_children_by_suffix
from unipercept import file_io
from unipercept.config import load_config, load_config_remote
from unipercept.engine._engine import _sort_children_by_suffix

if isinstance(config, str):
try:
Expand All @@ -167,7 +166,7 @@ def read_config(config: ConfigParam) -> DictConfig:
if not isinstance(obj, DictConfig):
msg = f"Expected a DictConfig, got {obj}"
raise TypeError(msg)

if config_path.suffix == ".yaml":
# Check if the config has a latest checkpoint
models_path = config_path.parent / "outputs" / "checkpoints"
Expand Down Expand Up @@ -273,10 +272,11 @@ def create_engine(config: ConfigParam) -> unipercept.engine.Engine:
engine
A engine instance.
"""
from .config import instantiate
from unipercept.engine import Engine
from unipercept.state import barrier

from .config import instantiate

config = read_config(config)
engine = T.cast(Engine, instantiate(config.ENGINE))

Expand Down Expand Up @@ -536,6 +536,7 @@ def prepare_images(
The dataset metadata.
"""
from torch.utils.data import DataLoader

from unipercept import file_io
from unipercept.model import InputData

Expand Down
11 changes: 10 additions & 1 deletion sources/unipercept/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@

import sys

from unipercept.cli import backbones, command, datasets, echo, profile, trace, train, evaluate
from unipercept.cli import (
backbones,
command,
datasets,
echo,
evaluate,
profile,
trace,
train,
)

__all__ = ["backbones", "echo", "profile", "trace", "train", "datasets", "evaluate"]

Expand Down
9 changes: 7 additions & 2 deletions sources/unipercept/cli/echo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import sys

from unipercept.cli._command import command

__all__ = []
Expand Down Expand Up @@ -44,7 +46,7 @@ def main(args):
print(res)

elif fmt == "json":
import json
import json

res = json.dumps(out, indent=4, ensure_ascii=False)
print(res, file=sys.stdout, flush=True)
Expand All @@ -62,7 +64,10 @@ def main(args):
@command.with_config
def echo(parser):
parser.add_argument(
"--format", default="pprint", help="output format", choices=["yaml", "pprint", "json"]
"--format",
default="pprint",
help="output format",
choices=["yaml", "pprint", "json"],
)
parser.add_argument("--key", default="config", help="key to output")

Expand Down
4 changes: 3 additions & 1 deletion sources/unipercept/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def _main(args):
model_factory = up.model.ModelFactory(lazy_config.MODEL, weights=args.weights)
try:
if args.evaluation:
_logger.info("Running in EVALUATION ONLY MODE. Be aware that no weights are loaded if not provided explicitly!")
_logger.info(
"Running in EVALUATION ONLY MODE. Be aware that no weights are loaded if not provided explicitly!"
)
results = engine.run_evaluation(model_factory)
_logger.info(
"Evaluation results: \n%s", up.log.create_table(results, format="long")
Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

from __future__ import annotations

from . import collect, io, ops, pseudolabeler, sets, tensors, types, pipes
from . import collect, io, ops, pipes, pseudolabeler, sets, tensors, types
from ._helpers import *
from ._loader import *
9 changes: 5 additions & 4 deletions sources/unipercept/engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from __future__ import annotations
import sys

import enum as E
import functools
import gc
Expand All @@ -13,11 +13,12 @@
import os
import re
import shutil
import sys
import time
import typing as T
import uuid
from datetime import datetime
from uuid import uuid4
import uuid

import torch
import torch._dynamo
Expand All @@ -35,8 +36,8 @@
from torch.utils.data import Dataset
from typing_extensions import override

from unipercept import file_io
import unipercept
from unipercept import file_io
from unipercept.data import DataLoaderFactory
from unipercept.engine._params import EngineParams, EngineStage
from unipercept.engine._trial import Trial, TrialWithParameters
Expand Down Expand Up @@ -1213,7 +1214,7 @@ def _start_experiment_trackers(self, *, restart: bool = True) -> None:
"wandb": {
"name": self.config_name,
"job_type": job_type,
"reinit": False,
"reinit": True,
"group": group_name,
"notes": "\n\n".join(
(
Expand Down
6 changes: 3 additions & 3 deletions sources/unipercept/evaluators/_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def compute_pq(
total=sample_amt,
disable=not check_main_process(local=True) or not self.show_progress,
)
#mp_context = M.get_context("spawn" if device.type != "cpu" else None)
#with concurrent.futures.ProcessPoolExecutor(
# mp_context = M.get_context("spawn" if device.type != "cpu" else None)
# with concurrent.futures.ProcessPoolExecutor(
# min(cpus_available(), M.cpu_count() // 2, 32), mp_context=mp_context
#) as pool:
# ) as pool:
with concurrent.futures.ThreadPoolExecutor() as pool:
for result in pool.map(compute_at, range(sample_amt)):
progress_bar.update(1)
Expand Down

0 comments on commit 2d3a0ed

Please sign in to comment.