Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Helper: to_torch #194

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/amrex/Array4.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def array4_to_cupy(self, copy=False, order="F"):
raise ValueError("The order argument must be F or C.")


# torch


def register_Array4_extension(amr):
"""Array4 helper methods"""
import inspect
Expand Down
3 changes: 3 additions & 0 deletions src/amrex/ArrayOfStructs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def aos_to_cupy(self, copy=False):
return cp.array(self, copy=copy)


# torch


def register_AoS_extension(amr):
"""ArrayOfStructs helper methods"""
import inspect
Expand Down
5 changes: 5 additions & 0 deletions src/amrex/MultiFab.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def mf_to_cupy(self, copy=False, order="F"):
return views


# torch


def register_MultiFab_extension(amr):
"""MultiFab helper methods"""

Expand All @@ -99,3 +102,5 @@ def register_MultiFab_extension(amr):
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__

amr.MultiFab.to_cupy = mf_to_cupy

# torch
14 changes: 14 additions & 0 deletions src/amrex/PODVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,19 @@ def podvector_to_cupy(self, copy=False):
raise ValueError("Vector is empty.")


def podvector_to_torch(self, copy=False):
"""
Provide PyTorch tensor views into a PODVector (e.g., RealVector, IntVector).

...
"""
import torch

# if CUDA else ...
# pick right device (context? device number?)
return torch.as_tensor(self.to_cupy(copy), device="cuda")


def register_PODVector_extension(amr):
"""PODVector helper methods"""
import inspect
Expand All @@ -82,3 +95,4 @@ def register_PODVector_extension(amr):
):
POD_type.to_numpy = podvector_to_numpy
POD_type.to_cupy = podvector_to_cupy
POD_type.to_torch = podvector_to_torch
14 changes: 14 additions & 0 deletions src/amrex/StructOfArrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ def soa_to_cupy(self, copy=False):
return soa_view


def soa_to_torch(self, copy=False):
"""
Provide PyTorch tensor views into a StructOfArrays.

...
"""
import torch

# if CUDA else ...
# pick right device (context? device number?)
return torch.as_tensor(self.to_cupy(copy), device="cuda")


def register_SoA_extension(amr):
"""StructOfArrays helper methods"""
import inspect
Expand All @@ -97,3 +110,4 @@ def register_SoA_extension(amr):
):
SoA_type.to_numpy = soa_to_numpy
SoA_type.to_cupy = soa_to_cupy
SoA_type.to_torch = soa_to_torch