Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop Python 3.9 support #14

Merged
merged 12 commits into from
Jan 6, 2025
5 changes: 3 additions & 2 deletions .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ runs:
using: 'composite'
steps:
- name: Install uv
uses: astral-sh/setup-uv@v3
uses: astral-sh/setup-uv@v4
with:
version: "0.5.1"
version: "0.5.10"
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
python-version: ${{ matrix.python-version }}
8 changes: 1 addition & 7 deletions .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,11 @@ on:
pull_request:
types: [opened, synchronize, reopened]
jobs:
test:
lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Lint check
run: make lint
4 changes: 1 addition & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.9, "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
uv-resolution: ["lowest-direct", "highest"]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: ./.github/actions/setup-uv
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Unit tests
run: uv run --resolution=${{ matrix.uv-resolution }} --all-extras coverage run --parallel
- name: Upload coverage data
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.8.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def optimize(self, batch, trainer):
loss_disc = (loss_real + loss_fake) / 2

# step dicriminator
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])
self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])

if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[0].step()
Expand All @@ -81,7 +81,7 @@ def optimize(self, batch, trainer):
loss_gen = trainer.criterion(logits, valid)

# step generator
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[1].step()
trainer.optimizer[1].zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions examples/train_simple_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def optimize(self, batch, trainer):
loss_disc = (loss_real + loss_fake) / 2

# step dicriminator
_, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])
self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])

if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[0].step()
Expand All @@ -118,13 +118,13 @@ def optimize(self, batch, trainer):
loss_gen = trainer.criterion(logits, valid)

# step generator
_, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
if trainer.total_steps_done % trainer.grad_accum_steps == 0:
trainer.optimizer[1].step()
trainer.optimizer[1].zero_grad()
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, criterion):
imgs, _ = batch

Expand Down
10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ build-backend = "hatchling.build"

