Skip to content

Commit

Permalink
FEAT-#6434: HDK: Do not convert dictionary columns to string when imp…
Browse files Browse the repository at this point in the history
…orting arrow tables

Signed-off-by: Andrey Pavlenko <[email protected]>
  • Loading branch information
AndreyPavlenko committed Jul 31, 2023
1 parent eb7ade3 commit 86c0bd0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@

from modin.error_message import ErrorMessage

_UINT_TO_INT_MAP = {
pa.uint8(): pa.int16(),
pa.uint16(): pa.int32(),
pa.uint32(): pa.int64(),
pa.uint64(): pa.int64(), # May cause overflow
}


class DbTable(abc.ABC):
"""
Expand Down Expand Up @@ -152,15 +159,17 @@ def _genName(cls, name):
# TODO: reword name in case of caller's mistake
return name

@staticmethod
def cast_to_compatible_types(table):
@classmethod
def cast_to_compatible_types(cls, table, cast_dict):
"""
Cast PyArrow table to be fully compatible with HDK.
Parameters
----------
table : pyarrow.Table
Source table.
cast_dict : bool
Cast dictionary columns to string.
Returns
-------
Expand All @@ -169,66 +178,55 @@ def cast_to_compatible_types(table):
"""
schema = table.schema
new_schema = schema
need_cast = False
uint_to_int_cast = False
new_cols = {}
uint_to_int_map = {
pa.uint8(): pa.int16(),
pa.uint16(): pa.int32(),
pa.uint32(): pa.int64(),
pa.uint64(): pa.int64(), # May cause overflow
}
need_cast = [False]
uint_to_int_cast = [False]

for i, field in enumerate(schema):
# Currently HDK doesn't support Arrow table import with
# dictionary columns. Here we cast dictionaries until support
# is in place.
# https://github.com/modin-project/modin/issues/1738
if pa.types.is_dictionary(field.type):
# Conversion for dictionary of null type to string is not supported
# in Arrow. Build new column for this case for now.
if pa.types.is_null(field.type.value_type):
mask = np.full(table.num_rows, True, dtype=bool)
new_col_data = np.empty(table.num_rows, dtype=str)
new_col = pa.array(new_col_data, pa.string(), mask)
new_cols[i] = new_col
new_field = pa.field(
field.name, pa.string(), field.nullable, field.metadata
)
table = table.set_column(i, new_field, new_col)
elif pa.types.is_string(field.type.value_type):
if cast_dict:
need_cast[0] = True
new_field = pa.field(
field.name, pa.string(), field.nullable, field.metadata
)
else:
new_field = field
else:
need_cast = True
new_field = pa.field(
field.name, pa.string(), field.nullable, field.metadata
)
new_field = cls._convert_field(
field, field.type.value_type, need_cast, uint_to_int_cast
)
if new_field == field:
new_field = pa.field(
field.name,
field.type.value_type,
field.nullable,
field.metadata,
)
need_cast[0] = True
new_schema = new_schema.set(i, new_field)
# HDK doesn't support importing Arrow's date type:
# https://github.com/omnisci/omniscidb/issues/678
elif pa.types.is_date(field.type):
# Arrow's date is the number of days since the UNIX-epoch, so we can convert it
# to a timestamp[s] (number of seconds since the UNIX-epoch) without losing precision
new_field = pa.field(
field.name, pa.timestamp("s"), field.nullable, field.metadata
else:
new_field = cls._convert_field(
field, field.type, need_cast, uint_to_int_cast
)
new_schema = new_schema.set(i, new_field)
need_cast = True
# HDK doesn't support unsigned types
elif pa.types.is_unsigned_integer(field.type):
new_field = pa.field(
field.name,
uint_to_int_map[field.type],
field.nullable,
field.metadata,
)
new_schema = new_schema.set(i, new_field)
need_cast = True
uint_to_int_cast = True

# Such cast may affect the data, so we have to raise a warning about it
if uint_to_int_cast:
ErrorMessage.single_warning(
"HDK does not support unsigned integer types, such types will be rounded up to the signed equivalent."
)

for i, col in new_cols.items():
table = table.set_column(i, new_schema[i], col)

if need_cast:
if need_cast[0]:
try:
table = table.cast(new_schema)
except pa.lib.ArrowInvalid as err:
Expand All @@ -239,6 +237,41 @@ def cast_to_compatible_types(table):

return table

@staticmethod
def _convert_field(field, field_type, need_cast, uint_to_int_cast):
"""
Convert the specified arrow field, if required.
Parameters
----------
field : pyarrow.Field
field_type : pyarrow.DataType
need_cast : list of bool
uint_to_int_cast : list of bool
Returns
-------
pyarrow.Field
"""
if pa.types.is_date(field_type):
# Arrow's date is the number of days since the UNIX-epoch, so we can convert it
# to a timestamp[s] (number of seconds since the UNIX-epoch) without losing precision
need_cast[0] = True
return pa.field(
field.name, pa.timestamp("s"), field.nullable, field.metadata
)
elif pa.types.is_unsigned_integer(field_type):
# HDK doesn't support unsigned types
need_cast[0] = True
uint_to_int_cast[0] = True
return pa.field(
field.name,
_UINT_TO_INT_MAP[field_type],
field.nullable,
field.metadata,
)
return field

@classmethod
@abc.abstractmethod
def import_arrow_table(cls, table, name=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1071,11 +1071,13 @@ def join(
for left_col, right_col in zip(left_on, right_on):
left_dt = self._dtypes[left_col]
right_dt = other._dtypes[right_col]
if is_categorical_dtype(left_dt) and is_categorical_dtype(right_dt):
left_dt = left_dt.categories.dtype
right_dt = right_dt.categories.dtype
if not (
(is_any_int_dtype(left_dt) and is_any_int_dtype(right_dt))
or (is_string_dtype(left_dt) and is_string_dtype(right_dt))
or (is_datetime64_dtype(left_dt) and is_datetime64_dtype(right_dt))
or (is_categorical_dtype(left_dt) and is_categorical_dtype(right_dt))
):
raise NotImplementedError(
f"Join on columns of '{left_dt}' and '{right_dt}' dtypes"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
"""Module provides ``HdkWorker`` class."""
from typing import Optional, Tuple, List, Union

from packaging import version

import pyarrow as pa
import os

import pyhdk
from pyhdk.hdk import HDK, QueryNode, ExecutionResult, RelAlgExecutor

from .base_worker import DbTable, BaseDbWorker

from modin.utils import _inherit_docstrings
from modin.config import HdkLaunchParameters, OmnisciFragmentSize, HdkFragmentSize

_CAST_DICT = version.parse(getattr(pyhdk, "__version__", "0")) <= version.parse("0.7.0")


class HdkTable(DbTable):
"""
Expand Down Expand Up @@ -116,7 +121,7 @@ def executeRA(cls, query: str, exec_calcite=False):
@classmethod
def import_arrow_table(cls, table: pa.Table, name: Optional[str] = None):
name = cls._genName(name)
table = cls.cast_to_compatible_types(table)
table = cls.cast_to_compatible_types(table, _CAST_DICT)
fragment_size = cls.compute_fragment_size(table)
return HdkTable(cls._hdk().import_arrow(table, name, fragment_size))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,7 @@ def merge(df, df2, on_columns, **kwargs):
on_columns="A",
constructor_kwargs={"dtype": "category"},
comparator=lambda df1, df2: df_equals(df1.astype(float), df2.astype(float)),
force_lazy=False,
)

def test_merge_date(self):
Expand Down

0 comments on commit 86c0bd0

Please sign in to comment.