Skip to content

Commit

Permalink
parallel_dispatch: allow passin pytorch tensors directly to kernel calls
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Jul 7, 2024
1 parent 4cce38d commit c2583e1
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions pykokkos/interface/parallel_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,27 @@ def convert_arrays(kwargs: Dict[str, Any]) -> None:
"""

cp_available: bool
torch_available: bool

try:
import cupy as cp
cp_available = True
except ImportError:
cp_available = False

try:
import torch
torch_available = True
except ImportError:
torch_available = False

for k, v in kwargs.items():
if isinstance(v, np.ndarray):
kwargs[k] = array(v)
elif cp_available and isinstance(v, cp.ndarray):
kwargs[k] = array(v)
elif torch_available and torch.is_tensor(v):
kwargs[k] = array(v)


def parallel_for(*args, **kwargs) -> None:
Expand Down

0 comments on commit c2583e1

Please sign in to comment.