From 4098ce946d449bfa51c3b70627164442d3cb3190 Mon Sep 17 00:00:00 2001 From: Matteo Di Bernardo Date: Mon, 2 Sep 2024 23:50:03 +0000 Subject: [PATCH] add sampling for evaluation --- scripts/evaluate_md.py | 213 +++++--------------- scripts/training_loop_resnet18_linear_md.py | 15 +- src/embed_time/evaluate_static.py | 131 ++++++++++++ src/embed_time/model.py | 4 +- 4 files changed, 189 insertions(+), 174 deletions(-) create mode 100644 src/embed_time/evaluate_static.py diff --git a/scripts/evaluate_md.py b/scripts/evaluate_md.py index 4331fbe..dfe694f 100644 --- a/scripts/evaluate_md.py +++ b/scripts/evaluate_md.py @@ -1,176 +1,61 @@ -#%% -import os -import numpy as np -import torch -from torch.utils.data import DataLoader -from torch.nn import functional as F -from torchvision.transforms import v2 +from embed_time.evaluate_static import ModelEvaluator import pandas as pd import matplotlib.pyplot as plt -from sklearn.decomposition import PCA -from matplotlib.colors import ListedColormap -import yaml - -from embed_time.dataset_static import ZarrCellDataset -from embed_time.dataloader_static import collate_wrapper -from embed_time.model_VAE_resnet18_linear import VAEResNet18_Linear - -# Utility Functions -def read_config(yaml_path): - with open(yaml_path, 'r') as file: - config = yaml.safe_load(file) - mean = [float(i) for i in config['Dataset mean'][0].split()] - std = [float(i) for i in config['Dataset std'][0].split()] - return np.array(mean), np.array(std) - -def load_checkpoint(checkpoint_path, model, device): - checkpoint = torch.load(checkpoint_path, map_location=device) - model.load_state_dict(checkpoint['model_state_dict']) - return model, checkpoint['epoch'] +import os -# Model Evaluation Function -def evaluate_model(model, dataloader, device): - model.eval() - total_loss = total_mse = total_kld = 0 - all_latent_vectors = [] - all_metadata = [] +def plot_cell_data(original, reconstruction): + fig, axes = plt.subplots(2, 4, figsize=(20, 10)) - with torch.no_grad(): - for batch in dataloader: - data = batch['cell_image'].to(device) - metadata = [batch['gene'], batch['barcode'], batch['stage']] - - recon_batch, _, mu, logvar = model(data) - mse = F.mse_loss(recon_batch, data, reduction='sum') - kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - loss = mse + kld * 1e-5 - - total_loss += loss.item() - total_mse += mse.item() - total_kld += kld.item() - - all_latent_vectors.append(mu.cpu()) - all_metadata.extend(zip(*metadata)) - - avg_loss = total_loss / len(dataloader.dataset) - avg_mse = total_mse / len(dataloader.dataset) - avg_kld = total_kld / len(dataloader.dataset) - latent_vectors = torch.cat(all_latent_vectors, dim=0) - - return avg_loss, avg_mse, avg_kld, latent_vectors, all_metadata - -# Visualization Functions -def plot_reconstructions(model, dataloader, device): - model.eval() - with torch.no_grad(): - batch = next(iter(dataloader)) - data = batch['cell_image'].to(device) - recon_batch, _, _, _ = model(data) - - image_idx = np.random.randint(data.shape[0]) - original = data[image_idx].cpu().numpy() - reconstruction = recon_batch[image_idx].cpu().numpy() - - fig, axes = plt.subplots(2, 4, figsize=(20, 10)) - - for j in range(4): - axes[0, j].imshow(original[j], cmap='gray') - axes[0, j].set_title(f'Original Channel {j+1}') - axes[0, j].axis('off') - axes[1, j].imshow(reconstruction[j], cmap='gray') - axes[1, j].set_title(f'Reconstructed Channel {j+1}') - axes[1, j].axis('off') - - plt.tight_layout() - plt.show() - - print(f"Image shape: {original.shape}") - print(f"Reconstruction shape: {reconstruction.shape}") - print(f"Original image min/max values: {original.min():.4f}/{original.max():.4f}") - print(f"Reconstructed image min/max values: {reconstruction.min():.4f}/{reconstruction.max():.4f}") - -def create_pca_plots(train_latents, val_latents, train_df, val_df): - pca = PCA(n_components=2) - train_latents_pca = pca.fit_transform(train_latents) - val_latents_pca = pca.transform(val_latents) - - fig, axes = plt.subplots(2, 3, figsize=(18, 12)) + for j in range(4): + axes[0, j].imshow(original[j], cmap='gray', vmin=-1, vmax=1) + axes[0, j].set_title(f'Original Channel {j+1}') + axes[0, j].axis('off') + axes[1, j].imshow(reconstruction[j], cmap='gray', vmin=-1, vmax=1) + axes[1, j].set_title(f'Reconstructed Channel {j+1}') + axes[1, j].axis('off') - def create_color_map(n): - return ListedColormap(plt.cm.tab20(np.linspace(0, 1, n))) - - attributes = ['stage', 'barcode', 'gene'] - for i, attr in enumerate(attributes): - for j, (latents_pca, df) in enumerate([(train_latents_pca, train_df), (val_latents_pca, val_df)]): - unique_values = df[attr].unique() - color_map = create_color_map(len(unique_values)) - color_dict = {value: i for i, value in enumerate(unique_values)} - colors = [color_dict[value] for value in df[attr]] - - scatter = axes[j, i].scatter(latents_pca[:, 0], latents_pca[:, 1], c=colors, s=5, cmap=color_map) - axes[j, i].set_title(f"{'Training' if j == 0 else 'Validation'} Latent Space (PCA) - Colored by {attr}") - axes[j, i].set_xlabel("PC1") - axes[j, i].set_ylabel("PC2") - - cbar = plt.colorbar(scatter, ax=axes[j, i], ticks=range(len(unique_values))) - cbar.set_ticklabels(unique_values) - plt.tight_layout() plt.show() -#%% -# Main Execution -if __name__ == "__main__": - # Setup - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # Model initialization and loading - model = VAEResNet18_Linear(nc=4, z_dim=72, input_spatial_dim=[96,96]) - checkpoint_dir = "/mnt/efs/dlmbl/G-et/checkpoints/static/Matteo/20240902_1450_resnet_linear_test/" - checkpoints = sorted(os.listdir(checkpoint_dir), key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x))) - checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1]) - model, epoch = load_checkpoint(checkpoint_path, model, device) - model = model.to(device) - print(model) - - # Dataset parameters - parent_dir = '/mnt/efs/dlmbl/S-md/' - csv_file = '/mnt/efs/dlmbl/G-et/csv/dataset_split_2.csv' - channels = [0, 1, 2, 3] - transform = "masks" - crop_size = 96 - normalizations = v2.Compose([v2.CenterCrop(crop_size)]) - yaml_file_path = "/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml" - dataset_mean, dataset_std = read_config(yaml_file_path) - - # Dataset and DataLoader creation - metadata_keys = ['gene', 'barcode', 'stage'] - images_keys = ['cell_image'] + print(f"Image shape: {original.shape}") + print(f"Reconstruction shape: {reconstruction.shape}") + print(f"Original image min/max values: {original.min():.4f}/{original.max():.4f}") + print(f"Reconstructed image min/max values: {reconstruction.min():.4f}/{reconstruction.max():.4f}") + +# Your configuration +config = { + 'model': 'VAEResNet18_Linear', + 'nc': 4, + 'z_dim': 32, + 'input_spatial_dim': [96, 96], + 'checkpoint_dir': "/mnt/efs/dlmbl/G-et/da_testing/training_logs/", + 'parent_dir': '/mnt/efs/dlmbl/S-md/', + 'csv_file': '/mnt/efs/dlmbl/G-et/csv/dataset_split_17_sampled.csv', + 'channels': [0, 1, 2, 3], + 'transform': "masks", + 'crop_size': 96, + 'yaml_file_path': "/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml", + 'batch_size': 16, + 'num_workers': 8, + 'metadata_keys': ['gene', 'barcode', 'stage'], + 'images_keys': ['cell_image'], + 'kld_weight': 1e-5, + 'output_dir': '/mnt/efs/dlmbl/G-et/latent_space_data/', + 'sampling_number': 3 +} + +# Initialize ModelEvaluator +evaluator = ModelEvaluator(config) - dataset_train = ZarrCellDataset(parent_dir, csv_file, 'train', channels, transform, normalizations, None, dataset_mean, dataset_std) - dataset_val = ZarrCellDataset(parent_dir, csv_file, 'val', channels, transform, normalizations, None, dataset_mean, dataset_std) - - dataloader_train = DataLoader(dataset_train, batch_size=16, shuffle=True, num_workers=8, collate_fn=collate_wrapper(metadata_keys, images_keys)) - dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=True, num_workers=8, collate_fn=collate_wrapper(metadata_keys, images_keys)) - - # Model evaluation - print("Evaluating on training data...") - train_loss, train_mse, train_kld, train_latents, train_metadata = evaluate_model(model, dataloader_train, device) - print(f"Training - Loss: {train_loss:.4f}, MSE: {train_mse:.4f}, KLD: {train_kld:.4f}") - - print("Evaluating on validation data...") - val_loss, val_mse, val_kld, val_latents, val_metadata = evaluate_model(model, dataloader_val, device) - print(f"Validation - Loss: {val_loss:.4f}, MSE: {val_mse:.4f}, KLD: {val_kld:.4f}") - - # Create DataFrames - train_df = pd.DataFrame(train_metadata, columns=['gene', 'barcode', 'stage']) - train_df = pd.concat([train_df, pd.DataFrame(train_latents.numpy())], axis=1) +train_df = evaluator.evaluate('train') +val_df = evaluator.evaluate('val') - val_df = pd.DataFrame(val_metadata, columns=['gene', 'barcode', 'stage']) - val_df = pd.concat([val_df, pd.DataFrame(val_latents.numpy())], axis=1) +# save train_df and val_df to csv in graphs subdirectory +model_name = "17_genes_resnet18_linear_latent32" - # Visualizations - plot_reconstructions(model, dataloader_val, device) - plot_reconstructions(model, dataloader_train, device) - create_pca_plots(train_latents.numpy(), val_latents.numpy(), train_df, val_df) +os.makedirs("latent", exist_ok=True) +train_df.to_csv(f"latent/{model_name}_train.csv", index=False) +val_df.to_csv(f"latent/{model_name}_val.csv", index=False) +print(val_df.shape) -#%% \ No newline at end of file +print("Evaluation complete. Latent dimensions extracted and saved.") \ No newline at end of file diff --git a/scripts/training_loop_resnet18_linear_md.py b/scripts/training_loop_resnet18_linear_md.py index 7eda768..7bf4821 100644 --- a/scripts/training_loop_resnet18_linear_md.py +++ b/scripts/training_loop_resnet18_linear_md.py @@ -43,7 +43,7 @@ def read_config(yaml_path): device = torch.device("cpu") # Basic values for logging -model_name = "static_resnet_linear_vae_md" +model_name = "static_resnet_linear_vae_md_nomask" find_port = True # Function to find an available port @@ -72,10 +72,10 @@ def launch_tensorboard(log_dir): # Define variables for the dataset read in parent_dir = '/mnt/efs/dlmbl/S-md/' -csv_file = '/mnt/efs/dlmbl/G-et/csv/dataset_split_2.csv' +csv_file = '/mnt/efs/dlmbl/G-et/csv/dataset_split_17_sampled.csv' split = 'train' channels = [0, 1, 2, 3] -transform = "masks" +transform = None crop_size = 96 normalizations = v2.Compose([v2.CenterCrop(crop_size)]) yaml_file_path = "/mnt/efs/dlmbl/G-et/yaml/dataset_info_20240901_155625.yaml" @@ -98,8 +98,7 @@ def launch_tensorboard(log_dir): ) # Create the model -vae = VAEResNet18_Linear(nc = 4, z_dim = 72, input_spatial_dim = [96,96]) - +vae = VAEResNet18_Linear(nc = 4, z_dim = 32, input_spatial_dim = [96,96]) torchview.draw_graph( vae, @@ -114,7 +113,7 @@ def launch_tensorboard(log_dir): vae = vae.to(device) # Define the optimizer -optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4) +optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) def loss_function(recon_x, x, mu, logvar): MSE = F.mse_loss(recon_x, x, reduction='mean') @@ -148,7 +147,7 @@ def train( recon_batch, z, mu, logvar = vae(data) MSE, KLD = loss_function(recon_batch, data, mu, logvar) - loss = MSE + KLD * 1e-5 + loss = MSE + KLD * 1e-4 loss.backward() train_loss += loss.item() @@ -244,7 +243,7 @@ def train( # Training loop output_dir = '/mnt/efs/dlmbl/G-et/' -run_name= "resnet_linear_test" +run_name= "resnet_linear_17_32dim_nomask" folder_suffix = datetime.now().strftime("%Y%m%d_%H%M_") + run_name log_path = output_dir + "logs/static/Matteo/"+ folder_suffix + "/" diff --git a/src/embed_time/evaluate_static.py b/src/embed_time/evaluate_static.py new file mode 100644 index 0000000..cb016df --- /dev/null +++ b/src/embed_time/evaluate_static.py @@ -0,0 +1,131 @@ +import os +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.nn import functional as F +from torchvision.transforms import v2 +import pandas as pd +import yaml +import argparse + +from embed_time.dataset_static import ZarrCellDataset +from embed_time.dataloader_static import collate_wrapper +from embed_time.model_VAE_resnet18_linear import VAEResNet18_Linear +from embed_time.model_VAE_resnet18 import VAEResNet18 +from embed_time.model import VAE, Encoder, Decoder + +class ModelEvaluator(): + def __init__(self, config): + self.config = config + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self._init_model() + self.dataset_mean, self.dataset_std = self._read_config() + + def _init_model(self): + if self.config['model'] == 'VAEResNet18': + model = VAEResNet18(nc=self.config['nc'], z_dim=self.config['z_dim'], input_spatial_dim=self.config['input_spatial_dim']) + elif self.config['model'] == 'VAEResNet18_Linear': + model = VAEResNet18_Linear(nc=self.config['nc'], z_dim=self.config['z_dim'], input_spatial_dim=self.config['input_spatial_dim']) + elif self.config['model'] == 'VAE': + encoder = Encoder(self.config['nc'], self.config['z_dim']) + decoder = Decoder(self.config['z_dim'], self.config['h_dim1'], self.config['h_dim2'], self.config['nc'], self.config['output_shape']) + model = VAE(encoder, decoder) + checkpoints = sorted(os.listdir(self.config['checkpoint_dir']), key=lambda x: os.path.getmtime(os.path.join(self.config['checkpoint_dir'], x))) + checkpoint_path = os.path.join(self.config['checkpoint_dir'], checkpoints[-1]) + model, _ = self._load_checkpoint(checkpoint_path, model) + return model.to(self.device) + + def _read_config(self): + with open(self.config['yaml_file_path'], 'r') as file: + yaml_config = yaml.safe_load(file) + mean = [float(i) for i in yaml_config['Dataset mean'][0].split()] + std = [float(i) for i in yaml_config['Dataset std'][0].split()] + return np.array(mean), np.array(std) + + def _load_checkpoint(self, checkpoint_path, model): + checkpoint = torch.load(checkpoint_path, map_location=self.device) + model.load_state_dict(checkpoint['model_state_dict']) + return model, checkpoint['epoch'] + + def _create_dataloader(self, split): + dataset = ZarrCellDataset( + self.config['parent_dir'], + self.config['csv_file'], + split, + self.config['channels'], + self.config['transform'], + v2.Compose([v2.CenterCrop(self.config['crop_size'])]), + None, + self.dataset_mean, + self.dataset_std + ) + return DataLoader( + dataset, + batch_size=self.config['batch_size'], + shuffle=False, + num_workers=self.config['num_workers'], + collate_fn=collate_wrapper(self.config['metadata_keys'], self.config['images_keys']) + ) + + def evaluate_model(self, dataloader): + self.model.eval() + total_loss = total_mse = total_kld = 0 + all_latent_vectors = [] + all_metadata = [] + + with torch.no_grad(): + for batch in dataloader: + data = batch['cell_image'].to(self.device) + metadata = [batch[key] for key in self.config['metadata_keys']] + + recon_batch, _, mu, logvar = self.model(data) + mse = F.mse_loss(recon_batch, data, reduction='sum') + kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + loss = mse + kld * self.config['kld_weight'] + + total_loss += loss.item() + total_mse += mse.item() + total_kld += kld.item() + + if self.config['sampling_number'] > 1: + print('Sampling {} times...'.format(self.config['sampling_number'])) + for i in range(self.config['sampling_number']): + # Sample from the latent space + z = self.model.reparameterize(mu, logvar) + # save zs and metadata into additional latent representations + all_latent_vectors.append(z.cpu()) + all_metadata.extend(zip(*metadata)) + else: + all_latent_vectors.append(mu.cpu()) + all_metadata.extend(zip(*metadata)) + + # Sample from the latent space + + avg_loss = total_loss / len(dataloader.dataset) + avg_mse = total_mse / len(dataloader.dataset) + avg_kld = total_kld / len(dataloader.dataset) + latent_vectors = torch.cat(all_latent_vectors, dim=0) + + return avg_loss, avg_mse, avg_kld, latent_vectors, all_metadata + + def evaluate(self, split): + dataloader = self._create_dataloader(split) + print(f"Evaluating on {split} data...") + loss, mse, kld, latents, metadata = self.evaluate_model(dataloader) + print(f"{split.capitalize()} - Loss: {loss:.4f}, MSE: {mse:.4f}, KLD: {kld:.4f}") + + # Create DataFrame + df = pd.DataFrame(metadata, columns=self.config['metadata_keys']) + latent_df = pd.DataFrame(latents.numpy(), columns=[f'latent_{i}' for i in range(latents.shape[1])]) + df = pd.concat([df, latent_df], axis=1) + + return df + +def parse_args(): + parser = argparse.ArgumentParser(description="Model Evaluation Script") + parser.add_argument("--config", type=str, required=True, help="Path to the configuration YAML file") + return parser.parse_args() + +def load_config(config_path): + with open(config_path, 'r') as file: + return yaml.safe_load(file) diff --git a/src/embed_time/model.py b/src/embed_time/model.py index d5a40a1..08866e2 100644 --- a/src/embed_time/model.py +++ b/src/embed_time/model.py @@ -115,7 +115,7 @@ def check_shapes(self, data_shape, z_dim): print("Error in checking shapes") raise (e) - def sampling(self, mu, log_var): + def reparametrize(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) z = eps.mul(std).add_(mu) @@ -123,5 +123,5 @@ def sampling(self, mu, log_var): def forward(self, x): mu, log_var = self.encoder(x) - z = self.sampling(mu, log_var) + z = self.reparametrize(mu, log_var) return self.decoder(z), mu, log_var