Skip to content

Commit

Permalink
Narrow type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Nov 15, 2023
1 parent a3e222d commit 3f51fa0
Show file tree
Hide file tree
Showing 15 changed files with 58 additions and 61 deletions.
6 changes: 3 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
import deltalake
from xlsxwriter import Workbook

from polars import Expr, LazyFrame, Series
from polars import DataType, Expr, LazyFrame, Series
from polars.interchange.dataframe import PolarsDataFrame
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -1198,7 +1198,7 @@ def columns(self, names: Sequence[str]) -> None:
self._df.set_column_names(names)

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get the datatypes of the columns of this DataFrame.
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def flags(self) -> dict[str, dict[str, bool]]:
return {name: self[name].flags for name in self.columns}

@property
def schema(self) -> SchemaDict:
def schema(self) -> Mapping[str, DataType]:
"""
Get a dict[column name, DataType].
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from polars.io.pyarrow_dataset import scan_pyarrow_dataset

if TYPE_CHECKING:
from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_delta(
Expand Down Expand Up @@ -320,7 +319,7 @@ def _check_if_delta_available() -> None:
)


def _check_for_unsupported_types(dtypes: list[PolarsDataType]) -> None:
def _check_for_unsupported_types(dtypes: list[DataType]) -> None:
schema_dtypes = unpack_dtypes(*dtypes)
unsupported_types = {Time, Categorical, Null}
overlap = schema_dtypes & unsupported_types
Expand Down
5 changes: 2 additions & 3 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import PolarsDataType
from polars import DataFrame, DataType, LazyFrame


def read_ipc(
Expand Down Expand Up @@ -185,7 +184,7 @@ def read_ipc_stream(
)


def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, PolarsDataType]:
def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, DataType]:
"""
Get the schema of an IPC file without reading data.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/io/parquet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
if TYPE_CHECKING:
from io import BytesIO

from polars import DataFrame, LazyFrame
from polars.type_aliases import ParallelStrategy, PolarsDataType
from polars import DataFrame, DataType, LazyFrame
from polars.type_aliases import ParallelStrategy


