Skip to content

Commit

Permalink
Add cuda module in core
Browse files Browse the repository at this point in the history
  • Loading branch information
againxx committed Apr 16, 2022
1 parent 1cb015a commit 4954020
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
2 changes: 2 additions & 0 deletions core/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .core import *
from . import cuda
28 changes: 17 additions & 11 deletions core.pyi → core/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from typing import (
Tuple,
List,
)
from numpy import ndarray
from numpy import ArrayLike, ndarray

class Blob:
def __init__(self, *args, **kwargs) -> None: ...
Expand Down Expand Up @@ -52,7 +52,8 @@ class Dtype:
def byte_code(self) -> DtypeCode: ...
def byte_size(self) -> int: ...

bool = Dtype.Bool
# bool = Dtype.Bool, this variable conflict with the builtin bool
bool8 = Dtype.Bool
float32 = Dtype.Float32
float64 = Dtype.Float64
int8 = Dtype.Int8
Expand Down Expand Up @@ -101,19 +102,19 @@ class HashMap:
self,
init_capacity: int,
key_dtype: Dtype,
key_element_shape: SizeVector,
key_element_shape: Iterable,
value_dtype: Dtype,
value_element_shape: SizeVector,
value_element_shape: Iterable,
device: Device = Device("CPU:0"),
) -> None: ...
@overload
def __init__(
self,
init_capacity: int,
key_dtype: Dtype,
key_element_shape: SizeVector,
value_dtypes: List[Dtype],
value_element_shapes: List[SizeVector],
key_element_shape: Iterable,
value_dtypes: Sequence[Dtype],
value_element_shapes: Sequence[Iterable],
device: Device = Device("CPU:0"),
) -> None: ...
def activate(self, keys: Tensor) -> Tuple[Tensor, Tensor]: ...
Expand All @@ -125,9 +126,9 @@ class HashMap:
def erase(self, keys: Tensor) -> Tensor: ...
def find(self, keys: Tensor) -> Tuple[Tensor, Tensor]: ...
@overload
def insert(keys: Tensor, values: Tensor) -> Tuple[Tensor, Tensor]: ...
def insert(self, keys: Tensor, values: Tensor) -> Tuple[Tensor, Tensor]: ...
@overload
def insert(keys: Tensor, list_values: Tensor) -> Tuple[Tensor, Tensor]: ...
def insert(self, keys: Tensor, list_values: Sequence[Tensor]) -> Tuple[Tensor, Tensor]: ...
def key_tensor(self) -> Tensor: ...
@classmethod
def load(cls, file_name: str) -> HashMap: ...
Expand Down Expand Up @@ -172,7 +173,7 @@ class Tensor:
@overload
def __init__(
self,
np_array: ndarray,
np_array: ArrayLike,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
) -> None: ...
Expand Down Expand Up @@ -295,10 +296,15 @@ class Tensor:
def max(self, dim: Optional[SizeVector] = None, keepdim: bool = False) -> Tensor: ...
def mean(self, dim: Optional[SizeVector] = None, keepdim: bool = False) -> Tensor: ...
def min(self, dim: Optional[SizeVector] = None, keepdim: bool = False) -> Tensor: ...
@overload
def to(self, dtype: Dtype, copy: bool = False) -> Tensor: ...
@overload
def to(self, device: Device, copy: bool = False) -> Tensor: ...
def __getitem__(self, indices: Tensor) -> Tensor: ...

def addmm(input: Tensor, A: Tensor, B: Tensor, alpha: float, beta: float) -> Tensor: ...
def append(self: Tensor, values: Tensor, axis: Optional[int] = None) -> Tensor: ...
def concatenate(tensors: List[Tensor], axis: Optional[int] = None) -> Tensor: ...
def concatenate(tensors: Sequence[Tensor], axis: Optional[int] = None) -> Tensor: ...
def det(A: Tensor) -> float: ...
def inv(A: Tensor) -> Tensor: ...
def lstsq(A: Tensor, B: Tensor) -> Tensor: ...
Expand Down
7 changes: 7 additions & 0 deletions core/cuda.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Optional
from . import core

def device_count() -> int: ...
def is_available() -> bool: ...
def release_cache() -> None: ...
def synchronize(device: Optional[core.Device] = None) -> None: ...

0 comments on commit 4954020

Please sign in to comment.