Skip to content

Commit

Permalink
refactor: optim loss of MultivariateFailure
Browse files Browse the repository at this point in the history
- Implement vectorized calculations to replace nested loops for efficiency.
- Utilize logsumexp for better numerical stability in exponential calculations.
  • Loading branch information
bbayukari committed Jan 3, 2024
1 parent 85402d7 commit f70fb37
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions skscope/skmodel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from skscope import ScopeSolver
from sklearn.base import BaseEstimator
from sklearn.covariance import LedoitWolf
Expand Down Expand Up @@ -273,7 +274,7 @@ def custom_objective(alpha):
return loss

solver = ScopeSolver(p, sparsity=self.sparsity)
alpha = solver.solve(custom_objective)
alpha = solver.solve(custom_objective, jit=True)
self.coef_ = np.abs(alpha)
return self

Expand Down Expand Up @@ -466,15 +467,11 @@ def fit(self, X, y, delta, sample_weight=None):
self.n_events = K

def multivariate_failure_objective(params):
Xbeta = jnp.matmul(X, params)
tmp = jnp.ones((n, K))
for i in range(n):
for k in range(K):
tmp = tmp.at[i, k].set(
X[i] @ params
- jnp.log(jnp.matmul(y[:, k] >= y[i, k], jnp.exp(Xbeta)))
)
loss = -jnp.mean(tmp * delta)
Xbeta_expanded = jnp.matmul(X, params)[:, None]
sum_exp_Xbeta = logsumexp(
Xbeta_expanded + jnp.log(y >= y[:, None, :]), axis=1
)
loss = -jnp.mean((Xbeta_expanded - sum_exp_Xbeta) * delta)
return loss

solver = ScopeSolver(p, self.sparsity)
Expand Down

0 comments on commit f70fb37

Please sign in to comment.