def read_parquet(
Expand Down Expand Up @@ -143,7 +143,7 @@ def read_parquet(

def read_parquet_schema(
source: str | BinaryIO | Path | bytes,
) -> dict[str, PolarsDataType]:
) -> dict[str, DataType]:
"""
Get the schema of a Parquet file without reading data.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@

import pyarrow as pa

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.dependencies import numpy as np
from polars.type_aliases import (
AsofJoinStrategy,
Expand Down Expand Up @@ -693,7 +693,7 @@ def columns(self) -> list[str]:
return self._ldf.columns()

@property
def dtypes(self) -> list[PolarsDataType]:
def dtypes(self) -> list[DataType]:
"""
Get dtypes of columns in LazyFrame.
Expand All @@ -717,7 +717,7 @@ def dtypes(self) -> list[PolarsDataType]:
return self._ldf.dtypes()

@property
def schema(self) -> SchemaDict:
def schema(self) -> Mapping[str, DataType]:
"""
Get a dict[column name, DataType].
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/series/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def median(self) -> dt.date | dt.datetime | dt.timedelta | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def mean(self) -> dt.date | dt.datetime | None:
Expand All @@ -108,7 +108,7 @@ def mean(self) -> dt.date | dt.datetime | None:
if s.dtype == Date:
return _to_python_date(int(out))
else:
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr]
return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined]
return None

def to_string(self, format: str) -> Series:
Expand Down
16 changes: 8 additions & 8 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
if TYPE_CHECKING:
import sys

from polars import DataFrame, Expr
from polars import DataFrame, DataType, Expr
from polars.series._numpy import SeriesView
from polars.type_aliases import (
ClosedInterval,
Expand Down Expand Up @@ -361,7 +361,7 @@ def _get_ptr(self) -> tuple[int, int, int]:
return self._s.get_ptr()

@property
def dtype(self) -> PolarsDataType:
def dtype(self) -> DataType:
"""
Get the data type of this Series.
Expand Down Expand Up @@ -394,7 +394,7 @@ def flags(self) -> dict[str, bool]:
return out

@property
def inner_dtype(self) -> PolarsDataType | None:
def inner_dtype(self) -> DataType | None:
"""
Get the inner dtype in of a List typed Series.
Expand Down Expand Up @@ -489,12 +489,12 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
time_unit = "us"
elif self.dtype == Datetime:
# Use local time zone info
time_zone = self.dtype.time_zone # type: ignore[union-attr]
time_zone = self.dtype.time_zone # type: ignore[attr-defined]
if str(other.tzinfo) != str(time_zone):
raise TypeError(
f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}"
)
time_unit = self.dtype.time_unit # type: ignore[union-attr]
time_unit = self.dtype.time_unit # type: ignore[attr-defined]
else:
raise ValueError(
f"cannot compare datetime.datetime to Series of type {self.dtype}"
Expand Down Expand Up @@ -4047,9 +4047,9 @@ def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
if self.dtype == Date:
tp = "datetime64[D]"
elif self.dtype == Duration:
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
else:
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[union-attr]
tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[attr-defined]
return arr.astype(tp)

def raise_no_zero_copy() -> None:
Expand All @@ -4062,7 +4062,7 @@ def raise_no_zero_copy() -> None:
writable=writable,
use_pyarrow=use_pyarrow,
)
np_array.shape = (self.len(), self.dtype.width) # type: ignore[union-attr]
np_array.shape = (self.len(), self.dtype.width) # type: ignore[attr-defined]
return np_array

if (
Expand Down
7 changes: 3 additions & 4 deletions py-polars/polars/series/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Mapping, Sequence

from polars.series.utils import expr_dispatch
from polars.utils._wrap import wrap_df
from polars.utils.various import sphinx_accessor

if TYPE_CHECKING:
from polars import DataFrame, Series
from polars import DataFrame, DataType, Series
from polars.polars import PySeries
from polars.type_aliases import SchemaDict
elif os.getenv("BUILDING_SPHINX_DOCS"):
property = sphinx_accessor

Expand Down Expand Up @@ -66,7 +65,7 @@ def rename_fields(self, names: Sequence[str]) -> Series:
"""

@property
def schema(self) -> SchemaDict:
def schema(self) -> Mapping[str, DataType]:
"""Get the struct definition as a name/dtype schema dict."""
if getattr(self, "_s", None) is None:
return {}
Expand Down
10 changes: 5 additions & 5 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from polars.utils.deprecation import issue_deprecation_warning

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
from polars import DataType


def assert_series_equal(
Expand Down Expand Up @@ -294,19 +294,19 @@ def _assert_series_nan_values_match(
)


def _comparing_floats(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_floats(left: DataType, right: DataType) -> bool:
return left.is_float() and right.is_float()


def _comparing_lists(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_lists(left: DataType, right: DataType) -> bool:
return left in (List, Array) and right in (List, Array)


def _comparing_structs(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_structs(left: DataType, right: DataType) -> bool:
return left == Struct and right == Struct


def _comparing_nested_numerics(left: PolarsDataType, right: PolarsDataType) -> bool:
def _comparing_nested_numerics(left: DataType, right: DataType) -> bool:
if not (_comparing_lists(left, right) or _comparing_structs(left, right)):
return False

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/parametric/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_series_duration_timeunits(
"us": 1_000,
"ms": 1_000_000,
}
assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[union-attr]
assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[attr-defined]
assert micros == [int(v / 1_000) for v in nanos]
assert millis == [int(v / 1_000) for v in micros]

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def test_duration_cumsum() -> None:
pl.Duration(time_unit="ms"),
pl.Duration(time_unit="ns"),
):
assert df.schema["A"].is_(duration_dtype) is False # type: ignore[arg-type]
assert df.schema["A"].is_(duration_dtype) is False
6 changes: 3 additions & 3 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_dtype() -> None:
a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]])
assert a.dtype == pl.List
assert a.inner_dtype == pl.Int64
assert a.dtype.inner == pl.Int64 # type: ignore[union-attr]
assert a.dtype.inner == pl.Int64 # type: ignore[attr-defined]
assert a.dtype.is_(pl.List(pl.Int64))

# explicit
Expand All @@ -44,7 +44,7 @@ def test_dtype() -> None:
"dtm": pl.List(pl.Datetime),
}
assert all(tp in pl.NESTED_DTYPES for tp in df.dtypes)
assert df.schema["i"].inner == pl.Int8 # type: ignore[union-attr]
assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined]
assert df.rows() == [
(
[1, 2, 3],
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_cast_inner() -> None:
# this creates an inner null type
df = pl.from_pandas(pd.DataFrame(data=[[[]], [[]]], columns=["A"]))
assert (
df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore[union-attr]
df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore[attr-defined]
)


Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/functions/range/test_datetime_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_datetime_range() -> None:
time_unit=time_unit,
eager=True,
)
assert rng.dtype.time_unit == time_unit # type: ignore[union-attr]
assert rng.dtype.time_unit == time_unit # type: ignore[attr-defined]
assert rng.shape == (13,)
assert rng.dt[0] == datetime(2020, 1, 1)
assert rng.dt[-1] == datetime(2020, 1, 2)
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_datetime_range() -> None:
datetime(2022, 1, 1), datetime(2022, 1, 1, 0, 1), "987456321ns", eager=True
)
assert len(result) == 61
assert result.dtype.time_unit == "ns" # type: ignore[union-attr]
assert result.dtype.time_unit == "ns" # type: ignore[attr-defined]
assert result.dt.second()[-1] == 59
assert result.cast(pl.Utf8)[-1] == "2022-01-01 00:00:59.247379260"

