Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT - Use @njit(cache=True) #215

Open
PascalCarrivain opened this issue Dec 14, 2023 · 4 comments
Open

FEAT - Use @njit(cache=True) #215

PascalCarrivain opened this issue Dec 14, 2023 · 4 comments

Comments

@PascalCarrivain
Copy link
Contributor

Use caching option from Numba

I explore a little bit the skglm source code and I realized you are using Numba decorator @nijt.
I was wondering if it makes senses to switch to @nijt(cache=True).
Indeed, according to Numba documentation caching compiled functions reduces the future compilation time.

@PascalCarrivain PascalCarrivain changed the title FEAT - Use @njit(cache=True) FEAT - Use @njit(cache=True) Dec 14, 2023
@PascalCarrivain PascalCarrivain changed the title FEAT - Use @njit(cache=True) FEAT - Use @njit(cache=True) Dec 14, 2023
@mathurinm
Copy link
Collaborator

Thanks for the pointer @PascalCarrivain, I was not aware of this feature.

It seems to help a lot (not for the first compilation, but for the subsequent calls), on a very CPU bound problem. From the first run to the second, I change only the value of cache in the following snippet)

In [1]: %run numba_cache.py
0.5166642830008641
0.10939704200063716
0.11151579199940898
0.11363932000131172
0.10976769300032174

In [2]:                                                                         
Do you really want to exit ([y]/n)? 
(base) ➜  scripts git:(main) ✗ ipython
Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.6.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: %run numba_cache.py
0.529917010000645
0.004792376999830594
0.004180643998552114
0.004381413000373868
0.004056714000398642
import numpy as np
import time
from numba import njit

a = np.arange(10_000)

for i in range(5):
    def my_sum(a):
        acc = 0
        for val in a:
            acc += val
        return acc

    t0 = time.perf_counter()
    njit(my_sum, cache=True)(a)
    t1 = time.perf_counter()
    print(t1 - t0)

Can you try to test the impact of using cache=True in our codebase on a real life skglm problem, ie fitting an estimator on a simple problem ?

@PascalCarrivain
Copy link
Contributor Author

@mathurinm Yes, I will do it late this year or early next year.

@mathurinm
Copy link
Collaborator

@PascalCarrivain do you know if this can help the first compilation too ?

@PascalCarrivain
Copy link
Contributor Author

I do not see a huge difference for the first compilation (at least on my projects).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants