Skip to content

Commit

Permalink
Add chla_predict and chla_viz functions
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Dec 8, 2024
1 parent 355e670 commit 7dd2b13
Showing 1 changed file with 230 additions and 7 deletions.
237 changes: 230 additions & 7 deletions hypercoast/chla.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def forward(
x (torch.Tensor): Input tensor.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Reconstructed tensor, mean, and log variance.
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Reconstructed tensor,
mean, and log variance.
"""
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
Expand Down Expand Up @@ -138,6 +139,8 @@ def train(
epochs: int = 200,
device: torch.device = None,
opt: torch.optim.Optimizer = None,
model_path: str = "model/vae_model_PACE.pth",
best_model_path: str = "model/vae_trans_model_best_Chl_PACE.pth",
) -> None:
"""Trains the VAE model.
Expand All @@ -147,6 +150,10 @@ def train(
epochs (int, optional): Number of epochs to train. Defaults to 200.
device (torch.device, optional): Device to train on. Defaults to None.
opt (torch.optim.Optimizer, optional): Optimizer. Defaults to None.
model_path (str, optional): Path to save the model. Defaults to
"model/vae_model_PACE.pth".
best_model_path (str, optional): Path to save the best model. Defaults
to "model/vae_trans_model_best_Chl_PACE.pth"
"""

if device is None:
Expand All @@ -159,7 +166,6 @@ def train(

min_total_loss = float("inf")
# Save the optimal model
best_model_total_path = "model/vae_trans_model_best_Chl_PACE.pth"

for epoch in range(epochs):
total_loss = 0.0
Expand All @@ -177,9 +183,9 @@ def train(

if avg_total_loss < min_total_loss:
min_total_loss = avg_total_loss
torch.save(model.state_dict(), best_model_total_path)
torch.save(model.state_dict(), best_model_path)
# Save the model from the last epoch.
torch.save(model.state_dict(), "model/vae_model_PACE.pth")
torch.save(model.state_dict(), model_path)


def evaluate(
Expand Down Expand Up @@ -220,7 +226,8 @@ def load_real_data(
rrs_file_path (str): Path to the rrs file.
Returns:
tuple[DataLoader, DataLoader, int, int]: Training DataLoader, testing DataLoader, input dimension, output dimension.
tuple[DataLoader, DataLoader, int, int]: Training DataLoader, testing
DataLoader, input dimension, output dimension.
"""
array1 = np.loadtxt(aphy_file_path, delimiter=",", dtype=float)
array2 = np.loadtxt(rrs_file_path, delimiter=",", dtype=float)
Expand Down Expand Up @@ -399,7 +406,7 @@ def plot_results(
x = np.array([-2, 4])
y = slope * x + intercept

plt.figure(figsize=(6, 6))
_ = plt.figure(figsize=(6, 6))

plt.plot(x, y, linestyle="--", color="blue", linewidth=0.8)
lims = [-2, 4]
Expand Down Expand Up @@ -434,9 +441,10 @@ def plot_results(

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()

plt.savefig(os.path.join(save_dir, f"{mode}_plot.pdf"), bbox_inches="tight")
plt.close()
# plt.close()


def save_to_csv(data: np.ndarray, file_path: str) -> None:
Expand All @@ -448,3 +456,218 @@ def save_to_csv(data: np.ndarray, file_path: str) -> None:
"""
df = pd.DataFrame(data)
df.to_csv(file_path, index=False)


def chla_predict(
pace_filepath: str,
best_model_path: str,
chla_data_file: str = None,
device: torch.device = None,
) -> None:
"""Predicts chlorophyll-a concentration using a pre-trained VAE model.
Args:
pace_filepath (str): Path to the PACE dataset file.
best_model_path (str): Path to the pre-trained VAE model file.
chla_data_file (str, optional): Path to save the predicted chlorophyll-a data. Defaults to None.
device (torch.device, optional): Device to perform inference on. Defaults to None.
"""

from .pace import read_pace

if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load PACE dataset and prepare data
PACE_dataset = read_pace(pace_filepath)
da = PACE_dataset["Rrs"]
wl = da.wavelength.values
Rrs = da.values
latitude = da.latitude.values
longitude = da.longitude.values

# Filter wavelengths between 400 and 703 nm
indices = np.where((wl >= 400) & (wl <= 703))[0]
filtered_Rrs = Rrs[:, :, indices]

# Save filtered Rrs and wavelength
filtered_wl = wl[indices]

# Create a mask that is 1 where all wavelengths for a given pixel have non-NaN values, and 0 otherwise
mask = np.all(~np.isnan(filtered_Rrs), axis=2).astype(int)

# Define input and output dimensions
input_dim = 148
output_dim = 1

# Load test data and mask
test_data = filtered_Rrs
mask_data = mask

# Filter valid data using the mask
mask = mask_data == 1
N = np.sum(mask)
valid_test_data = test_data[mask]

# Normalize data
valid_test_data = np.array(
[
(
MinMaxScaler(feature_range=(1, 10))
.fit_transform(row.reshape(-1, 1))
.flatten()
if not np.isnan(row).any()
else row
)
for row in valid_test_data
]
)
valid_test_data = valid_test_data.reshape(N, input_dim)

# Create DataLoader for test data
test_tensor = TensorDataset(torch.tensor(valid_test_data).float())
test_loader = DataLoader(test_tensor, batch_size=2048, shuffle=False)

# Load the pre-trained VAE model
model = VAE(input_dim, output_dim).to(device)
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

# Perform inference
predictions_all = []
with torch.no_grad():
for batch in test_loader:
batch = batch[0].to(device)
predictions, _, _ = model(batch) # VAE model inference
predictions_all.append(predictions.cpu().numpy())

# Concatenate all batch predictions
predictions_all = np.vstack(predictions_all)

# Ensure predictions are in the correct shape
if predictions_all.shape[-1] == 1:
predictions_all = predictions_all.squeeze(-1)
# if predictions_all.dim() == 3:
# all_outputs = predictions_all.squeeze(1)

# Initialize output array with NaNs
outputs = np.full((test_data.shape[0], test_data.shape[1]), np.nan)

# Fill in the valid mask positions with predictions
outputs[mask] = predictions_all

# Flatten latitude, longitude, and predictions for output
lat_flat = latitude.flatten()
lon_flat = longitude.flatten()
output_flat = outputs.flatten()

# Combine latitude, longitude, and predictions
final_output = np.column_stack((lat_flat, lon_flat, output_flat))

# Save the final output including latitude and longitude
if chla_data_file is None:
chla_data_file = pace_filepath.replace(".nc", ".npy")
np.save(chla_data_file, final_output)


def chla_viz(
rgb_image_tif_file: str,
chla_data_file: str,
output_file: str,
title: str = "PACE",
figsize: tuple = (12, 8),
cmap: str = "jet",
) -> None:
"""Visualizes the chlorophyll-a concentration over an RGB image.
Args:
rgb_image_tif_file (str): Path to the RGB image file.
chla_data_file (str): Path to the chlorophyll-a data file.
output_file (str): Path to save the output visualization.
title (str, optional): Title of the plot. Defaults to "PACE".
figsize (tuple, optional): Figure size for the plot. Defaults to (12, 8).
cmap (str, optional): Colormap for the chlorophyll-a concentration. Defaults to "jet".
"""

# Read RGB Image
# rgb_image_tif_file = "data/snapshot-2024-08-10T00_00_00Z.tif"

with rasterio.open(rgb_image_tif_file) as dataset:
# Read R、G、B bands
R = dataset.read(1)
G = dataset.read(2)
B = dataset.read(3)

# # Get geographic extent, resolution information.
extent = [
dataset.bounds.left,
dataset.bounds.right,
dataset.bounds.bottom,
dataset.bounds.top,
]
transform = dataset.transform
width, height = dataset.width, dataset.height

# Combine the R, G, B bands into a 3D array.
rgb_image = np.stack((R, G, B), axis=-1)

# Load Chla data
chla_data = np.load(chla_data_file)
# chla_data = final_output

# Extract the latitude, longitude, and concentration values of the chlorophyll-a data.
latitude = chla_data[:, 0]
longitude = chla_data[:, 1]
chla_values = chla_data[:, 2]

# Extract the pixels within the geographic extent of the RGB image.
mask = (
(latitude >= extent[2])
& (latitude <= extent[3])
& (longitude >= extent[0])
& (longitude <= extent[1])
)
latitude = latitude[mask]
longitude = longitude[mask]
chla_values = chla_values[mask]

# Create a grid with the same resolution as the RGB image.
grid_lon = np.linspace(extent[0], extent[1], width)
grid_lat = np.linspace(extent[3], extent[2], height)
grid_lon, grid_lat = np.meshgrid(grid_lon, grid_lat)

# Resample the chlorophyll-a data to the size of the RGB image using interpolation.
chla_resampled = griddata(
(longitude, latitude), chla_values, (grid_lon, grid_lat), method="linear"
)

# Keep NaN values as transparent regions.
chla_resampled = np.ma.masked_invalid(chla_resampled)

plt.figure(figsize=figsize)

plt.imshow(rgb_image / 255.0, extent=extent, origin="upper")

vmin, vmax = 0, 35
im = plt.imshow(
chla_resampled,
extent=extent,
cmap=cmap,
alpha=0.6,
origin="upper",
vmin=vmin,
vmax=vmax,
)

cbar = plt.colorbar(im, orientation="horizontal")
cbar.set_label("Chlorophyll-a Concentration (mg/m³)")

plt.title(title)
plt.xlabel("Longitude")
plt.ylabel("Latitude")

# output_file = "20241024-2.png"
plt.savefig(output_file, dpi=300, bbox_inches="tight", pad_inches=0.1)
print(f"Saved overlay image to {output_file}")

plt.show()

0 comments on commit 7dd2b13

Please sign in to comment.