Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Automatic support of L2 regulrization in Penalties #150

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Badr-MOUFAD
Copy link
Collaborator

Given a penalty $f: \mathbb{R}^n \rightarrow \mathbb{R}$, that is already implemented in the package,
It is possible to endow it with L2 regularization to get $\Omega = f + \frac{\mu}{2} \lVert \cdot \rVert$

Indeed for a step, $\sigma$ and gradient $\mathrm{grad}$,
the proximal operator and distance to subdifferential can be written using prox and subdiffdistance of $f$

$$
\mathrm{prox}{\Omega, \sigma}(x) = \mathrm{prox}{f, \frac{\sigma}{1 + \sigma \mu}}(\frac{x}{1 + \sigma \mu})
$$

$$
\mathrm{dist}{\partial \Omega(x)}(-\mathrm{grad}) = \mathrm{dist}{\partial f(x)}(-\mathrm{grad} - \mu x)
$$

Implementation

This can be implemented either through inheritance or a class decorator.
This PR provides a POC of the second approach. Hence to add support for L2 regularization, one only needs to decorate the penalty with overload_with_l2.

Help needed

I unittested to the logic and implementation and everything works as expected. However, I'm running into problems when jit-compiling the class as numba doesn't support *args, **kwargs, which are mandatory to overload the constructor of the penalty.

Any workaround to bypass that?

@Badr-MOUFAD Badr-MOUFAD marked this pull request as draft April 6, 2023 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants