From 55c23470cae9e20986799bf9d7dc4879fec42997 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Wed, 28 Aug 2024 14:32:35 +0200 Subject: [PATCH 1/8] try new queue with max --- novae/_constants.py | 2 +- novae/model.py | 3 +-- novae/module/swav.py | 10 +++++----- novae/plot/_heatmap.py | 4 ++-- scripts/config/all_new.yaml | 30 ++++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 scripts/config/all_new.yaml diff --git a/novae/_constants.py b/novae/_constants.py index 0e3697f..fe26a2a 100644 --- a/novae/_constants.py +++ b/novae/_constants.py @@ -57,7 +57,7 @@ class Nums: SWAV_EPSILON: float = 0.05 SINKHORN_ITERATIONS: int = 3 QUEUE_SIZE: int = 3 - QUEUE_WEIGHT_THRESHOLD: float = 0.99 + QUEUE_WEIGHT_THRESHOLD: float = 0.5 # misc nums MEAN_NGH_TH_WARNING: float = 3.5 diff --git a/novae/model.py b/novae/model.py index 34ca59c..3740c50 100644 --- a/novae/model.py +++ b/novae/model.py @@ -50,7 +50,6 @@ def __init__( num_layers: int = 10, batch_size: int = 512, temperature: float = 0.1, - temperature_weight_proto: float = 0.1, num_prototypes: int = 256, panel_subset_size: float = 0.6, background_noise_lambda: float = 8.0, @@ -99,7 +98,7 @@ def __init__( ### Initialize modules self.encoder = GraphEncoder(embedding_size, hidden_size, num_layers, output_size, heads) self.augmentation = GraphAugmentation(panel_subset_size, background_noise_lambda, sensitivity_noise_std) - self.swav_head = SwavHead(self.mode, output_size, num_prototypes, temperature, temperature_weight_proto) + self.swav_head = SwavHead(self.mode, output_size, num_prototypes, temperature) ### Misc self._num_workers = 0 diff --git a/novae/module/swav.py b/novae/module/swav.py index 42bdb1f..597a5f3 100644 --- a/novae/module/swav.py +++ b/novae/module/swav.py @@ -27,7 +27,6 @@ def __init__( output_size: int, num_prototypes: int, temperature: float, - temperature_weight_proto: float, ): """SwavHead module, adapted from the paper ["Unsupervised Learning of Visual Features by Contrasting Cluster Assignments"](https://arxiv.org/abs/2006.09882). @@ -41,7 +40,6 @@ def __init__( self.output_size = output_size self.num_prototypes = num_prototypes self.temperature = temperature - self.temperature_weight_proto = temperature_weight_proto self._prototypes = nn.Parameter(torch.empty((self.num_prototypes, self.output_size))) self._prototypes = nn.init.kaiming_uniform_(self._prototypes, a=math.sqrt(5), mode="fan_out") @@ -130,10 +128,9 @@ def prototype_ilocs(self, projections: Tensor, slide_id: str | None = None) -> T return ... slide_index = self.slide_label_encoder[slide_id] - slide_weights = F.softmax(projections / self.temperature_weight_proto, dim=1).mean(0) self.queue[slide_index, 1:] = self.queue[slide_index, :-1].clone() - self.queue[slide_index, 0] = slide_weights + self.queue[slide_index, 0] = projections.max(0).values weights = self.queue_weights()[slide_index] ilocs = torch.where(weights >= Nums.QUEUE_WEIGHT_THRESHOLD)[0] @@ -146,7 +143,10 @@ def queue_weights(self) -> Tensor: Returns: A tensor of shape `(n_slides, K)`. """ - return self.sinkhorn(self.queue.mean(dim=1)) * self.num_prototypes + max_projections = self.queue.max(dim=1).values + unused_prototypes = max_projections.max(dim=0).values < Nums.QUEUE_WEIGHT_THRESHOLD + max_projections[:, unused_prototypes] = 1 # ensure all prototypes are used + return max_projections @utils.format_docs @torch.no_grad() diff --git a/novae/plot/_heatmap.py b/novae/plot/_heatmap.py index 8d579ef..3af5aae 100644 --- a/novae/plot/_heatmap.py +++ b/novae/plot/_heatmap.py @@ -22,8 +22,8 @@ def _weights_clustermap( show_yticklabels: bool = False, show_tissue_legend: bool = True, figsize: tuple[int] = (6, 4), - vmin: float = 0.9, - vmax: float = 1.1, + vmin: float = 0, + vmax: float = 0.5, **kwargs: int, ) -> None: row_colors = None diff --git a/scripts/config/all_new.yaml b/scripts/config/all_new.yaml new file mode 100644 index 0000000..c8c931d --- /dev/null +++ b/scripts/config/all_new.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: all + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human + n_hops_view: 3 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.15 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: true + log_umap: true + log_domains: true + delete_X: true From c567b0e24388805f1a4b13f8fd209cf3d6aeb2f9 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Wed, 28 Aug 2024 17:57:52 +0200 Subject: [PATCH 2/8] set up new queue --- novae/_constants.py | 2 +- novae/model.py | 7 ++++--- novae/module/swav.py | 18 ++++++++++-------- novae/plot/_heatmap.py | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/novae/_constants.py b/novae/_constants.py index fe26a2a..2b61500 100644 --- a/novae/_constants.py +++ b/novae/_constants.py @@ -57,7 +57,7 @@ class Nums: SWAV_EPSILON: float = 0.05 SINKHORN_ITERATIONS: int = 3 QUEUE_SIZE: int = 3 - QUEUE_WEIGHT_THRESHOLD: float = 0.5 + QUEUE_WEIGHT_THRESHOLD_RATIO: float = 0.9 # misc nums MEAN_NGH_TH_WARNING: float = 3.5 diff --git a/novae/model.py b/novae/model.py index 3740c50..8a7f952 100644 --- a/novae/model.py +++ b/novae/model.py @@ -412,11 +412,12 @@ def plot_prototype_weights(self, **kwargs: int): self.swav_head.queue is not None ), "Swav queue not initialized. Initialize it with `model.init_slide_queue(...)`, then train or fine-tune the model." - weights = self.swav_head.queue_weights().numpy(force=True) + weights, thresholds = self.swav_head.queue_weights() + weights, thresholds = weights.numpy(force=True), thresholds.numpy(force=True) - where_enough_prototypes = (weights >= Nums.QUEUE_WEIGHT_THRESHOLD).sum(1) >= self.swav_head.min_prototypes + where_enough_prototypes = (weights >= thresholds).sum(1) >= self.swav_head.min_prototypes for i in np.where(where_enough_prototypes)[0]: - weights[i, weights[i] < Nums.QUEUE_WEIGHT_THRESHOLD] = 0 + weights[i, weights[i] < thresholds] = 0 for i in np.where(~where_enough_prototypes)[0]: indices_0 = np.argsort(weights[i])[: -self.swav_head.min_prototypes] weights[i, indices_0] = 0 diff --git a/novae/module/swav.py b/novae/module/swav.py index 597a5f3..5ae99ec 100644 --- a/novae/module/swav.py +++ b/novae/module/swav.py @@ -130,23 +130,25 @@ def prototype_ilocs(self, projections: Tensor, slide_id: str | None = None) -> T slide_index = self.slide_label_encoder[slide_id] self.queue[slide_index, 1:] = self.queue[slide_index, :-1].clone() - self.queue[slide_index, 0] = projections.max(0).values + self.queue[slide_index, 0] = projections.topk(3, dim=0).values[-1] - weights = self.queue_weights()[slide_index] - ilocs = torch.where(weights >= Nums.QUEUE_WEIGHT_THRESHOLD)[0] + weights, thresholds = self.queue_weights() + slide_weights = weights[slide_index] - return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(weights, self.min_prototypes).indices + ilocs = torch.where(slide_weights >= thresholds)[0] + return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(slide_weights, self.min_prototypes).indices - def queue_weights(self) -> Tensor: + def queue_weights(self) -> tuple[Tensor, Tensor]: """Convert the queue to a matrix of prototype weight per slide. Returns: A tensor of shape `(n_slides, K)`. """ max_projections = self.queue.max(dim=1).values - unused_prototypes = max_projections.max(dim=0).values < Nums.QUEUE_WEIGHT_THRESHOLD - max_projections[:, unused_prototypes] = 1 # ensure all prototypes are used - return max_projections + + thresholds = max_projections.max(0).values * Nums.QUEUE_WEIGHT_THRESHOLD_RATIO + + return max_projections, thresholds @utils.format_docs @torch.no_grad() diff --git a/novae/plot/_heatmap.py b/novae/plot/_heatmap.py index 3af5aae..6faf410 100644 --- a/novae/plot/_heatmap.py +++ b/novae/plot/_heatmap.py @@ -23,7 +23,7 @@ def _weights_clustermap( show_tissue_legend: bool = True, figsize: tuple[int] = (6, 4), vmin: float = 0, - vmax: float = 0.5, + vmax: float = 1, **kwargs: int, ) -> None: row_colors = None From 6efb173146e2524f18aa8b681c627039a13f573b Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Thu, 29 Aug 2024 09:52:03 +0200 Subject: [PATCH 3/8] add all_human --- scripts/config/all_human.yaml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 scripts/config/all_human.yaml diff --git a/scripts/config/all_human.yaml b/scripts/config/all_human.yaml new file mode 100644 index 0000000..f1b9225 --- /dev/null +++ b/scripts/config/all_human.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: /gpfs/workdir/shared/prime/spatial/human + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human + n_hops_view: 3 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.4 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: false + log_umap: true + log_domains: true + delete_X: true From 6ae83e7e224978a92dffa52d344d0029b89daff4 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Thu, 29 Aug 2024 10:09:55 +0200 Subject: [PATCH 4/8] add mouse and brain configs --- scripts/config/all_brain.yaml | 30 ++++++++++++++++++++++++++++++ scripts/config/all_mouse.yaml | 30 ++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 scripts/config/all_brain.yaml create mode 100644 scripts/config/all_mouse.yaml diff --git a/scripts/config/all_brain.yaml b/scripts/config/all_brain.yaml new file mode 100644 index 0000000..f0ca89c --- /dev/null +++ b/scripts/config/all_brain.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: /gpfs/workdir/shared/prime/spatial/brain + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_brain + n_hops_view: 3 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.4 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: false + log_umap: true + log_domains: true + delete_X: true diff --git a/scripts/config/all_mouse.yaml b/scripts/config/all_mouse.yaml new file mode 100644 index 0000000..6517aaa --- /dev/null +++ b/scripts/config/all_mouse.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: /gpfs/workdir/shared/prime/spatial/mouse + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human + n_hops_view: 3 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.4 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: false + log_umap: true + log_domains: true + delete_X: true From 60f80729907d73764affb477ea3843d5ee9c470b Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Mon, 2 Sep 2024 14:40:49 +0200 Subject: [PATCH 5/8] new human config --- scripts/config/all_human.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/config/all_human.yaml b/scripts/config/all_human.yaml index f1b9225..79c7a5b 100644 --- a/scripts/config/all_human.yaml +++ b/scripts/config/all_human.yaml @@ -11,7 +11,7 @@ model_kwargs: num_prototypes: 512 background_noise_lambda: 5 panel_subset_size: 0.8 - min_prototypes_ratio: 0.4 + min_prototypes_ratio: 0.5 fit_kwargs: max_epochs: 30 From 486d6c3592bceeb38f6b2dcfe83c99f45b9dff6b Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Mon, 2 Sep 2024 16:21:59 +0200 Subject: [PATCH 6/8] minor fix --- scripts/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/utils.py b/scripts/utils.py index c52ea94..e11bee8 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -125,7 +125,7 @@ def _log_umap(model: novae.Novae, adatas: list[AnnData], config: Config, n_obs_t obs_key = model.assign_domains(adatas, n_domains=n_domains) model.batch_effect_correction(adatas, obs_key=obs_key) - latent_conc = np.concat([adata.obsm[Keys.REPR_CORRECTED] for adata in adatas], axis=0) + latent_conc = np.concatenate([adata.obsm[Keys.REPR_CORRECTED] for adata in adatas], axis=0) obs_conc = pd.concat([adata.obs for adata in adatas], axis=0, join="inner") adata_conc = AnnData(obsm={Keys.REPR_CORRECTED: latent_conc}, obs=obs_conc) From 62a30a1cadede7c1e774a4064fde0561c813ed37 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Mon, 2 Sep 2024 18:17:27 +0200 Subject: [PATCH 7/8] add all_human2 --- scripts/config/all_human2.yaml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 scripts/config/all_human2.yaml diff --git a/scripts/config/all_human2.yaml b/scripts/config/all_human2.yaml new file mode 100644 index 0000000..8ad35d6 --- /dev/null +++ b/scripts/config/all_human2.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: /gpfs/workdir/shared/prime/spatial/human + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human + n_hops_view: 2 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.75 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: false + log_umap: true + log_domains: true + delete_X: true From ce8e75da7b9d75a035ad241f157d3d436a2f1d53 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Tue, 3 Sep 2024 10:12:46 +0200 Subject: [PATCH 8/8] fix queue --- novae/module/swav.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/novae/module/swav.py b/novae/module/swav.py index 23cd47b..dc6f1cb 100644 --- a/novae/module/swav.py +++ b/novae/module/swav.py @@ -61,7 +61,7 @@ def init_queue(self, slide_ids: list[str]) -> None: """ del self.queue - shape = (len(slide_ids), self.num_prototypes) + shape = (len(slide_ids), Nums.QUEUE_SIZE, self.num_prototypes) self.register_buffer("queue", torch.full(shape, 1 / self.num_prototypes)) self.slide_label_encoder = {slide_id: i for i, slide_id in enumerate(slide_ids)}