Skip to content

Commit f7f5352

Browse files
xmnlabjp-harvey
andauthored
Add load_table array support for columnar method (#311)
* Basic working array load * Fix columnar array load_table issue * Add load_table tests for None and empty array Co-authored-by: jp-harvey <[email protected]>
1 parent 91520f0 commit f7f5352

6 files changed

+384
-146
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ ENV/
9393
# pycharm
9494
.idea/
9595

96+
# vscode
97+
.vscode/
98+
9699
# Rope project settings
97100
.ropeproject
98101

pymapd/_pandas_loaders.py

+55-11
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626

2727

2828
def get_mapd_dtype(data):
29-
"Get the OmniSci type"
29+
"""Get the OmniSci type"""
30+
3031
if is_object_dtype(data):
3132
return get_mapd_type_from_object(data)
3233
else:
@@ -119,37 +120,80 @@ def build_input_columnar(
119120

120121
dfs = np.array_split(df, chunks)
121122
cols_array = []
123+
122124
for df in dfs:
123125
input_cols = []
124126

125127
colindex = 0
126128
for col in col_names:
127-
data = df[col]
129+
data = df.loc[:, [col]]
128130

129-
mapd_type = col_types[colindex][0]
131+
mapd_type = col_types[colindex].type
132+
is_array = col_types[colindex].is_array
133+
scale = col_types[colindex].scale
134+
has_nulls = data[col].hasnans
130135

131-
has_nulls = data.hasnans
132136
if has_nulls:
133-
nulls = data.isnull().values.tolist()
137+
nulls = data[col].isnull().values.tolist()
134138
else:
135139
nulls = [False] * len(df)
136140

141+
if is_array:
142+
# Expand the dataframe so each array item has
143+
# its own field in the dataframe.
144+
data = data.iloc[:, 0].apply(pd.Series)
145+
137146
if mapd_type in {'TIME', 'TIMESTAMP', 'DATE', 'BOOL'}:
138147
# requires a cast to integer
139-
data = thrift_cast(data, mapd_type, 0)
148+
for c in data:
149+
data.loc[:, c] = thrift_cast(
150+
data=data[c], mapd_type=mapd_type
151+
)
140152

141153
if mapd_type in ['DECIMAL']:
142154
# requires a calculation be done using the scale
143155
# then cast to int
144-
data = thrift_cast(data, mapd_type, col_types[colindex][1])
156+
for c in data:
157+
data.loc[:, c] = thrift_cast(
158+
data=data[c],
159+
mapd_type=mapd_type,
160+
scale=scale,
161+
is_array=is_array,
162+
)
145163

146164
if has_nulls:
147-
data = data.fillna(mapd_to_na[mapd_type])
165+
if not is_array:
166+
for c in data:
167+
data.loc[:, c] = data[c].fillna(mapd_to_na[mapd_type])
168+
169+
if is_array:
170+
data = data.apply(lambda x: [i for i in x.dropna()], axis=1)
171+
if has_nulls:
172+
data[nulls] = mapd_to_na[mapd_type]
148173

149174
if mapd_type not in ['FLOAT', 'DOUBLE', 'VARCHAR', 'STR']:
150-
data = data.astype('int64')
151-
# use .values so that indexes don't have to be serialized too
152-
kwargs = {mapd_to_slot[mapd_type]: data.values}
175+
if is_array:
176+
data = data.apply(
177+
lambda _array: [int(item) for item in _array]
178+
if isinstance(_array, list)
179+
else None
180+
)
181+
else:
182+
for c in data:
183+
data.loc[:, c] = data.loc[:, c].astype('int64')
184+
185+
# If this is an array column, we need the data to be a series
186+
# of TColumn objects of type mapd_type.
187+
if is_array:
188+
data = data.apply(
189+
lambda x: TColumn(
190+
data=TColumnData(**{mapd_to_slot[mapd_type]: x})
191+
)
192+
)
193+
kwargs = {'arr_col': data}
194+
else:
195+
kwargs = {mapd_to_slot[mapd_type]: data.iloc[:, 0].values}
196+
153197
input_cols.append(TColumn(data=TColumnData(**kwargs), nulls=nulls))
154198
colindex += 1
155199
cols_array.append(input_cols)

pymapd/connection.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def __init__(
262262
proto = TBinaryProtocol.TBinaryProtocolAccelerated(transport)
263263
else:
264264
raise ValueError(
265-
"`protocol` should be one of",
266-
" ['http', 'https', 'binary'],",
265+
"`protocol` should be one of"
266+
" ['http', 'https', 'binary'],"
267267
" got {} instead".format(protocol),
268268
)
269269
self._user = user
@@ -433,7 +433,7 @@ def select_ipc_gpu(
433433
from cudf.core.dataframe import DataFrame # noqa
434434
except ImportError:
435435
raise ImportError(
436-
"The 'cudf' package is required for " "`select_ipc_gpu`"
436+
"The 'cudf' package is required for `select_ipc_gpu`"
437437
)
438438

439439
self.register_runtime_udfs()
@@ -771,37 +771,38 @@ def load_table_columnar(
771771
order to avoid loading inconsistent values into DATE column.
772772
"""
773773

774-
if isinstance(data, pd.DataFrame):
775-
table_details = self.get_table_details(table_name)
776-
# Validate that there are the same number of columns in the table
777-
# as there are in the dataframe. No point trying to load the data
778-
# if this is not the case
779-
if len(table_details) != len(data.columns):
780-
raise ValueError(
781-
'Number of columns in dataframe ({}) does not \
782-
match number of columns in OmniSci table \
783-
({})'.format(
784-
len(data.columns), len(table_details)
785-
)
786-
)
774+
if not isinstance(data, pd.DataFrame):
775+
raise TypeError('Unknown type {}'.format(type(data)))
787776

788-
col_names = (
789-
[i[0] for i in table_details]
790-
if col_names_from_schema
791-
else list(data)
777+
table_details = self.get_table_details(table_name)
778+
# Validate that there are the same number of columns in the table
779+
# as there are in the dataframe. No point trying to load the data
780+
# if this is not the case
781+
if len(table_details) != len(data.columns):
782+
raise ValueError(
783+
'Number of columns in dataframe ({}) does not \
784+
match number of columns in OmniSci table \
785+
({})'.format(
786+
len(data.columns), len(table_details)
787+
)
792788
)
793789

794-
col_types = [(i[1], i[4]) for i in table_details]
790+
col_names = (
791+
[i.name for i in table_details]
792+
if col_names_from_schema
793+
else list(data)
794+
)
795+
796+
col_types = table_details
797+
798+
input_cols = _pandas_loaders.build_input_columnar(
799+
data,
800+
preserve_index=preserve_index,
801+
chunk_size_bytes=chunk_size_bytes,
802+
col_types=col_types,
803+
col_names=col_names,
804+
)
795805

796-
input_cols = _pandas_loaders.build_input_columnar(
797-
data,
798-
preserve_index=preserve_index,
799-
chunk_size_bytes=chunk_size_bytes,
800-
col_types=col_types,
801-
col_names=col_names,
802-
)
803-
else:
804-
raise TypeError("Unknown type {}".format(type(data)))
805806
for cols in input_cols:
806807
self._client.load_table_binary_columnar(
807808
self._session, table_name, cols

tests/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import subprocess
22
import time
3+
from uuid import uuid4
4+
35
import pytest
46
from thrift.transport import TSocket, TTransport
57
from thrift.transport.TSocket import TTransportException
@@ -189,3 +191,14 @@ def _tests_table_no_nulls(n_samples):
189191
}
190192

191193
return pd.DataFrame(d)
194+
195+
196+
@pytest.fixture
197+
def tmp_table(con) -> str:
198+
table_name = 'table_{}'.format(uuid4().hex)
199+
con.execute("drop table if exists {};".format(table_name))
200+
201+
try:
202+
yield table_name
203+
finally:
204+
con.execute("drop table if exists {};".format(table_name))

tests/test_integration.py

+49-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
TOmniSciException.__hash__ = lambda x: id(x)
2626

2727

28+
def _cursor2df(cursor):
29+
col_names = [c.name for c in cursor.description]
30+
return pd.DataFrame(cursor.fetchall(), columns=col_names)
31+
32+
2833
@pytest.mark.usefixtures("mapd_server")
2934
class TestIntegration:
3035
def test_connect_binary(self):
@@ -666,17 +671,50 @@ def test_load_empty_table_arrow(self, con):
666671
self.check_empty_insert(result, data)
667672
con.execute("drop table if exists baz;")
668673

669-
def test_load_table_columnar(self, con):
670-
671-
con.execute("drop table if exists baz;")
672-
con.execute("create table baz (a int, b float, c text);")
673-
674-
df = pd.DataFrame(
675-
{"a": [1, 2, 3], "b": [1.1, 2.2, 3.3], "c": ['a', '2', '3']},
676-
columns=['a', 'b', 'c'],
677-
)
678-
con.load_table_columnar("baz", df)
679-
con.execute("drop table if exists baz;")
674+
@pytest.mark.parametrize(
675+
'df, table_fields',
676+
[
677+
(
678+
pd.DataFrame(
679+
{
680+
"a": [1, 2, 3],
681+
"b": [1.1, 2.2, 3.3],
682+
"c": ['a', '2', '3'],
683+
},
684+
),
685+
'a int, b float, c text',
686+
),
687+
(
688+
pd.DataFrame(
689+
[
690+
{'ary': [2, 3, 4]},
691+
{'ary': [4444]},
692+
{'ary': []},
693+
{'ary': None},
694+
{'ary': [2, 3, 4]},
695+
]
696+
),
697+
'ary INT[]',
698+
),
699+
(
700+
pd.DataFrame(
701+
[
702+
{'ary': [2, 3, 4], 'strtest': 'teststr'},
703+
{'ary': None, 'strtest': 'teststr'},
704+
{'ary': [4444], 'strtest': 'teststr'},
705+
{'ary': [], 'strtest': 'teststr'},
706+
{'ary': [2, 3, 4], 'strtest': 'teststr'},
707+
]
708+
),
709+
'ary INT[], strtest TEXT',
710+
),
711+
],
712+
)
713+
def test_load_table_columnar(self, con, tmp_table, df, table_fields):
714+
con.execute("create table {} ({});".format(tmp_table, table_fields))
715+
con.load_table_columnar(tmp_table, df)
716+
result = _cursor2df(con.execute('select * from {}'.format(tmp_table)))
717+
pd.testing.assert_frame_equal(df, result)
680718

681719
def test_load_infer(self, con):
682720

0 commit comments

Comments
 (0)