From 828dc9d63b75873ad5ce35d1b07ca304acea7ed2 Mon Sep 17 00:00:00 2001 From: belsten <belsten@Alexanders-MacBook-Pro.local> Date: Tue, 19 Nov 2024 19:53:15 -0800 Subject: [PATCH] add priors file and update tests to reflect new import --- sparsecoding/priors.py | 167 ++++++++++++++++++++++++++++++++ tests/inference/common.py | 3 +- tests/priors/test_l0.py | 2 +- tests/priors/test_spike_slab.py | 2 +- 4 files changed, 170 insertions(+), 4 deletions(-) create mode 100644 sparsecoding/priors.py diff --git a/sparsecoding/priors.py b/sparsecoding/priors.py new file mode 100644 index 0000000..026ac43 --- /dev/null +++ b/sparsecoding/priors.py @@ -0,0 +1,167 @@ +import torch +from torch.distributions.laplace import Laplace + +from abc import ABC, abstractmethod + + +class Prior(ABC): + """A distribution over weights. + + Parameters + ---------- + weights_dim : int + Number of weights for each sample. + """ + @abstractmethod + def D(self): + """ + Number of weights per sample. + """ + + @abstractmethod + def sample( + self, + num_samples: int = 1, + ): + """Sample weights from the prior. + + Parameters + ---------- + num_samples : int, default=1 + Number of samples. + + Returns + ------- + samples : Tensor, shape [num_samples, self.D] + Sampled weights. + """ + + +class SpikeSlabPrior(Prior): + """Prior where weights are drawn from a "spike-and-slab" distribution. + + The "spike" is at 0 and the "slab" is Laplacian. + + See: + https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf + for a good review of the spike-and-slab model. + + Parameters + ---------- + dim : int + Number of weights per sample. + p_spike : float + The probability of the weight being 0. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + p_spike: float, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if p_spike < 0 or p_spike > 1: + raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.p_spike = p_spike + self.scale = scale + self.positive_only = positive_only + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + zero_weights = torch.zeros((N, self.D), dtype=torch.float32) + slab_weights = Laplace( + loc=zero_weights, + scale=torch.full((N, self.D), self.scale, dtype=torch.float32), + ).sample() # [N, D] + + if self.positive_only: + slab_weights = torch.abs(slab_weights) + + spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike + + weights = torch.where( + spike_over_slab, + zero_weights, + slab_weights, + ) + + return weights + + +class L0Prior(Prior): + """Prior with a distribution over the l0-norm of the weights. + + A class of priors where the weights are binary; + the distribution is over the l0-norm of the weight vector + (how many weights are active). + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + + self.prob_distr = prob_distr + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + weights[active_weight_idxs] += 1. + + return weights diff --git a/tests/inference/common.py b/tests/inference/common.py index 2ff14d2..eb5d574 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,7 +1,6 @@ import torch -from sparsecoding.priors.l0 import L0Prior -from sparsecoding.priors.spike_slab import SpikeSlabPrior +from sparsecoding.priors import L0Prior, SpikeSlabPrior from sparsecoding.data.datasets.bars import BarsDataset torch.manual_seed(1997) diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py index 7518ad5..12dcce6 100644 --- a/tests/priors/test_l0.py +++ b/tests/priors/test_l0.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors import L0Prior class TestL0Prior(unittest.TestCase): diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py index df4ec31..20804fe 100644 --- a/tests/priors/test_spike_slab.py +++ b/tests/priors/test_spike_slab.py @@ -1,7 +1,7 @@ import torch import unittest -from sparsecoding.priors.spike_slab import SpikeSlabPrior +from sparsecoding.priors import SpikeSlabPrior class TestSpikeSlabPrior(unittest.TestCase):