Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python)!: Allow all DataType objects to be instantiated #12470

Merged
merged 6 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
import deltalake
from xlsxwriter import Workbook

from polars import Expr, LazyFrame, Series
from polars import DataType, Expr, LazyFrame, Series
from polars.interchange.dataframe import PolarsDataFrame
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def columns(self, names: Sequence[str]) -> None:
self._df.set_column_names(names)

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get the datatypes of the columns of this DataFrame.

Expand Down Expand Up @@ -1255,7 +1255,7 @@ def flags(self) -> dict[str, dict[str, bool]]:
return {name: self[name].flags for name in self.columns}

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""
Get a dict[column name, DataType].

Expand Down
19 changes: 12 additions & 7 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,24 @@ def is_nested(self) -> bool: # noqa: D102
class DataType(metaclass=DataTypeClass):
"""Base class for all Polars data types."""

def __new__(cls, *args: Any, **kwargs: Any) -> PolarsDataType: # type: ignore[misc] # noqa: D102
# this formulation allows for equivalent use of "pl.Type" and "pl.Type()", while
# still respecting types that take initialisation params (eg: Duration/Datetime)
if args or kwargs:
return super().__new__(cls)
return cls

def __reduce__(self) -> Any:
return (_custom_reconstruct, (type(self), object, None), self.__dict__)

def _string_repr(self) -> str:
return _dtype_str_repr(self)

def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override]
if type(other) is DataTypeClass:
return issubclass(other, type(self))
else:
return isinstance(other, type(self))

def __hash__(self) -> int:
return hash(self.__class__)

def __repr__(self) -> str:
return self.__class__.__name__

@classmethod
def base_type(cls) -> DataTypeClass:
"""
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from polars.io.pyarrow_dataset import scan_pyarrow_dataset

if TYPE_CHECKING:
from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_delta(
Expand Down Expand Up @@ -320,7 +319,7 @@ def _check_if_delta_available() -> None:
)


def _check_for_unsupported_types(dtypes: list[PolarsDataType]) -> None:
def _check_for_unsupported_types(dtypes: list[DataType]) -> None:
schema_dtypes = unpack_dtypes(*dtypes)
unsupported_types = {Time, Categorical, Null}
overlap = schema_dtypes & unsupported_types
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_ipc(
Expand Down Expand Up @@ -185,7 +184,7 @@ def read_ipc_stream(
)


def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, PolarsDataType]:
def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, DataType]:
"""
Get the schema of an IPC file without reading data.

Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/io/parquet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import ParallelStrategy, PolarsDataType
from polars import DataFrame, DataType, LazyFrame
from polars.type_aliases import ParallelStrategy


