Skip to content

Commit

Permalink
Add apply_sam function
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 13, 2024
1 parent 5df1bd5 commit a19e83d
Showing 1 changed file with 148 additions and 1 deletion.
149 changes: 148 additions & 1 deletion hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def apply_kmeans(
im,
ax=ax,
orientation="vertical",
pad=0.02,
# pad=0.02,
fraction=0.05,
ticks=np.arange(n_clusters),
)
Expand Down Expand Up @@ -969,3 +969,150 @@ def apply_pca(
plt.show()

return pca_data


def apply_sam(
dataset: Union[xr.Dataset, str],
n_components: int = 3,
n_clusters: int = 6,
random_state: int = 0,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
extent: list[float] = None,
colors: list[str] = None,
title: str = "Spectral Angle Mapper",
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Applies Spectral Angle Mapper (SAM) to the dataset and optionally plots the results.
Args:
dataset (Union[xr.Dataset, str]): The dataset containing the PACE data or the file path to the dataset.
n_components (int, optional): Number of principal components to compute. Defaults to 3.
n_clusters (int, optional): Number of clusters for K-means. Defaults to 6.
random_state (int, optional): Random state for K-means. Defaults to 0.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
extent (List[float], optional): The extent to zoom in to the specified region. Defaults to None.
colors (List[str], optional): Colors for the clusters. Defaults to None.
title (str, optional): Title for the plot. Defaults to "Spectral Angle Mapper".
**kwargs: Additional keyword arguments to pass to the `plt.subplots` function.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The best match classification, latitudes, and longitudes.
"""
from sklearn.cluster import KMeans
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from sklearn.decomposition import PCA

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset()
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

da = dataset["Rrs"]

# Reshape data to (n_pixels, n_bands)
reshaped_data = da.values.reshape(-1, da.shape[-1])

# Handle NaNs by removing them
reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)]

# Apply PCA to reduce dimensionality
pca = PCA(n_components=n_components)
pca_data = pca.fit_transform(reshaped_data_no_nan)

# Apply K-means to find clusters representing endmembers
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
kmeans.fit(pca_data)

# The cluster centers in the original spectral space are your endmembers
endmembers = pca.inverse_transform(kmeans.cluster_centers_)

def spectral_angle_mapper(pixel, reference):
norm_pixel = np.linalg.norm(pixel)
norm_reference = np.linalg.norm(reference)
cos_theta = np.dot(pixel, reference) / (norm_pixel * norm_reference)
angle = np.arccos(np.clip(cos_theta, -1, 1))
return angle

# Apply SAM for each pixel and each endmember
angles = np.zeros((reshaped_data_no_nan.shape[0], endmembers.shape[0]))

for i in range(reshaped_data_no_nan.shape[0]):
for j in range(endmembers.shape[0]):
angles[i, j] = spectral_angle_mapper(
reshaped_data_no_nan[i, :], endmembers[j, :]
)

# Find the minimum angle (best match) for each pixel
best_match = np.argmin(angles, axis=1)

# Reshape best_match back to the original spatial dimensions
original_shape = da.shape[:-1] # Get the spatial dimensions
best_match_full = np.full(reshaped_data.shape[0], np.nan)
best_match_full[~np.isnan(reshaped_data).any(axis=1)] = best_match
best_match_full = best_match_full.reshape(original_shape)

latitudes = da.coords["latitude"].values
longitudes = da.coords["longitude"].values

if plot:

if colors is None:
colors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]
# Create a custom discrete color map
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(-0.5, n_clusters, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create a figure and axis with the correct map projection
fig, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, **kwargs
)

# Plot the SAM classification results
im = ax.pcolormesh(
longitudes,
latitudes,
best_match_full,
cmap=cmap,
norm=norm,
transform=ccrs.PlateCarree(),
)

# Add geographic features for context
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.STATES, linestyle="--")

# Add gridlines
ax.gridlines(draw_labels=True)

# Set the extent to zoom in to the specified region
if extent is not None:
ax.set_extent(extent, crs=ccrs.PlateCarree())

# Add color bar with labels
cbar = plt.colorbar(
im,
ax=ax,
orientation="vertical",
# pad=0.02,
fraction=0.05,
ticks=np.arange(n_clusters),
)
cbar.ax.set_yticklabels([f"Class {i+1}" for i in range(n_clusters)])
cbar.set_label("Water Types", rotation=270, labelpad=20)

# Add title
ax.set_title(title, fontsize=14)

# Show the plot
plt.show()

return best_match_full, latitudes, longitudes

0 comments on commit a19e83d

Please sign in to comment.