diff --git a/bin/train/main.py b/bin/train/main.py index 2666a6b5..aed72a2f 100644 --- a/bin/train/main.py +++ b/bin/train/main.py @@ -33,6 +33,7 @@ def _ensure_graceful_shutdowns(): from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset from nam.models import Model +from nam.models._base import BaseNet # HACK access from nam.util import filter_warnings, timestamp torch.manual_seed(0) @@ -191,6 +192,7 @@ def main_inner( "Train and validation data loaders have different data set sample rates: " f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" ) + model.net.sample_rate = dataset_train.sample_rate train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) @@ -215,7 +217,6 @@ def main_inner( ) model.cpu() model.eval() - model.net.sample_rate = train_dataloader.dataset.sample_rate if make_plots: plot( model, @@ -226,9 +227,9 @@ def main_inner( show=False, ) plot(model, dataset_validation, show=not no_show) - # Would like to, but this doesn't work for all cases. - # If you're making snapshot models, you may find this convenient to uncomment :) - # model.net.export(outdir) + # Convenient export for snapshot models: + if isinstance(model.net, BaseNet): + model.net.export(outdir) if __name__ == "__main__": diff --git a/nam/data.py b/nam/data.py index ac817c2f..caf1e58e 100644 --- a/nam/data.py +++ b/nam/data.py @@ -753,7 +753,7 @@ def _validate_datasets(cls, datasets: Sequence[Dataset]): def register_dataset_initializer( - name: str, constructor: Callable[[Any], AbstractDataset] + name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False ): """ If you have otehr data set types, you can register their initializer by name using @@ -768,7 +768,7 @@ def register_dataset_initializer( :param name: The name that'll be used in the config to ask for the data set type :param constructor: The constructor that'll be fed the config. """ - if name in _dataset_init_registry: + if name in _dataset_init_registry and not overwrite: raise KeyError( f"A constructor for dataset name '{name}' is already registered!" ) diff --git a/nam/models/_base.py b/nam/models/_base.py index c2836c84..6b30fa77 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -24,7 +24,12 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): def __init__(self, sample_rate: Optional[float] = None): super().__init__() - self.sample_rate = sample_rate + self.register_buffer( + "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool) + ) + self.register_buffer( + "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate) + ) @abc.abstractproperty def pad_start_default(self) -> bool: @@ -49,6 +54,15 @@ def _metadata_loudness_x(cls) -> torch.Tensor: ) ) + @property + def sample_rate(self) -> Optional[float]: + return self._sample_rate.item() if self._has_sample_rate else None + + @sample_rate.setter + def sample_rate(self, val: Optional[float]): + self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool) + self._sample_rate = torch.tensor(0.0 if val is None else val) + def _get_export_dict(self): d = super()._get_export_dict() sample_rate_key = "sample_rate" diff --git a/nam/models/base.py b/nam/models/base.py index 01b6215a..11dfc75e 100644 --- a/nam/models/base.py +++ b/nam/models/base.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Dict, NamedTuple, Optional, Tuple +from typing import Any, Dict, NamedTuple, Optional, Tuple import auraloss import logging @@ -210,8 +210,8 @@ def parse_config(cls, config): } @classmethod - def register_net_initializer(cls, name, constructor): - if name in _model_net_init_registry: + def register_net_initializer(cls, name, constructor, overwrite: bool = False): + if name in _model_net_init_registry and not overwrite: raise KeyError( f"A constructor for net name '{name}' is already registered!" ) @@ -238,6 +238,14 @@ def configure_optimizers(self): def forward(self, *args, **kwargs): return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 + self.net.sample_rate = checkpoint["sample_rate"] + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 + checkpoint["sample_rate"] = self.net.sample_rate + def _shared_step( self, batch ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]: diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py index 63417f75..89c81a2b 100644 --- a/tests/test_nam/test_models/test_base.py +++ b/tests/test_nam/test_models/test_base.py @@ -2,8 +2,13 @@ # Created Date: Thursday March 16th 2023 # Author: Steven Atkinson (steven@atkinson.mn) +""" +Tests for the base network and Lightning module +""" + import math from pathlib import Path +from tempfile import TemporaryDirectory from typing import Optional import numpy as np @@ -106,5 +111,46 @@ def mocked_loss( assert obj._mrstft_device == "cpu" +class TestSampleRate(object): + """ + Tests for sample_rate interface + """ + + @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) + def test_on_init(self, expected_sample_rate: Optional[float]): + model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) + self._wrap_assert(model, expected_sample_rate) + + @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) + def test_setter(self, expected_sample_rate: Optional[float]): + model = _MockBaseNet(gain=1.0) + model.sample_rate = expected_sample_rate + self._wrap_assert(model, expected_sample_rate) + + @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) + def test_state_dict(self, expected_sample_rate: Optional[float]): + """ + Assert that it makes it into the state dict + + https://github.com/sdatkinson/neural-amp-modeler/issues/351 + """ + model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) + with TemporaryDirectory() as tmpdir: + model_path = Path(tmpdir, "model.pt") + torch.save(model.state_dict(), model_path) + model2 = _MockBaseNet(gain=1.0) + model2.load_state_dict(torch.load(model_path)) + self._wrap_assert(model2, expected_sample_rate) + + @classmethod + def _wrap_assert(cls, model: _MockBaseNet, expected: Optional[float]): + actual = model.sample_rate + if expected is None: + assert actual is None + else: + assert isinstance(actual, float) + assert actual == expected + + if __name__ == "__main__": pytest.main()