Skip to content

Commit

Permalink
Feature: overrides in project name on WandB
Browse files Browse the repository at this point in the history
  • Loading branch information
kurt-stolle committed Mar 21, 2024
1 parent bc22541 commit 2b2be75
Show file tree
Hide file tree
Showing 25 changed files with 195 additions and 352 deletions.
1 change: 0 additions & 1 deletion scripts/wandb_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pprint

import wandb

from unipercept.integrations.wandb_integration import artifact_historic_delete

if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions sources/unipercept/_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def _read_model_wandb(path: str) -> str:
from unipercept import file_io

run = _wandb_read_run(path)
import wandb
from wandb.sdk.wandb_run import Run

import wandb

assert path.startswith(WANDB_RUN_PREFIX)

_logger.info("Reading W&B model checkpoint from %s", path)
Expand Down Expand Up @@ -390,8 +391,7 @@ def create_dataset(
variant: T.Optional[str | re.Pattern],
batch_size: int,
return_loader: bool = True,
) -> tuple[torch.utils.data.DataLoader[InputData], Metadata]:
...
) -> tuple[torch.utils.data.DataLoader[InputData], Metadata]: ...


@T.overload
Expand All @@ -400,8 +400,7 @@ def create_dataset(
variant: T.Optional[str | re.Pattern],
batch_size: int,
return_loader: bool = False,
) -> tuple[T.Iterator[InputData], Metadata]:
...
) -> tuple[T.Iterator[InputData], Metadata]: ...


def create_dataset(
Expand Down
27 changes: 9 additions & 18 deletions sources/unipercept/_api_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,48 @@ def __dir__() -> list[str]:
@T.overload
def get_dataset(
query: T.Literal["cityscapes"], # noqa: U100
) -> type[unisets.cityscapes.CityscapesDataset]:
...
) -> type[unisets.cityscapes.CityscapesDataset]: ...

@T.overload
def get_dataset(
query: T.Literal["cityscapes-vps"], # noqa: U100
) -> type[unisets.cityscapes.CityscapesVPSDataset]:
...
) -> type[unisets.cityscapes.CityscapesVPSDataset]: ...

@T.overload
def get_dataset(
query: T.Literal["kitti-360"], # noqa: U100
) -> type[unisets.kitti_360.KITTI360Dataset]:
...
) -> type[unisets.kitti_360.KITTI360Dataset]: ...

@T.overload
def get_dataset(
query: T.Literal["kitti-step"], # noqa: U100
) -> type[unisets.kitti_step.KITTISTEPDataset]:
...
) -> type[unisets.kitti_step.KITTISTEPDataset]: ...

@T.overload
def get_dataset(
query: T.Literal["kitti-sem"], # noqa: U100
) -> type[unisets.kitti_sem.SemKITTIDataset]:
...
) -> type[unisets.kitti_sem.SemKITTIDataset]: ...

@T.overload
def get_dataset(
query: T.Literal["vistas"], # noqa: U100
) -> type[unisets.vistas.VistasDataset]:
...
) -> type[unisets.vistas.VistasDataset]: ...

@T.overload
def get_dataset(
query: T.Literal["wilddash"], # noqa: U100
) -> type[unisets.wilddash.WildDashDataset]:
...
) -> type[unisets.wilddash.WildDashDataset]: ...

@T.overload
def get_dataset(
query: str, # noqa: U100
) -> type[unisets.PerceptionDataset]:
...
) -> type[unisets.PerceptionDataset]: ...

@T.overload
def get_dataset(
query: None, # noqa: U100
**kwargs: T.Any, # noqa: U100
) -> type[unisets.PerceptionDataset]:
...
) -> type[unisets.PerceptionDataset]: ...


