diff --git a/nitorch/vb/mixtures.py b/nitorch/vb/mixtures.py index a06eee7d..aea44df0 100644 --- a/nitorch/vb/mixtures.py +++ b/nitorch/vb/mixtures.py @@ -33,7 +33,7 @@ def __init__(self, num_class=2): # Functions def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None, - show_fit=False): + show_fit=False, samp=1): """ Fit mixture model. Args: X (torch.tensor): Observed data (N, C). @@ -50,17 +50,34 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None, fig_num (int, optional): Defaults to 1. W (torch.tensor, optional): Observation weights (N, 1). Defaults to no weights. show_fit (bool, optional): Plot mixture fit, defaults to False. + samp (int, optional): Sub-sampling, defaults to 1. Returns: Z (torch.tensor): Responsibilities (N, K). """ + if not isinstance(samp, int) or samp < 1: + raise ValueError(f"samp parameter needs to be an int >= 1, got {samp}") + if verbose: t0 = timer() # Start timer # Set random seed torch.manual_seed(1) + if torch.is_floating_point(X) == False: + # Integer data type -> convert to float and add some noise + X = X.type(torch.float) + X += (0.001*X.max())*torch.randn_like(X) + self.dev = X.device self.dt = X.dtype + + if samp > 1: + # Sub-sample + X0 = X[::1, :] # Ensures copy + X = X[::samp, :] + if W is not None: + W0 = W[::1, :] # Ensures copy + W = W[::samp, :] if len(X.shape) == 1: X = X[:, None] @@ -86,6 +103,16 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None, # EM loop Z, lb = self._em(X, max_iter=max_iter, tol=tol, verbose=verbose, W=W) + if samp > 1: + # Create original resolution responsibilites + X = X0 + N = X.shape[0] + Z = torch.zeros((N, K), dtype=self.dt, device=self.dev) + for k in range(K): + Z[:, k] = torch.log(self.mp[k]) + self._log_likelihood(X, k) + if W is not None: W = W0 + Z, _ = softmax_lse(Z, lse=True, weights=W) + # Print algorithm info if verbose >= 1: print('Algorithm finished in {} iterations, ' @@ -277,7 +304,7 @@ def full_resp(Z, msk, dm=[]): """ Converts masked responsibilities to full. Args: Z (torch.tensor): Masked responsibilities (N, K). - msk (torch.tensor): Mask of original data (N0, 1). + msk (torch.tensor): Mask of original data (N0, ). dm (torch.Size, optional): Reshapes Z_full using dm. Defaults to []. Returns: Z_full (torch.tensor): Full responsibilities (N0, K).