Skip to content

Commit

Permalink
Catch edge cases for clustering and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
klapo committed Jan 19, 2024
1 parent c60befd commit 4595e24
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
9 changes: 7 additions & 2 deletions pydmd/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@ def cluster_hyperparameter_sweep(
"""
if n_components_range is None:
n_components_range = np.arange(
np.max((self.svd_rank // 4, 2)), self.svd_rank
np.max((self.svd_rank // 4, 2)),
self.svd_rank // 2,
)
score = np.zeros_like(n_components_range, float)

Expand Down Expand Up @@ -1359,7 +1360,9 @@ def to_xarray(self):
},
attrs={
"svd_rank": self.svd_rank,
"omega_transformation": self._transform_method,
"omega_transformation": self._xarray_sanitize(
self._transform_method
),
"n_slides": self._n_slides,
"window_length": self._window_length,
"num_frequency_bands": self.n_components,
Expand Down Expand Up @@ -1411,6 +1414,8 @@ def from_xarray(self, ds):
self._pydmd_kwargs[new_attr_name] = self._xarray_unsanitize(
ds.attrs[attr]
)
elif "omega_transformation" in attr:
self._transform_method = self._xarray_unsanitize(ds.attrs[attr])

return self

Expand Down
3 changes: 3 additions & 0 deletions pydmd/mrcosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ def plot_local_reconstructions(
elif level > 0:
x_iter, _ = self.costs_array[level - 1].scale_separation()

if not x_iter.shape == (self._n_data_vars, self._n_time_steps):
raise ValueError("Input data has the wrong shape.")

if kwargs is None:
kwargs = {}

Expand Down

0 comments on commit 4595e24

Please sign in to comment.