def get_dataset(
Expand Down
2 changes: 1 addition & 1 deletion sources/unipercept/cli/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __call__(self, parser, namespace, values, option_string=None):
cfg = up.read_config(name)
cfg = self.apply_overrides(cfg, overrides)
cfg["CLI"] = name
cfg["CLI_OVERRIDES"] = " ".join(overrides)
cfg["CLI_OVERRIDES"] = list(overrides)

setattr(namespace, self.dest + "_path", name)
setattr(namespace, self.dest + "_overrides", overrides)
Expand Down
21 changes: 7 additions & 14 deletions sources/unipercept/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def apply(f: EnvFilter | str, v: T.Any, /) -> bool:
@T.overload
def get_env(
__type: type[_R], /, *keys: str, default: _R, filter: EnvFilter = EnvFilter.TRUTHY
) -> _R:
...
) -> _R: ...


@T.overload
Expand All @@ -129,8 +128,7 @@ def get_env(
*keys: str,
default: _R | None = None,
filter: EnvFilter = EnvFilter.TRUTHY,
) -> _R | None:
...
) -> _R | None: ...


@functools.cache
Expand Down Expand Up @@ -534,12 +532,10 @@ def safe_update(cfg, key, value):
if T.TYPE_CHECKING:

class LazyObject(T.Generic[_L]):
def __getattr__(self, name: str) -> T.Any:
...
def __getattr__(self, name: str) -> T.Any: ...

@override
def __setattr__(self, __name: str, __value: Any) -> None:
...
def __setattr__(self, __name: str, __value: Any) -> None: ...

else:
import types
Expand Down Expand Up @@ -628,18 +624,15 @@ def migrate_target(v: T.Any) -> T.Any:


@T.overload
def instantiate(cfg: T.Sequence[LazyObject[_L]], /) -> T.Sequence[_L]:
...
def instantiate(cfg: T.Sequence[LazyObject[_L]], /) -> T.Sequence[_L]: ...


@T.overload
def instantiate(cfg: LazyObject[_L], /) -> _L:
...
def instantiate(cfg: LazyObject[_L], /) -> _L: ...


@T.overload
def instantiate(cfg: T.Mapping[T.Any, LazyObject[_L]], /) -> T.Mapping[T.Any, _L]:
...
def instantiate(cfg: T.Mapping[T.Any, LazyObject[_L]], /) -> T.Mapping[T.Any, _L]: ...


def instantiate(cfg: T.Any, /) -> T.Any:
Expand Down
9 changes: 3 additions & 6 deletions sources/unipercept/data/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,15 @@ def queue_size(self) -> int:

@property
@abc.abstractmethod
def indices(self) -> T.Iterator[_I]:
...
def indices(self) -> T.Iterator[_I]: ...

@property
@abc.abstractmethod
def sample_count(self) -> int:
...
def sample_count(self) -> int: ...

@property
@abc.abstractmethod
def total_count(self) -> int:
...
def total_count(self) -> int: ...

@property
def generator(self) -> torch.Generator:
Expand Down
3 changes: 1 addition & 2 deletions sources/unipercept/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def _run(self, inputs: InputData) -> InputData:
if T.TYPE_CHECKING:

@override
def __call__(self, inputs: InputData) -> InputData:
...
def __call__(self, inputs: InputData) -> InputData: ...


class CloneOp(Op):
Expand Down
6 changes: 2 additions & 4 deletions sources/unipercept/data/tensors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def multi_read(
key: Any,
*,
no_entries: Literal[NoEntriesAction.ERROR] | Literal["error"],
) -> Callable[Concatenate[Sequence[Mapping[Any, Any]], _ReadParams], _ReadReturn]:
...
) -> Callable[Concatenate[Sequence[Mapping[Any, Any]], _ReadParams], _ReadReturn]: ...


@overload
Expand All @@ -101,8 +100,7 @@ def multi_read(
no_entries: Literal[NoEntriesAction.NONE] | Literal["none"] = NoEntriesAction.NONE,
) -> Callable[
Concatenate[Sequence[Mapping[Any, Any]], _ReadParams], _ReadReturn | None
]:
...
]: ...


