forked from snowflakedb/snowflake-connector-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharrow_iterator.pyx
228 lines (176 loc) · 7.98 KB
/
arrow_iterator.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#
# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved.
#
# distutils: language = c++
# cython: language_level=3
from logging import getLogger
from cpython.ref cimport PyObject
from libc.stdint cimport *
from libcpp cimport bool as c_bool
from libcpp.memory cimport shared_ptr
from libcpp.string cimport string as c_string
from libcpp.vector cimport vector
from .errors import (Error, OperationalError, InterfaceError)
from .errorcode import (ER_FAILED_TO_READ_ARROW_STREAM, ER_FAILED_TO_CONVERT_ROW_TO_PYTHON_TYPE)
logger = getLogger(__name__)
'''
the unit in this iterator
EMPTY_UNIT: default
ROW_UNIT: fetch row by row if the user call `fetchone()`
TABLE_UNIT: fetch one arrow table if the user call `fetch_pandas()`
'''
ROW_UNIT, TABLE_UNIT, EMPTY_UNIT = 'row', 'table', ''
cdef extern from "cpp/ArrowIterator/CArrowIterator.hpp" namespace "sf":
cdef cppclass ReturnVal:
PyObject * successObj;
PyObject * exception;
cdef cppclass CArrowIterator:
shared_ptr[ReturnVal] next();
cdef extern from "cpp/ArrowIterator/CArrowChunkIterator.hpp" namespace "sf":
cdef cppclass CArrowChunkIterator(CArrowIterator):
CArrowChunkIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches, PyObject* use_numpy) except +
cdef cppclass DictCArrowChunkIterator(CArrowChunkIterator):
DictCArrowChunkIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches, PyObject* use_numpy) except +
cdef extern from "cpp/ArrowIterator/CArrowTableIterator.hpp" namespace "sf":
cdef cppclass CArrowTableIterator(CArrowIterator):
CArrowTableIterator(PyObject* context, vector[shared_ptr[CRecordBatch]]* batches) except +
cdef extern from "arrow/api.h" namespace "arrow" nogil:
cdef cppclass CStatus "arrow::Status":
CStatus()
c_string ToString()
c_string message()
c_bool ok()
c_bool IsIOError()
c_bool IsOutOfMemory()
c_bool IsInvalid()
c_bool IsKeyError()
c_bool IsNotImplemented()
c_bool IsTypeError()
c_bool IsCapacityError()
c_bool IsIndexError()
c_bool IsSerializationError()
cdef cppclass CBuffer" arrow::Buffer":
CBuffer(const uint8_t* data, int64_t size)
cdef cppclass CRecordBatch" arrow::RecordBatch"
cdef cppclass CRecordBatchReader" arrow::RecordBatchReader":
CStatus ReadNext(shared_ptr[CRecordBatch]* batch)
cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
cdef cppclass CRecordBatchStreamReader \
" arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader):
@staticmethod
CStatus Open(const InputStream* stream,
shared_ptr[CRecordBatchReader]* out)
cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil:
enum FileMode" arrow::io::FileMode::type":
FileMode_READ" arrow::io::FileMode::READ"
FileMode_WRITE" arrow::io::FileMode::WRITE"
FileMode_READWRITE" arrow::io::FileMode::READWRITE"
cdef cppclass FileInterface:
CStatus Close()
CStatus Tell(int64_t* position)
FileMode mode()
c_bool closed()
cdef cppclass Readable:
# put overload under a different name to avoid cython bug with multiple
# layers of inheritance
CStatus ReadBuffer" Read"(int64_t nbytes, shared_ptr[CBuffer]* out)
CStatus Read(int64_t nbytes, int64_t* bytes_read, uint8_t* out)
cdef cppclass InputStream(FileInterface, Readable):
pass
cdef cppclass Seekable:
CStatus Seek(int64_t position)
cdef cppclass RandomAccessFile(InputStream, Seekable):
CStatus GetSize(int64_t* size)
CStatus ReadAt(int64_t position, int64_t nbytes,
int64_t* bytes_read, uint8_t* buffer)
CStatus ReadAt(int64_t position, int64_t nbytes,
shared_ptr[CBuffer]* out)
c_bool supports_zero_copy()
cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil:
cdef cppclass PyReadableFile(RandomAccessFile):
PyReadableFile(object fo)
cdef class EmptyPyArrowIterator:
def __next__(self):
raise StopIteration
def init(self, str iter_unit):
pass
cdef class PyArrowIterator(EmptyPyArrowIterator):
cdef object context
cdef CArrowIterator* cIterator
cdef str unit
cdef shared_ptr[ReturnVal] cret
cdef vector[shared_ptr[CRecordBatch]] batches
cdef object use_dict_result
cdef object cursor
# this is the flag indicating whether fetch data as numpy datatypes or not. The flag
# is passed from the constructor of SnowflakeConnection class. Note, only FIXED, REAL
# and TIMESTAMP_NTZ will be converted into numpy data types, all other sql types will
# still be converted into native python types.
# https://docs.snowflake.net/manuals/user-guide/sqlalchemy.html#numpy-data-type-support
cdef object use_numpy
def __cinit__(self, object cursor, object py_inputstream, object arrow_context, object use_dict_result,
object numpy):
cdef shared_ptr[InputStream] input_stream
cdef shared_ptr[CRecordBatchReader] reader
cdef shared_ptr[CRecordBatch] record_batch
input_stream.reset(new PyReadableFile(py_inputstream))
cdef CStatus ret = CRecordBatchStreamReader.Open(input_stream.get(), &reader)
if not ret.ok():
Error.errorhandler_wrapper(
cursor.connection,
cursor,
OperationalError,
{
u'msg': u'Failed to open arrow stream: ' + ret.message(),
u'errno': ER_FAILED_TO_READ_ARROW_STREAM
})
while True:
ret = reader.get().ReadNext(&record_batch)
if not ret.ok():
Error.errorhandler_wrapper(
cursor.connection,
cursor,
OperationalError,
{
u'msg': u'Failed to read next arrow batch: ' + ret.message(),
u'errno': ER_FAILED_TO_READ_ARROW_STREAM
})
if record_batch.get() is NULL:
break
self.batches.push_back(record_batch)
logger.debug("Batches read: %d", self.batches.size())
self.context = arrow_context
self.cIterator = NULL
self.unit = ''
self.use_dict_result = use_dict_result
self.cursor = cursor
self.use_numpy = numpy
def __dealloc__(self):
del self.cIterator
def __next__(self):
self.cret = self.cIterator.next()
if not self.cret.get().successObj:
msg = u'Failed to convert current row, cause: ' + str(<object>self.cret.get().exception)
Error.errorhandler_wrapper(self.cursor.connection, self.cursor, InterfaceError,
{
u'msg': msg,
u'errno': ER_FAILED_TO_CONVERT_ROW_TO_PYTHON_TYPE
})
# it looks like this line can help us get into python and detect the global variable immediately
# however, this log will not show up for unclear reason
ret = <object>self.cret.get().successObj
if ret is None:
raise StopIteration
else:
return ret
def init(self, str iter_unit):
# init chunk (row) iterator or table iterator
if iter_unit != ROW_UNIT and iter_unit != TABLE_UNIT:
raise NotImplementedError
elif iter_unit == ROW_UNIT:
self.cIterator = new CArrowChunkIterator(<PyObject*>self.context, &self.batches, <PyObject *>self.use_numpy) \
if not self.use_dict_result \
else new DictCArrowChunkIterator(<PyObject*>self.context, &self.batches, <PyObject *>self.use_numpy)
elif iter_unit == TABLE_UNIT:
self.cIterator = new CArrowTableIterator(<PyObject*>self.context, &self.batches)
self.unit = iter_unit