-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdlpack.py
96 lines (81 loc) · 2.47 KB
/
dlpack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import ctypes
_c_str_dltensor = b"dltensor"
class DLDevice(ctypes.Structure):
_fields_ = [
("device_type", ctypes.c_int),
("device_id", ctypes.c_int),
]
class DLDataTypeCode(ctypes.c_uint8):
kDLInt = 0
kDLUInt = 1
kDLFloat = 2
kDLBfloat = 4
def __str__(self):
return {
self.kDLInt: "int",
self.kDLUInt: "uint",
self.kDLFloat: "float",
self.kDLBfloat: "bfloat",
}[self.value]
class DLDataType(ctypes.Structure):
_fields_ = [
("type_code", DLDataTypeCode),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16),
]
TYPE_MAP = {
"bool": (1, 1, 1),
"int32": (0, 32, 1),
"int64": (0, 64, 1),
"uint32": (1, 32, 1),
"uint64": (1, 64, 1),
"float32": (2, 32, 1),
"float64": (2, 64, 1),
}
class DLTensor(ctypes.Structure):
_fields_ = [
("data", ctypes.c_void_p),
("device", DLDevice),
("ndim", ctypes.c_int),
("dtype", DLDataType),
("shape", ctypes.POINTER(ctypes.c_int64)),
("strides", ctypes.POINTER(ctypes.c_int64)),
("byte_offset", ctypes.c_uint64),
]
@property
def itemsize(self):
return self.dtype.lanes * self.dtype.bits // 8
@property
def __array_interface__(self):
shape = tuple(self.shape[dim] for dim in range(self.ndim))
if self.strides:
strides = tuple(
self.strides[dim] * self.itemsize for dim in range(self.ndim)
)
else:
# Array is compact, make it numpy compatible.
strides = []
for i, s in enumerate(shape):
cumulative = 1
for e in range(i + 1, self.ndim):
cumulative *= shape[e]
strides.append(cumulative * self.itemsize)
strides = tuple(strides)
typestr = "|" + str(self.dtype.type_code)[0] + str(self.itemsize)
return dict(
version=3,
shape=shape,
strides=strides,
data=(self.data, True),
offset=self.byte_offset,
typestr=typestr,
)
class DLManagedTensor(ctypes.Structure):
_fields_ = [
("dl_tensor", DLTensor),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)),
]
@property
def __array_interface__(self):
return self.dl_tensor.__array_interface__