diff --git a/novae/_constants.py b/novae/_constants.py index 0e3697f..2b61500 100644 --- a/novae/_constants.py +++ b/novae/_constants.py @@ -57,7 +57,7 @@ class Nums: SWAV_EPSILON: float = 0.05 SINKHORN_ITERATIONS: int = 3 QUEUE_SIZE: int = 3 - QUEUE_WEIGHT_THRESHOLD: float = 0.99 + QUEUE_WEIGHT_THRESHOLD_RATIO: float = 0.9 # misc nums MEAN_NGH_TH_WARNING: float = 3.5 diff --git a/novae/model.py b/novae/model.py index 34ca59c..8a7f952 100644 --- a/novae/model.py +++ b/novae/model.py @@ -50,7 +50,6 @@ def __init__( num_layers: int = 10, batch_size: int = 512, temperature: float = 0.1, - temperature_weight_proto: float = 0.1, num_prototypes: int = 256, panel_subset_size: float = 0.6, background_noise_lambda: float = 8.0, @@ -99,7 +98,7 @@ def __init__( ### Initialize modules self.encoder = GraphEncoder(embedding_size, hidden_size, num_layers, output_size, heads) self.augmentation = GraphAugmentation(panel_subset_size, background_noise_lambda, sensitivity_noise_std) - self.swav_head = SwavHead(self.mode, output_size, num_prototypes, temperature, temperature_weight_proto) + self.swav_head = SwavHead(self.mode, output_size, num_prototypes, temperature) ### Misc self._num_workers = 0 @@ -413,11 +412,12 @@ def plot_prototype_weights(self, **kwargs: int): self.swav_head.queue is not None ), "Swav queue not initialized. Initialize it with `model.init_slide_queue(...)`, then train or fine-tune the model." - weights = self.swav_head.queue_weights().numpy(force=True) + weights, thresholds = self.swav_head.queue_weights() + weights, thresholds = weights.numpy(force=True), thresholds.numpy(force=True) - where_enough_prototypes = (weights >= Nums.QUEUE_WEIGHT_THRESHOLD).sum(1) >= self.swav_head.min_prototypes + where_enough_prototypes = (weights >= thresholds).sum(1) >= self.swav_head.min_prototypes for i in np.where(where_enough_prototypes)[0]: - weights[i, weights[i] < Nums.QUEUE_WEIGHT_THRESHOLD] = 0 + weights[i, weights[i] < thresholds] = 0 for i in np.where(~where_enough_prototypes)[0]: indices_0 = np.argsort(weights[i])[: -self.swav_head.min_prototypes] weights[i, indices_0] = 0 diff --git a/novae/module/swav.py b/novae/module/swav.py index 3a0d81f..dc6f1cb 100644 --- a/novae/module/swav.py +++ b/novae/module/swav.py @@ -27,7 +27,6 @@ def __init__( output_size: int, num_prototypes: int, temperature: float, - temperature_weight_proto: float, ): """SwavHead module, adapted from the paper ["Unsupervised Learning of Visual Features by Contrasting Cluster Assignments"](https://arxiv.org/abs/2006.09882). @@ -41,7 +40,6 @@ def __init__( self.output_size = output_size self.num_prototypes = num_prototypes self.temperature = temperature - self.temperature_weight_proto = temperature_weight_proto self._prototypes = nn.Parameter(torch.empty((self.num_prototypes, self.output_size))) self._prototypes = nn.init.kaiming_uniform_(self._prototypes, a=math.sqrt(5), mode="fan_out") @@ -63,7 +61,7 @@ def init_queue(self, slide_ids: list[str]) -> None: """ del self.queue - shape = (len(slide_ids), self.num_prototypes) + shape = (len(slide_ids), Nums.QUEUE_SIZE, self.num_prototypes) self.register_buffer("queue", torch.full(shape, 1 / self.num_prototypes)) self.slide_label_encoder = {slide_id: i for i, slide_id in enumerate(slide_ids)} @@ -131,20 +129,26 @@ def prototype_ilocs(self, projections: Tensor, slide_id: str | None = None) -> T slide_index = self.slide_label_encoder[slide_id] - self.queue[slide_index] = projections.topk(3, dim=0).values[-1] # top-3 more robust than max + self.queue[slide_index, 1:] = self.queue[slide_index, :-1].clone() + self.queue[slide_index, 0] = projections.topk(3, dim=0).values[-1] - weights = self.queue_weights()[slide_index] - ilocs = torch.where(weights >= Nums.QUEUE_WEIGHT_THRESHOLD)[0] + weights, thresholds = self.queue_weights() + slide_weights = weights[slide_index] - return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(weights, self.min_prototypes).indices + ilocs = torch.where(slide_weights >= thresholds)[0] + return ilocs if len(ilocs) >= self.min_prototypes else torch.topk(slide_weights, self.min_prototypes).indices - def queue_weights(self) -> Tensor: + def queue_weights(self) -> tuple[Tensor, Tensor]: """Convert the queue to a matrix of prototype weight per slide. Returns: A tensor of shape `(n_slides, K)`. """ - return self.sinkhorn(self.queue) * self.num_prototypes + max_projections = self.queue.max(dim=1).values + + thresholds = max_projections.max(0).values * Nums.QUEUE_WEIGHT_THRESHOLD_RATIO + + return max_projections, thresholds @utils.format_docs @torch.no_grad() diff --git a/novae/plot/_heatmap.py b/novae/plot/_heatmap.py index 8d579ef..6faf410 100644 --- a/novae/plot/_heatmap.py +++ b/novae/plot/_heatmap.py @@ -22,8 +22,8 @@ def _weights_clustermap( show_yticklabels: bool = False, show_tissue_legend: bool = True, figsize: tuple[int] = (6, 4), - vmin: float = 0.9, - vmax: float = 1.1, + vmin: float = 0, + vmax: float = 1, **kwargs: int, ) -> None: row_colors = None diff --git a/scripts/config/all_human.yaml b/scripts/config/all_human.yaml index 0ea5c36..19df528 100644 --- a/scripts/config/all_human.yaml +++ b/scripts/config/all_human.yaml @@ -8,11 +8,10 @@ model_kwargs: heads: 16 hidden_size: 128 temperature: 0.1 - temperature_weight_proto: 0.2 num_prototypes: 512 background_noise_lambda: 5 panel_subset_size: 0.8 - min_prototypes_ratio: 0.2 + min_prototypes_ratio: 0.5 fit_kwargs: max_epochs: 30 diff --git a/scripts/config/all_human2.yaml b/scripts/config/all_human2.yaml index 4604aeb..0ce7a8e 100644 --- a/scripts/config/all_human2.yaml +++ b/scripts/config/all_human2.yaml @@ -8,11 +8,10 @@ model_kwargs: heads: 16 hidden_size: 128 temperature: 0.1 - temperature_weight_proto: 0.2 num_prototypes: 512 background_noise_lambda: 5 panel_subset_size: 0.8 - min_prototypes_ratio: 0.2 + min_prototypes_ratio: 0.75 fit_kwargs: max_epochs: 30 diff --git a/scripts/config/all_new.yaml b/scripts/config/all_new.yaml new file mode 100644 index 0000000..c8c931d --- /dev/null +++ b/scripts/config/all_new.yaml @@ -0,0 +1,30 @@ +data: + train_dataset: all + val_dataset: igr/202305031337_hBreast-slide-B-4h-photobleach_VMSC09302 + +model_kwargs: + scgpt_model_dir: /gpfs/workdir/blampeyq/checkpoints/scgpt/scGPT_human + n_hops_view: 3 + heads: 16 + hidden_size: 128 + temperature: 0.1 + num_prototypes: 512 + background_noise_lambda: 5 + panel_subset_size: 0.8 + min_prototypes_ratio: 0.15 + +fit_kwargs: + max_epochs: 30 + lr: 0.0001 + accelerator: "gpu" + num_workers: 8 + patience: 6 + min_delta: 0.025 + +post_training: + n_domains: [15, 20, 25] + log_metrics: true + save_h5ad: true + log_umap: true + log_domains: true + delete_X: true