Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion to use np.linalg.solve instead of torch.linalg.solve #106

Open
Andyccs opened this issue Apr 18, 2024 · 1 comment
Open

Suggestion to use np.linalg.solve instead of torch.linalg.solve #106

Andyccs opened this issue Apr 18, 2024 · 1 comment

Comments

@Andyccs
Copy link

Andyccs commented Apr 18, 2024

Is your feature request related to a problem? Please describe.

I was trying out the MinimumProfit class. The optimization is really slow due to the usages of np.linalg.solve. By changing np.linalg.solve to torch.linalg.solve, I was able to speed it up by ~2x (on CPU) based on observation. Example codes:

import torch

class CustomMinimumProfit(MinimumProfit):
    def _mean_passage_time(self, lower: int, upper: int, ar_coeff: float, ar_resid: np.array,
                           granularity: float) -> pd.Series:
        # Build the grid for summation
        grid = granularity * np.arange(lower, upper)

        # Calculate the gaussian kernel
        gaussian = self._gaussian_kernel(ar_coeff, grid, ar_resid)

        # Calculate the mean passage time at each grid point
        k_dim = gaussian.shape[0]
        passage_time = torch.linalg.solve(torch.from_numpy(np.eye(k_dim) - gaussian), torch.from_numpy(np.ones(k_dim)))
        # passage_time = np.linalg.solve(np.eye(k_dim) - gaussian, np.ones(k_dim))

        # Return a pandas.Series indexed by grid points for easy retrieval
        passage_time_df = pd.Series(passage_time, index=grid)

        return passage_time_df

The idea is based on the discussion at https://stackoverflow.com/questions/62099939/solving-linear-equations-on-the-gpu-with-numpy-and-pytorch

Describe the solution you'd like
I am not sure whether this repo wants to take in new dependency on PyTorch, but if the maintainer is willing to do so, using PyTorch to solve the equations would greatly speed to the optimization process.

Describe alternatives you've considered
None.

Additional context
None.

@Jackal08
Copy link
Member

Thank you for this great suggestion!

I will come back to this - we are very nervous to add more dependencies as it causes a lot of issues in the long term but maybe there is a way.

Can one copy paste code from pytorch to this repo so that we don't have the dependency?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants