Skip to content

Commit

Permalink
Add apply_pac function
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 12, 2024
1 parent d90af5b commit 7a50660
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,65 @@ def apply_kmeans(
plt.show()

return cluster_labels, latitudes, longitudes


def apply_pca(
dataset: Union[xr.Dataset, str],
n_components: int = 3,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
x_component: int = 0,
y_component: int = 1,
color: str = "blue",
title: str = "PCA of Spectral Data",
**kwargs,
) -> np.ndarray:
"""
Applies Principal Component Analysis (PCA) to the dataset and optionally plots the results.
Args:
dataset (xr.Dataset | str): The dataset containing the PACE data or the file path to the dataset.
n_components (int, optional): Number of principal components to compute. Defaults to 3.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
x_component (int, optional): The principal component to plot on the x-axis. Defaults to 0.
y_component (int, optional): The principal component to plot on the y-axis. Defaults to 1.
color (str, optional): Color of the scatter plot points. Defaults to "blue".
title (str, optional): Title for the plot. Defaults to "PCA of Spectral Data".
**kwargs: Additional keyword arguments to pass to the `plt.scatter` function.
Returns:
np.ndarray: The PCA-transformed data.
"""
from sklearn.decomposition import PCA

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

da = dataset["Rrs"]

# Reshape data to (n_pixels, n_bands)
reshaped_data = da.values.reshape(-1, da.shape[-1])

# Handle NaNs by removing them
reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)]

# Apply PCA to reduce dimensionality
pca = PCA(n_components=n_components)
pca_data = pca.fit_transform(reshaped_data_no_nan)

if plot:
plt.figure(figsize=figsize)
if "s" not in kwargs:
kwargs["s"] = 1
plt.scatter(
pca_data[:, x_component], pca_data[:, y_component], c=color, **kwargs
)
plt.title(title)
plt.xlabel(f"Principal Component {x_component + 1}")
plt.ylabel(f"Principal Component {y_component + 1}")
plt.show()

return pca_data

0 comments on commit 7a50660

Please sign in to comment.