Skip to content

Commit

Permalink
[BUGFIX] Store model's expected sample rate in Lightning checkpoints (#…
Browse files Browse the repository at this point in the history
…357)

* Rework sample_rate implementation to be storable in PyTorch artifacts

* Store sample rate in checkpoint
  • Loading branch information
sdatkinson authored Jan 9, 2024
1 parent a2f54f0 commit a284650
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 10 deletions.
9 changes: 5 additions & 4 deletions bin/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _ensure_graceful_shutdowns():

from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.models import Model
from nam.models._base import BaseNet # HACK access
from nam.util import filter_warnings, timestamp

torch.manual_seed(0)
Expand Down Expand Up @@ -191,6 +192,7 @@ def main_inner(
"Train and validation data loaders have different data set sample rates: "
f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
)
model.net.sample_rate = dataset_train.sample_rate
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])

Expand All @@ -215,7 +217,6 @@ def main_inner(
)
model.cpu()
model.eval()
model.net.sample_rate = train_dataloader.dataset.sample_rate
if make_plots:
plot(
model,
Expand All @@ -226,9 +227,9 @@ def main_inner(
show=False,
)
plot(model, dataset_validation, show=not no_show)
# Would like to, but this doesn't work for all cases.
# If you're making snapshot models, you may find this convenient to uncomment :)
# model.net.export(outdir)
# Convenient export for snapshot models:
if isinstance(model.net, BaseNet):
model.net.export(outdir)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions nam/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def _validate_datasets(cls, datasets: Sequence[Dataset]):


def register_dataset_initializer(
name: str, constructor: Callable[[Any], AbstractDataset]
name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False
):
"""
If you have otehr data set types, you can register their initializer by name using
Expand All @@ -768,7 +768,7 @@ def register_dataset_initializer(
:param name: The name that'll be used in the config to ask for the data set type
:param constructor: The constructor that'll be fed the config.
"""
if name in _dataset_init_registry:
if name in _dataset_init_registry and not overwrite:
raise KeyError(
f"A constructor for dataset name '{name}' is already registered!"
)
Expand Down
16 changes: 15 additions & 1 deletion nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
class _Base(nn.Module, InitializableFromConfig, Exportable):
def __init__(self, sample_rate: Optional[float] = None):
super().__init__()
self.sample_rate = sample_rate
self.register_buffer(
"_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool)
)
self.register_buffer(
"_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate)
)

@abc.abstractproperty
def pad_start_default(self) -> bool:
Expand All @@ -49,6 +54,15 @@ def _metadata_loudness_x(cls) -> torch.Tensor:
)
)

@property
def sample_rate(self) -> Optional[float]:
return self._sample_rate.item() if self._has_sample_rate else None

@sample_rate.setter
def sample_rate(self, val: Optional[float]):
self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool)
self._sample_rate = torch.tensor(0.0 if val is None else val)

def _get_export_dict(self):
d = super()._get_export_dict()
sample_rate_key = "sample_rate"
Expand Down
14 changes: 11 additions & 3 deletions nam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Dict, NamedTuple, Optional, Tuple
from typing import Any, Dict, NamedTuple, Optional, Tuple

import auraloss
import logging
Expand Down Expand Up @@ -210,8 +210,8 @@ def parse_config(cls, config):
}

@classmethod
def register_net_initializer(cls, name, constructor):
if name in _model_net_init_registry:
def register_net_initializer(cls, name, constructor, overwrite: bool = False):
if name in _model_net_init_registry and not overwrite:
raise KeyError(
f"A constructor for net name '{name}' is already registered!"
)
Expand All @@ -238,6 +238,14 @@ def configure_optimizers(self):
def forward(self, *args, **kwargs):
return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
self.net.sample_rate = checkpoint["sample_rate"]

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
checkpoint["sample_rate"] = self.net.sample_rate

def _shared_step(
self, batch
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_nam/test_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
# Created Date: Thursday March 16th 2023
# Author: Steven Atkinson ([email protected])

"""
Tests for the base network and Lightning module
"""

import math
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -106,5 +111,46 @@ def mocked_loss(
assert obj._mrstft_device == "cpu"


class TestSampleRate(object):
"""
Tests for sample_rate interface
"""

@pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
def test_on_init(self, expected_sample_rate: Optional[float]):
model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
self._wrap_assert(model, expected_sample_rate)

@pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
def test_setter(self, expected_sample_rate: Optional[float]):
model = _MockBaseNet(gain=1.0)
model.sample_rate = expected_sample_rate
self._wrap_assert(model, expected_sample_rate)

@pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
def test_state_dict(self, expected_sample_rate: Optional[float]):
"""
Assert that it makes it into the state dict
https://github.com/sdatkinson/neural-amp-modeler/issues/351
"""
model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
with TemporaryDirectory() as tmpdir:
model_path = Path(tmpdir, "model.pt")
torch.save(model.state_dict(), model_path)
model2 = _MockBaseNet(gain=1.0)
model2.load_state_dict(torch.load(model_path))
self._wrap_assert(model2, expected_sample_rate)

@classmethod
def _wrap_assert(cls, model: _MockBaseNet, expected: Optional[float]):
actual = model.sample_rate
if expected is None:
assert actual is None
else:
assert isinstance(actual, float)
assert actual == expected


if __name__ == "__main__":
pytest.main()

0 comments on commit a284650

Please sign in to comment.