def read_parquet(
Expand Down Expand Up @@ -143,7 +143,7 @@ def read_parquet(

def read_parquet_schema(
source: str | BinaryIO | Path | bytes,
) -> dict[str, PolarsDataType]:
) -> dict[str, DataType]:
"""
Get the schema of a Parquet file without reading data.

Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

import pyarrow as pa

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.dependencies import numpy as np
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -693,7 +693,7 @@ def columns(self) -> list[str]:
return self._ldf.columns()

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get dtypes of columns in LazyFrame.

Expand All @@ -717,7 +717,7 @@ def dtypes(self) -> list[PolarsDataType]:
return self._ldf.dtypes()

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""
Get a dict[column name, DataType].

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def median(self) -> dt.date | dt.datetime | dt.timedelta | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def mean(self) -> dt.date | dt.datetime | None:
Expand All @@ -108,7 +108,7 @@ def mean(self) -> dt.date | dt.datetime | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def to_string(self, format: str) -> Series:
Expand Down
27 changes: 15 additions & 12 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
if TYPE_CHECKING:
import sys

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.series._numpy import SeriesView
from polars.type_aliases import (
ClosedInterval,
Expand Down Expand Up @@ -365,7 +365,7 @@ def _get_ptr(self) -> tuple[int, int, int]:
return self._s.get_ptr()

@property
def dtype(self) -> PolarsDataType:
def dtype(self) -> DataType:
"""
Get the data type of this Series.

Expand Down Expand Up @@ -398,10 +398,13 @@ def flags(self) -> dict[str, bool]:
return out

@property
def inner_dtype(self) -> PolarsDataType | None:
def inner_dtype(self) -> DataType | None:
"""
Get the inner dtype in of a List typed Series.

.. deprecated:: 0.19.14
Use `Series.dtype.inner` instead.

Returns
-------
DataType
Expand All @@ -412,7 +415,7 @@ def inner_dtype(self) -> PolarsDataType | None:
version="0.19.14",
)
try:
return self.dtype.inner # type: ignore[union-attr]
return self.dtype.inner # type: ignore[attr-defined]
except AttributeError:
return None

Expand Down Expand Up @@ -502,12 +505,12 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
time_unit = "us"
elif self.dtype == Datetime:
# Use local time zone info
time_zone = self.dtype.time_zone # type: ignore[union-attr]
time_zone = self.dtype.time_zone # type: ignore[attr-defined]
if str(other.tzinfo) != str(time_zone):
raise TypeError(
f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}"
)
time_unit = self.dtype.time_unit # type: ignore[union-attr]
time_unit = self.dtype.time_unit # type: ignore[attr-defined]
else:
raise ValueError(
f"cannot compare datetime.datetime to Series of type {self.dtype}"
Expand All @@ -524,7 +527,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
return self._from_pyseries(f(d))

elif isinstance(other, timedelta) and self.dtype == Duration:
time_unit = self.dtype.time_unit # type: ignore[union-attr]
time_unit = self.dtype.time_unit # type: ignore[attr-defined]
td = _timedelta_to_pl_timedelta(other, time_unit) # type: ignore[arg-type]
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
Expand Down Expand Up @@ -4051,9 +4054,9 @@ def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
if self.dtype == Date:
tp = "datetime64[D]"
elif self.dtype == Duration:
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
else:
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
return arr.astype(tp)

def raise_no_zero_copy() -> None:
Expand All @@ -4066,7 +4069,7 @@ def raise_no_zero_copy() -> None:
writable=writable,
use_pyarrow=use_pyarrow,
)
np_array.shape = (self.len(), self.dtype.width) # type: ignore[union-attr]
np_array.shape = (self.len(), self.dtype.width) # type: ignore[attr-defined]
return np_array

if (
Expand Down Expand Up @@ -6972,7 +6975,7 @@ def is_boolean(self) -> bool:
True

"""
return self.dtype is Boolean
return self.dtype == Boolean

@deprecate_function("Use `Series.dtype == pl.Utf8` instead.", version="0.19.14")
def is_utf8(self) -> bool:
Expand All @@ -6989,7 +6992,7 @@ def is_utf8(self) -> bool:
True

"""
return self.dtype is Utf8
return self.dtype == Utf8

@deprecate_renamed_function("gather_every", version="0.19.14")
def take_every(self, n: int) -> Series:
Expand Down
7 changes: 3 additions & 4 deletions py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from polars.utils.various import sphinx_accessor

if TYPE_CHECKING:
from polars import DataFrame, Series
from polars import DataFrame, DataType, Series
from polars.polars import PySeries
from polars.type_aliases import SchemaDict
elif os.getenv("BUILDING_SPHINX_DOCS"):
property = sphinx_accessor

Expand Down Expand Up @@ -66,10 +65,10 @@ def rename_fields(self, names: Sequence[str]) -> Series:
"""

@property
def schema(self) -> SchemaDict:
def schema(self) -> OrderedDict[str, DataType]:
"""Get the struct definition as a name/dtype schema dict."""
if getattr(self, "_s", None) is None:
return {}
return OrderedDict()
return OrderedDict(self._s.dtype().to_schema())

def unnest(self) -> DataFrame:
Expand Down
10 changes: 5 additions & 5 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polars.testing.asserts.utils import raise_assertion_error

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
from polars import DataType


def assert_series_equal(
Expand Down Expand Up @@ -252,19 +252,19 @@ def _assert_series_nan_values_match(left: Series, right: Series) -> None:
)


def _comparing_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_floats(left: DataType, right: DataType) -> bool:
return left.is_float() and right.is_float()


def _comparing_lists(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_lists(left: DataType, right: DataType) -> bool:
return left in (List, Array) and right in (List, Array)


def _comparing_structs(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_structs(left: DataType, right: DataType) -> bool:
return left == Struct and right == Struct


def _comparing_nested_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_nested_floats(left: DataType, right: DataType) -> bool:
if not (_comparing_lists(left, right) or _comparing_structs(left, right)):
return False

Expand Down
Loading