diff --git a/skscope/skmodel.py b/skscope/skmodel.py index 298cf64..121d633 100644 --- a/skscope/skmodel.py +++ b/skscope/skmodel.py @@ -1,5 +1,6 @@ import numpy as np import jax.numpy as jnp +from jax.scipy.special import logsumexp from skscope import ScopeSolver from sklearn.base import BaseEstimator from sklearn.covariance import LedoitWolf @@ -273,7 +274,7 @@ def custom_objective(alpha): return loss solver = ScopeSolver(p, sparsity=self.sparsity) - alpha = solver.solve(custom_objective) + alpha = solver.solve(custom_objective, jit=True) self.coef_ = np.abs(alpha) return self @@ -466,15 +467,11 @@ def fit(self, X, y, delta, sample_weight=None): self.n_events = K def multivariate_failure_objective(params): - Xbeta = jnp.matmul(X, params) - tmp = jnp.ones((n, K)) - for i in range(n): - for k in range(K): - tmp = tmp.at[i, k].set( - X[i] @ params - - jnp.log(jnp.matmul(y[:, k] >= y[i, k], jnp.exp(Xbeta))) - ) - loss = -jnp.mean(tmp * delta) + Xbeta_expanded = jnp.matmul(X, params)[:, None] + sum_exp_Xbeta = logsumexp( + Xbeta_expanded + jnp.log(y >= y[:, None, :]), axis=1 + ) + loss = -jnp.mean((Xbeta_expanded - sum_exp_Xbeta) * delta) return loss solver = ScopeSolver(p, self.sparsity)