Skip to content

Commit

Permalink
Fix bug (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson authored May 14, 2024
1 parent a681ebb commit 2d06ada
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
14 changes: 13 additions & 1 deletion nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def _metadata_loudness_x(cls) -> torch.Tensor:
)
)

@property
def device(self) -> Optional[torch.device]:
"""
Helpful property, where the parameters of the model live.
"""
# We can do this because the models are tiny and I don't expect a NAM to be on
# multiple devices
try:
return next(self.parameters()).device
except StopIteration:
return None

@property
def sample_rate(self) -> Optional[float]:
return self._sample_rate.item() if self._has_sample_rate else None
Expand Down Expand Up @@ -81,7 +93,7 @@ def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
:param gain: Multiplies input signal
"""
x = self._metadata_loudness_x()
x = self._metadata_loudness_x().to(self.device)
y = self._at_nominal_settings(gain * x)
loudness = torch.sqrt(torch.mean(torch.square(y)))
if db:
Expand Down
2 changes: 1 addition & 1 deletion nam/models/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def export_weights(self) -> np.ndarray:
weights = torch.cat([layer.export_weights() for layer in self._layers])
if self._head is not None:
weights = torch.cat([weights, self._head.export_weights()])
weights = torch.cat([weights, torch.Tensor([self._head_scale])])
weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])])
return weights.detach().cpu().numpy()

def import_weights(self, weights: torch.Tensor):
Expand Down

0 comments on commit 2d06ada

Please sign in to comment.