Skip to content

Commit

Permalink
refactor: apply automatic ruff fixes for python 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 29, 2024
1 parent 87a5508 commit 36a212e
Show file tree
Hide file tree
Showing 16 changed files with 109 additions and 115 deletions.
3 changes: 2 additions & 1 deletion trainer/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Callable
from collections.abc import Callable
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from trainer import Trainer
Expand Down
30 changes: 13 additions & 17 deletions trainer/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any

from coqpit import Coqpit

Expand Down Expand Up @@ -50,13 +50,13 @@ class TrainerArgs(Coqpit):
default=False,
metadata={"help": "Start with evaluation and test."},
)
small_run: Optional[int] = field(
small_run: int | None = field(
default=None,
metadata={
"help": "Only use a subset of the samples for debugging. Set the number of samples to use. Defaults to None. "
},
)
gpu: Optional[int] = field(
gpu: int | None = field(
default=None, metadata={"help": "GPU ID to use if ```CUDA_VISIBLE_DEVICES``` is not set. Defaults to None."}
)
# only for DDP
Expand Down Expand Up @@ -97,14 +97,14 @@ class TrainerConfig(Coqpit):

# Fields for the run
output_path: str = field(default="output")
logger_uri: Optional[str] = field(
logger_uri: str | None = field(
default=None,
metadata={
"help": "URI to save training artifacts by the logger. If not set, logs will be saved in the output_path. Defaults to None"
},
)
run_name: str = field(default="run", metadata={"help": "Name of the run. Defaults to 'run'"})
project_name: Optional[str] = field(default=None, metadata={"help": "Name of the project. Defaults to None"})
project_name: str | None = field(default=None, metadata={"help": "Name of the project. Defaults to None"})
run_description: str = field(
default="🐸Coqui trainer run.",
metadata={"help": "Notes and description about the run. Defaults to '🐸Coqui trainer run.'"},
Expand All @@ -119,17 +119,15 @@ class TrainerConfig(Coqpit):
model_param_stats: bool = field(
default=False, metadata={"help": "Log model parameters stats on the logger dashboard. Defaults to False"}
)
wandb_entity: Optional[str] = field(
default=None, metadata={"help": "Wandb entity to log the run. Defaults to None"}
)
wandb_entity: str | None = field(default=None, metadata={"help": "Wandb entity to log the run. Defaults to None"})
dashboard_logger: str = field(
default="tensorboard", metadata={"help": "Logger to use for the tracking dashboard. Defaults to 'tensorboard'"}
)
# Fields for checkpointing
save_on_interrupt: bool = field(
default=True, metadata={"help": "Save checkpoint on interrupt (Ctrl+C). Defaults to True"}
)
log_model_step: Optional[int] = field(
log_model_step: int | None = field(
default=None,
metadata={
"help": "Save checkpoint to the logger every log_model_step steps. If not defined `save_step == log_model_step`."
Expand All @@ -144,7 +142,7 @@ class TrainerConfig(Coqpit):
default=False, metadata={"help": "Save all best checkpoints and keep the older ones. Defaults to False"}
)
save_best_after: int = field(default=0, metadata={"help": "Wait N steps to save best checkpoints. Defaults to 0"})
target_loss: Optional[str] = field(
target_loss: str | None = field(
default=None, metadata={"help": "Target loss name to select the best model. Defaults to None"}
)
# Fields for eval and test run
Expand All @@ -153,7 +151,7 @@ class TrainerConfig(Coqpit):
run_eval: bool = field(
default=True, metadata={"help": "Run evalulation epoch after training epoch. Defaults to True"}
)
run_eval_steps: Optional[int] = field(
run_eval_steps: int | None = field(
default=None,
metadata={
"help": "Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None"
Expand Down Expand Up @@ -186,16 +184,14 @@ class TrainerConfig(Coqpit):
metadata={"help": "Step the scheduler after each epoch else step after each iteration. Defaults to True"},
)
# Fields for optimzation
lr: Union[float, list[float]] = field(
lr: float | list[float] = field(
default=0.001, metadata={"help": "Learning rate for each optimizer. Defaults to 0.001"}
)
optimizer: Optional[Union[str, list[str]]] = field(
default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"}
)
optimizer_params: Union[dict[str, Any], list[dict[str, Any]]] = field(
optimizer: str | list[str] | None = field(default=None, metadata={"help": "Optimizer(s) to use. Defaults to None"})
optimizer_params: dict[str, Any] | list[dict[str, Any]] = field(
default_factory=dict, metadata={"help": "Optimizer(s) arguments. Defaults to {}"}
)
lr_scheduler: Optional[Union[str, list[str]]] = field(
lr_scheduler: str | list[str] | None = field(
default=None, metadata={"help": "Learning rate scheduler(s) to use. Defaults to None"}
)
lr_scheduler_params: dict[str, Any] = field(
Expand Down
6 changes: 3 additions & 3 deletions trainer/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
from collections.abc import ItemsView
from pathlib import Path
from typing import Any, Union
from typing import Any

import fsspec
import torch
Expand Down Expand Up @@ -78,14 +78,14 @@ def get_commit_hash() -> str:
return commit


def get_experiment_folder_path(root_path: Union[str, os.PathLike[Any]], model_name: str) -> Path:
def get_experiment_folder_path(root_path: str | os.PathLike[Any], model_name: str) -> Path:
"""Get an experiment folder path with the current date and time."""
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
return Path(root_path) / f"{model_name}-{date_str}-{commit_hash}"


def remove_experiment_folder(experiment_path: Union[str, os.PathLike[Any]]) -> None:
def remove_experiment_folder(experiment_path: str | os.PathLike[Any]) -> None:
"""Check folder if there is a checkpoint, otherwise remove the folder."""
experiment_path = str(experiment_path)
fs = fsspec.get_mapper(experiment_path).fs
Expand Down
45 changes: 23 additions & 22 deletions trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os
import re
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import Any
from urllib.parse import urlparse

import fsspec
Expand Down Expand Up @@ -42,7 +43,7 @@ def get_user_data_dir(appname: str) -> Path:
return ans.joinpath(appname)


def copy_model_files(config: Coqpit, out_path: Union[str, os.PathLike[Any]], new_fields: dict) -> None:
def copy_model_files(config: Coqpit, out_path: str | os.PathLike[Any], new_fields: dict) -> None:
"""Copy config.json and other model files to training folder and add new fields.
Args:
Expand All @@ -60,8 +61,8 @@ def copy_model_files(config: Coqpit, out_path: Union[str, os.PathLike[Any]], new


def load_fsspec(
path: Union[str, os.PathLike[Any]],
map_location: Optional[Union[str, Callable[[Storage, str], Storage], torch.device, dict[str, str]]] = None,
path: str | os.PathLike[Any],
map_location: str | Callable[[Storage, str], Storage] | torch.device | dict[str, str] | None = None,
*,
cache: bool = True,
**kwargs,
Expand Down Expand Up @@ -92,7 +93,7 @@ def load_fsspec(

def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: Union[str, os.PathLike[Any]],
checkpoint_path: str | os.PathLike[Any],
*,
use_cuda: bool = False,
eval: bool = False,
Expand All @@ -107,7 +108,7 @@ def load_checkpoint(
return model, state


def save_fsspec(state: Any, path: Union[str, os.PathLike[Any]], **kwargs) -> None:
def save_fsspec(state: Any, path: str | os.PathLike[Any], **kwargs) -> None:
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
Args:
Expand All @@ -120,14 +121,14 @@ def save_fsspec(state: Any, path: Union[str, os.PathLike[Any]], **kwargs) -> Non


def save_model(
config: Union[dict, Coqpit],
config: dict | Coqpit,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scaler,
current_step: int,
epoch: int,
output_path: Union[str, os.PathLike[Any]],
save_func: Optional[Callable] = None,
output_path: str | os.PathLike[Any],
save_func: Callable | None = None,
**kwargs,
) -> None:
model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
Expand Down Expand Up @@ -163,15 +164,15 @@ def save_model(


def save_checkpoint(
config: Union[dict, Coqpit],
config: dict | Coqpit,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scaler,
current_step: int,
epoch: int,
output_folder: Union[str, os.PathLike[Any]],
save_n_checkpoints: Optional[int] = None,
save_func: Optional[Callable] = None,
output_folder: str | os.PathLike[Any],
save_n_checkpoints: int | None = None,
save_func: Callable | None = None,
**kwargs,
):
file_name = f"checkpoint_{current_step}.pth"
Expand All @@ -194,21 +195,21 @@ def save_checkpoint(


def save_best_model(
current_loss: Union[dict, float],
best_loss: Union[dict[str, Optional[float]], float],
config: Union[dict, Coqpit],
current_loss: dict | float,
best_loss: dict[str, float | None] | float,
config: dict | Coqpit,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scaler,
current_step: int,
epoch: int,
out_path: Union[str, os.PathLike[Any]],
out_path: str | os.PathLike[Any],
*,
keep_all_best: bool = False,
keep_after: int = 0,
save_func: Optional[Callable] = None,
save_func: Callable | None = None,
**kwargs,
) -> Union[dict, float]:
) -> dict | float:
if isinstance(current_loss, dict) and isinstance(best_loss, dict):
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
Expand Down Expand Up @@ -251,7 +252,7 @@ def save_best_model(
return best_loss


def get_last_checkpoint(path: Union[str, os.PathLike[Any]]) -> tuple[str, str]:
def get_last_checkpoint(path: str | os.PathLike[Any]) -> tuple[str, str]:
"""Get latest checkpoint or/and best model in path.
It is based on globbing for `*.pth` and the RegEx
Expand Down Expand Up @@ -318,7 +319,7 @@ def get_last_checkpoint(path: Union[str, os.PathLike[Any]]) -> tuple[str, str]:
return last_models["checkpoint"], last_models["best_model"]


def keep_n_checkpoints(path: Union[str, os.PathLike[Any]], n: int) -> None:
def keep_n_checkpoints(path: str | os.PathLike[Any], n: int) -> None:
"""Keep only the last n checkpoints in path.
Args:
Expand All @@ -333,7 +334,7 @@ def keep_n_checkpoints(path: Union[str, os.PathLike[Any]], n: int) -> None:


def sort_checkpoints(
output_path: Union[str, os.PathLike[Any]], checkpoint_prefix: str, *, use_mtime: bool = False
output_path: str | os.PathLike[Any], checkpoint_prefix: str, *, use_mtime: bool = False
) -> list[str]:
"""Sort checkpoint paths based on the checkpoint step number.
Expand Down
4 changes: 2 additions & 2 deletions trainer/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
logger = logging.getLogger("trainer")


def get_mlflow_tracking_url() -> Union[str, None]:
def get_mlflow_tracking_url() -> str | None:
if "MLFLOW_TRACKING_URI" in os.environ:
return os.environ["MLFLOW_TRACKING_URI"]
return None


def get_ai_repo_url() -> Union[str, None]:
def get_ai_repo_url() -> str | None:
if "AIM_TRACKING_URI" in os.environ:
return os.environ["AIM_TRACKING_URI"]
return None
Expand Down
4 changes: 1 addition & 3 deletions trainer/logging/aim_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import torch

from trainer.logging.base_dash_logger import BaseDashboardLogger
Expand All @@ -17,7 +15,7 @@ def __init__(
self,
repo: str,
model_name: str,
tags: Optional[str] = None,
tags: str | None = None,
) -> None:
self._context = None
self.model_name = model_name
Expand Down
2 changes: 1 addition & 1 deletion trainer/logging/base_dash_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def add_text(self, title: str, text: str, step: int) -> None:
pass

@abstractmethod
def add_artifact(self, file_or_dir: Union[str, os.PathLike[Any]], name: str, artifact_type: str, aliases=None):
def add_artifact(self, file_or_dir: str | os.PathLike[Any], name: str, artifact_type: str, aliases=None):
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions trainer/logging/clearml_logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Optional
from typing import Any

import torch

Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
local_path: str,
project_name: str,
task_name: str,
tags: Optional[str] = None,
tags: str | None = None,
) -> None:
self._context = None
self.local_path = local_path
Expand Down
4 changes: 2 additions & 2 deletions trainer/logging/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any

from trainer.utils.distributed import rank_zero_only

Expand Down Expand Up @@ -44,7 +44,7 @@ def get_time() -> str:
return now.strftime("%Y-%m-%d %H:%M:%S")

@rank_zero_only
def print_epoch_start(self, epoch: int, max_epoch: int, output_path: Optional[Union[str, os.PathLike[Any]]] = None):
def print_epoch_start(self, epoch: int, max_epoch: int, output_path: str | os.PathLike[Any] | None = None):
self.log_with_flush(
f"\n{tcolors.UNDERLINE}{tcolors.BOLD} > EPOCH: {epoch}/{max_epoch}{tcolors.ENDC}",
)
Expand Down
3 changes: 1 addition & 2 deletions trainer/logging/mlflow_logger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import tempfile
import traceback
from typing import Optional

import soundfile as sf
import torch
Expand All @@ -23,7 +22,7 @@ def __init__(
self,
log_uri: str,
model_name: str,
tags: Optional[str] = None,
tags: str | None = None,
) -> None:
self.model_name = model_name
self.client = MlflowClient(tracking_uri=os.path.join(log_uri))
Expand Down
4 changes: 2 additions & 2 deletions trainer/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

import torch
from torch import nn
Expand All @@ -20,7 +20,7 @@ class TrainerModel(ABC, nn.Module):
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this."""

@abstractmethod
def forward(self, input: torch.Tensor, *args, aux_input: Optional[dict[str, Any]] = None, **kwargs) -> dict:
def forward(self, input: torch.Tensor, *args, aux_input: dict[str, Any] | None = None, **kwargs) -> dict:
"""Forward ... for the model mainly used in training.
You can be flexible here and use different number of arguments and argument names since it is intended to be
Expand Down
5 changes: 2 additions & 3 deletions trainer/torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
from collections.abc import Iterator
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -36,8 +35,8 @@ def __init__(
self,
sampler,
*,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
num_replicas: int | None = None,
rank: int | None = None,
shuffle: bool = True,
seed: int = 0,
) -> None:
Expand Down
Loading

0 comments on commit 36a212e

Please sign in to comment.