def multi_read(
Expand Down
12 changes: 6 additions & 6 deletions sources/unipercept/engine/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch.optim
import torch.types
import torch.utils.data
import wandb
from omegaconf import DictConfig, OmegaConf
from PIL import Image as pil_image
from tabulate import tabulate
Expand All @@ -37,6 +36,7 @@
from torch.utils.data import Dataset
from typing_extensions import override

import wandb
from unipercept import file_io
from unipercept.data import DataLoaderFactory
from unipercept.engine._params import EngineParams, EvaluationSuite, TrainingStage
Expand Down Expand Up @@ -634,17 +634,15 @@ def build_training_dataloader(
dataloader: DataLoaderFactory,
batch_size: int,
gradient_accumulation: None = None,
) -> tuple[torch.utils.data.DataLoader, int, None]:
...
) -> tuple[torch.utils.data.DataLoader, int, None]: ...

@T.overload
def build_training_dataloader(
self,
dataloader: DataLoaderFactory,
batch_size: int,
gradient_accumulation: int,
) -> tuple[torch.utils.data.DataLoader, int, int]:
...
) -> tuple[torch.utils.data.DataLoader, int, int]: ...

def build_training_dataloader(
self,
Expand Down Expand Up @@ -1247,7 +1245,9 @@ def _start_experiment_trackers(self, *, restart: bool = True) -> None:
# Set up tracker-specific parameters
specific_kwargs = {
"wandb": {
"name": self.config_name,
"name": " ".join(
[self.config_name, *self.config.get("CLI_OVERRIDES", [])]
),
"job_type": job_type,
"reinit": True,
"group": group_name,
Expand Down
9 changes: 3 additions & 6 deletions sources/unipercept/engine/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,23 +102,20 @@ def find_executable_batch_size(
function: _Fin[_P, _R],
*,
starting_batch_size: int = 128,
) -> _Fout[_P, _R]:
...
) -> _Fout[_P, _R]: ...

@T.overload
def find_executable_batch_size(
function: None = None,
*,
starting_batch_size: int = 128,
) -> T.Callable[[_Fin[_P, _R]], _Fout[_P, _R]]:
...
) -> T.Callable[[_Fin[_P, _R]], _Fout[_P, _R]]: ...

def find_executable_batch_size(
function: _Fin | None = None,
*,
starting_batch_size: int = 128,
) -> T.Callable[[_Fin[_P, _R]], _Fout[_P, _R]] | _Fout[_P, _R]:
...
) -> T.Callable[[_Fin[_P, _R]], _Fout[_P, _R]] | _Fout[_P, _R]: ...

else:
find_executable_batch_size = accelerate.utils.find_executable_batch_size
3 changes: 1 addition & 2 deletions sources/unipercept/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,7 @@ def __call__(
state: State,
control: Signal,
**kwargs,
) -> Signal | None:
...
) -> Signal | None: ...


CallbackType: T.TypeAlias = CallbackProtocol | type[CallbackProtocol]
Expand Down
9 changes: 3 additions & 6 deletions sources/unipercept/evaluators/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,20 @@ def update(
storage: TensorDictBase, # noqa: U100
inputs: TensorDictBase, # noqa: U100
outputs: TensorDictBase, # noqa: U100
) -> None:
...
) -> None: ...

@abc.abstractmethod
def compute(
self,
storage: TensorDictBase, # noqa: U100
**kwargs: T.Unpack[EvaluatorComputeKWArgs], # noqa: U100
) -> dict[str, int | float | str | bool | dict]:
...
) -> dict[str, int | float | str | bool | dict]: ...

@abc.abstractmethod
def plot(
self,
storage: TensorDictBase, # noqa: U100
) -> dict[str, pil_image.Image]:
...
) -> dict[str, pil_image.Image]: ...

def _show_table(self, msg: str, tab: pd.DataFrame) -> None:
from unipercept.log import create_table
Expand Down
Loading

0 comments on commit 2b2be75

Please sign in to comment.