Expand Down
26 changes: 13 additions & 13 deletions py-polars/tests/unit/io/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_ipc_schema(compression: IpcCompression) -> None:
df.write_ipc(f, compression=compression)
f.seek(0)

expected = {"a": pl.Int64, "b": pl.Utf8, "c": pl.Boolean}
expected = {"a": pl.Int64(), "b": pl.Utf8(), "c": pl.Boolean()}
assert pl.read_ipc_schema(f) == expected


Expand All @@ -152,18 +152,18 @@ def test_ipc_schema_from_file(
schema = pl.read_ipc_schema(file_path)

expected = {
"bools": pl.Boolean,
"bools_nulls": pl.Boolean,
"int": pl.Int64,
"int_nulls": pl.Int64,
"floats": pl.Float64,
"floats_nulls": pl.Float64,
"strings": pl.Utf8,
"strings_nulls": pl.Utf8,
"date": pl.Date,
"datetime": pl.Datetime,
"time": pl.Time,
"cat": pl.Categorical,
"bools": pl.Boolean(),
"bools_nulls": pl.Boolean(),
"int": pl.Int64(),
"int_nulls": pl.Int64(),
"floats": pl.Float64(),
"floats_nulls": pl.Float64(),
"strings": pl.Utf8(),
"strings_nulls": pl.Utf8(),
"date": pl.Date(),
"datetime": pl.Datetime(),
"time": pl.Time(),
"cat": pl.Categorical(),
}
assert schema == expected

Expand Down
14 changes: 7 additions & 7 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def test_init_inputs(monkeypatch: Any) -> None:
s = pl.Series([date(2023, 1, 1), date(2023, 1, 2)], dtype=pl.Datetime)
assert s.to_list() == [datetime(2023, 1, 1), datetime(2023, 1, 2)]
assert Datetime == s.dtype
assert s.dtype.time_unit == "us" # type: ignore[union-attr]
assert s.dtype.time_zone is None # type: ignore[union-attr]
assert s.dtype.time_unit == "us" # type: ignore[attr-defined]
assert s.dtype.time_zone is None # type: ignore[attr-defined]

# conversion of Date to Datetime with specified timezone and units
tu: TimeUnit = "ms"
Expand All @@ -136,8 +136,8 @@ def test_init_inputs(monkeypatch: Any) -> None:
d2 = datetime(2023, 1, 2, 0, 0, 0, 0, ZoneInfo(tz))
assert s.to_list() == [d1, d2]
assert Datetime == s.dtype
assert s.dtype.time_unit == tu # type: ignore[union-attr]
assert s.dtype.time_zone == tz # type: ignore[union-attr]
assert s.dtype.time_unit == tu # type: ignore[attr-defined]
assert s.dtype.time_zone == tz # type: ignore[attr-defined]

# datetime64: check timeunit (auto-detect, implicit/explicit) and NaT
d64 = pd.date_range(date(2021, 8, 1), date(2021, 8, 3)).values
Expand All @@ -148,10 +148,10 @@ def test_init_inputs(monkeypatch: Any) -> None:
s = pl.Series("dates", d64, dtype)
assert s.to_list() == expected
assert Datetime == s.dtype
assert s.dtype.time_unit == "ns" # type: ignore[union-attr]
assert s.dtype.time_unit == "ns" # type: ignore[attr-defined]

s = pl.Series(values=d64.astype("<M8[ms]"))
assert s.dtype.time_unit == "ms" # type: ignore[union-attr]
assert s.dtype.time_unit == "ms" # type: ignore[attr-defined]
assert expected == s.to_list()

# pandas
Expand Down Expand Up @@ -204,7 +204,7 @@ class TeaShipmentPD(pydantic.BaseModel):
s = pl.Series("t", [t0, t1, t2])

assert isinstance(s, pl.Series)
assert s.dtype.fields == [ # type: ignore[union-attr]
assert s.dtype.fields == [ # type: ignore[attr-defined]
Field("exporter", pl.Utf8),
Field("importer", pl.Utf8),
Field("product", pl.Utf8),
Expand Down

0 comments on commit 3f51fa0

Please sign in to comment.