[project]
name = "coqui-tts-trainer"
version = "0.2.0"
version = "0.2.1"
description = "General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui."
readme = "README.md"
requires-python = ">=3.9, <3.13"
requires-python = ">=3.10, <3.13"
license = {text = "Apache-2.0"}
authors = [
{name = "Eren Gölge", email = "[email protected]"}
Expand All @@ -47,7 +47,6 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -58,7 +57,7 @@ classifiers = [
dependencies = [
"coqpit-config>=0.1.1",
"fsspec>=2023.6.0",
"numpy>=1.24.3; python_version < '3.12'",
"numpy>=1.25.2; python_version < '3.12'",
"numpy>=1.26.0; python_version >= '3.12'",
"packaging>=21.0",
"psutil>=5",
Expand All @@ -73,7 +72,7 @@ dev = [
"coverage>=7",
"pre-commit>=3",
"pytest>=8",
"ruff==0.6.9",
"ruff==0.8.3",
]
test = [
"accelerate>=0.20.0",
Expand Down Expand Up @@ -112,7 +111,6 @@ packages = ["trainer"]

[tool.ruff]
line-length = 120
target-version = "py39"
lint.extend-select = [
"ANN204", # type hints
"B", # bugbear
Expand Down
6 changes: 0 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
import os


def run_cli(command):
exit_status = os.system(command)
assert exit_status == 0, f" [!] command `{command}` failed."
30 changes: 30 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from trainer.config import TrainerConfig


def test_optimizer_params():
TrainerConfig(
optimizer="optimizer",
grad_clip=0.0,
lr=0.1,
optimizer_params={},
lr_scheduler="scheduler",
)

TrainerConfig(
optimizer=["optimizer1", "optimizer2"],
grad_clip=[0.0, 0.0],
lr=[0.1, 0.01],
optimizer_params=[{}, {}],
lr_scheduler=["scheduler1", "scheduler2"],
)

with pytest.raises(TypeError, match="Either none or all of these fields must be a list:"):
TrainerConfig(
optimizer=["optimizer1", "optimizer2"],
grad_clip=0.0,
lr=[0.1, 0.01],
optimizer_params=[{}, {}],
lr_scheduler=["scheduler1", "scheduler2"],
)
26 changes: 15 additions & 11 deletions tests/test_continue_train.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
from tests import run_cli
import sys

from tests.utils.train_mnist import main as train_mnist

def test_continue_train(tmp_path):
command_train = f"python tests/utils/train_mnist.py --coqpit.output_path {tmp_path}"
run_cli(command_train)

def test_continue_train(tmp_path, monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", [sys.argv[0], "--coqpit.output_path", str(tmp_path)])
train_mnist()

continue_path = max(tmp_path.iterdir(), key=lambda p: p.stat().st_mtime)
number_of_checkpoints = len(list(continue_path.glob("*.pth")))

# Continue training from the best model
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
run_cli(command_continue)
with monkeypatch.context() as m:
m.setattr(sys, "argv", [sys.argv[0], "--continue_path", str(continue_path), "--coqpit.run_eval_steps", "1"])
train_mnist()

assert number_of_checkpoints < len(list(continue_path.glob("*.pth")))

# Continue training from the last checkpoint
for best in continue_path.glob("best_model*"):
best.unlink()
run_cli(command_continue)

# Continue training from a specific checkpoint
restore_path = continue_path / "checkpoint_5.pth"
command_continue = (
f"python tests/utils/train_mnist.py --restore_path {restore_path} --coqpit.output_path {tmp_path}"
)
run_cli(command_continue)
with monkeypatch.context() as m:
m.setattr(
sys, "argv", [sys.argv[0], "--restore_path", str(restore_path), "--coqpit.output_path", str(tmp_path)]
)
train_mnist()
18 changes: 5 additions & 13 deletions tests/test_train_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def train_step(self, batch, criterion, optimizer_idx):
loss_real = criterion(logits, valid)
return {"model_outputs": logits}, {"loss": loss_real}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, criterion, optimizer_idx):
return self.train_step(batch, criterion, optimizer_idx)

Expand Down Expand Up @@ -200,7 +200,7 @@ def train_step(self, batch, criterion, optimizer_idx):
loss_real = criterion(logits, valid)
return {"model_outputs": logits}, {"loss": loss_real}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, criterion, optimizer_idx):
return self.train_step(batch, criterion, optimizer_idx)

Expand Down Expand Up @@ -300,7 +300,7 @@ def optimize(self, batch, trainer):
trainer.optimizer[1].step()
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, trainer):
imgs, _ = batch

Expand Down Expand Up @@ -413,7 +413,7 @@ def optimize(self, batch, trainer):
trainer.optimizer[1].zero_grad()
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, criterion):
imgs, _ = batch

Expand Down Expand Up @@ -528,7 +528,7 @@ def optimize(self, batch, trainer):
trainer.optimizer[1].zero_grad()
return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

@torch.no_grad()
@torch.inference_mode()
def eval_step(self, batch, criterion):
imgs, _ = batch

Expand Down Expand Up @@ -582,11 +582,3 @@ def get_data_loader(self, config, assets, is_eval, samples, verbose, num_gpus, r
print(f"loss_g1: {loss_g1}, loss_g2: {loss_g2}")
assert loss_d1 > loss_d2, f"Discriminator loss should decrease. {loss_d1} > {loss_d2}"
assert loss_g1 > loss_g2, f"Generator loss should decrease. {loss_g1} > {loss_g2}"


if __name__ == "__main__":
test_overfit_mnist_simple_gan()
test_overfit_accelerate_mnist_simple_gan()
test_overfit_manual_optimize_mnist_simple_gan()
test_overfit_manual_optimize_grad_accum_mnist_simple_gan()
test_overfit_manual_accelerate_optimize_grad_accum_mnist_simple_gan()
3 changes: 1 addition & 2 deletions tests/utils/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def eval_step(self, batch, criterion):
loss = criterion(logits, y)
return {"model_outputs": logits}, {"loss": loss}

@staticmethod
def get_criterion():
def get_criterion(self):
return torch.nn.NLLLoss()

def get_data_loader(self, config, assets, *, is_eval, samples=None, verbose=False, num_gpus=1, rank=0): # pylint: disable=unused-argument
Expand Down
3 changes: 1 addition & 2 deletions tests/utils/train_mnist.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from mnist import MnistModel, MnistModelConfig

from tests.utils.mnist import MnistModel, MnistModelConfig
from trainer import Trainer, TrainerArgs


Expand Down
2 changes: 1 addition & 1 deletion trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@

__version__ = importlib.metadata.version("coqui-tts-trainer")

__all__ = ["TrainerArgs", "TrainerConfig", "Trainer", "TrainerModel"]
__all__ = ["Trainer", "TrainerArgs", "TrainerConfig", "TrainerModel"]
23 changes: 23 additions & 0 deletions trainer/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict

import torch

if TYPE_CHECKING:
import matplotlib
import numpy.typing as npt
import plotly

from trainer.trainer import Trainer


Audio: TypeAlias = "npt.NDArray[Any]"
Figure: TypeAlias = "matplotlib.figure.Figure | plotly.graph_objects.Figure"
LRScheduler: TypeAlias = torch.optim.lr_scheduler._LRScheduler

Callback: TypeAlias = Callable[["Trainer"], None]


class LossDict(TypedDict):
train_loss: float
eval_loss: float | None
Loading
Loading