diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index c72beb2e4c44..75bd1e7a2ec0 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -126,7 +126,7 @@ import deltalake from xlsxwriter import Workbook - from polars import Expr, LazyFrame, Series + from polars import DataType, Expr, LazyFrame, Series from polars.interchange.dataframe import PolarsDataFrame from polars.type_aliases import ( AsofJoinStrategy, @@ -1206,7 +1206,7 @@ def columns(self, names: Sequence[str]) -> None: self._df.set_column_names(names) @property - def dtypes(self) -> list[PolarsDataType]: + def dtypes(self) -> list[DataType]: """ Get the datatypes of the columns of this DataFrame. @@ -1255,7 +1255,7 @@ def flags(self) -> dict[str, dict[str, bool]]: return {name: self[name].flags for name in self.columns} @property - def schema(self) -> SchemaDict: + def schema(self) -> OrderedDict[str, DataType]: """ Get a dict[column name, DataType]. diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index e5f86de5e187..6a4124c9dc5c 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -77,19 +77,24 @@ def is_nested(self) -> bool: # noqa: D102 class DataType(metaclass=DataTypeClass): """Base class for all Polars data types.""" - def __new__(cls, *args: Any, **kwargs: Any) -> PolarsDataType: # type: ignore[misc] # noqa: D102 - # this formulation allows for equivalent use of "pl.Type" and "pl.Type()", while - # still respecting types that take initialisation params (eg: Duration/Datetime) - if args or kwargs: - return super().__new__(cls) - return cls - def __reduce__(self) -> Any: return (_custom_reconstruct, (type(self), object, None), self.__dict__) def _string_repr(self) -> str: return _dtype_str_repr(self) + def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] + if type(other) is DataTypeClass: + return issubclass(other, type(self)) + else: + return isinstance(other, type(self)) + + def __hash__(self) -> int: + return hash(self.__class__) + + def __repr__(self) -> str: + return self.__class__.__name__ + @classmethod def base_type(cls) -> DataTypeClass: """ diff --git a/py-polars/polars/io/delta.py b/py-polars/polars/io/delta.py index 8521bd05fef4..2f0a21f8f191 100644 --- a/py-polars/polars/io/delta.py +++ b/py-polars/polars/io/delta.py @@ -12,8 +12,7 @@ from polars.io.pyarrow_dataset import scan_pyarrow_dataset if TYPE_CHECKING: - from polars import DataFrame, LazyFrame - from polars.type_aliases import PolarsDataType + from polars import DataFrame, DataType, LazyFrame def read_delta( @@ -320,7 +319,7 @@ def _check_if_delta_available() -> None: ) -def _check_for_unsupported_types(dtypes: list[PolarsDataType]) -> None: +def _check_for_unsupported_types(dtypes: list[DataType]) -> None: schema_dtypes = unpack_dtypes(*dtypes) unsupported_types = {Time, Categorical, Null} overlap = schema_dtypes & unsupported_types diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index f426bbeae2ae..3d520b5cc388 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -15,8 +15,7 @@ if TYPE_CHECKING: from io import BytesIO - from polars import DataFrame, LazyFrame - from polars.type_aliases import PolarsDataType + from polars import DataFrame, DataType, LazyFrame def read_ipc( @@ -185,7 +184,7 @@ def read_ipc_stream( ) -def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, PolarsDataType]: +def read_ipc_schema(source: str | BinaryIO | Path | bytes) -> dict[str, DataType]: """ Get the schema of an IPC file without reading data. diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index 59554a587b1a..cfce8e1d085c 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -16,8 +16,8 @@ if TYPE_CHECKING: from io import BytesIO - from polars import DataFrame, LazyFrame - from polars.type_aliases import ParallelStrategy, PolarsDataType + from polars import DataFrame, DataType, LazyFrame + from polars.type_aliases import ParallelStrategy def read_parquet( @@ -143,7 +143,7 @@ def read_parquet( def read_parquet_schema( source: str | BinaryIO | Path | bytes, -) -> dict[str, PolarsDataType]: +) -> dict[str, DataType]: """ Get the schema of a Parquet file without reading data. diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 5dc7e57ae3ae..9ec5813c164d 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -87,7 +87,7 @@ import pyarrow as pa - from polars import DataFrame, Expr + from polars import DataFrame, DataType, Expr from polars.dependencies import numpy as np from polars.type_aliases import ( AsofJoinStrategy, @@ -693,7 +693,7 @@ def columns(self) -> list[str]: return self._ldf.columns() @property - def dtypes(self) -> list[PolarsDataType]: + def dtypes(self) -> list[DataType]: """ Get dtypes of columns in LazyFrame. @@ -717,7 +717,7 @@ def dtypes(self) -> list[PolarsDataType]: return self._ldf.dtypes() @property - def schema(self) -> SchemaDict: + def schema(self) -> OrderedDict[str, DataType]: """ Get a dict[column name, DataType]. diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 38f63002c174..19d588298023 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -85,7 +85,7 @@ def median(self) -> dt.date | dt.datetime | dt.timedelta | None: if s.dtype == Date: return _to_python_date(int(out)) else: - return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr] + return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined] return None def mean(self) -> dt.date | dt.datetime | None: @@ -108,7 +108,7 @@ def mean(self) -> dt.date | dt.datetime | None: if s.dtype == Date: return _to_python_date(int(out)) else: - return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[union-attr] + return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[attr-defined] return None def to_string(self, format: str) -> Series: diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index d2ec0e1af915..b7a9e5f8f0a6 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -112,7 +112,7 @@ if TYPE_CHECKING: import sys - from polars import DataFrame, Expr + from polars import DataFrame, DataType, Expr from polars.series._numpy import SeriesView from polars.type_aliases import ( ClosedInterval, @@ -365,7 +365,7 @@ def _get_ptr(self) -> tuple[int, int, int]: return self._s.get_ptr() @property - def dtype(self) -> PolarsDataType: + def dtype(self) -> DataType: """ Get the data type of this Series. @@ -398,10 +398,13 @@ def flags(self) -> dict[str, bool]: return out @property - def inner_dtype(self) -> PolarsDataType | None: + def inner_dtype(self) -> DataType | None: """ Get the inner dtype in of a List typed Series. + .. deprecated:: 0.19.14 + Use `Series.dtype.inner` instead. + Returns ------- DataType @@ -412,7 +415,7 @@ def inner_dtype(self) -> PolarsDataType | None: version="0.19.14", ) try: - return self.dtype.inner # type: ignore[union-attr] + return self.dtype.inner # type: ignore[attr-defined] except AttributeError: return None @@ -502,12 +505,12 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: time_unit = "us" elif self.dtype == Datetime: # Use local time zone info - time_zone = self.dtype.time_zone # type: ignore[union-attr] + time_zone = self.dtype.time_zone # type: ignore[attr-defined] if str(other.tzinfo) != str(time_zone): raise TypeError( f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}" ) - time_unit = self.dtype.time_unit # type: ignore[union-attr] + time_unit = self.dtype.time_unit # type: ignore[attr-defined] else: raise ValueError( f"cannot compare datetime.datetime to Series of type {self.dtype}" @@ -524,7 +527,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: return self._from_pyseries(f(d)) elif isinstance(other, timedelta) and self.dtype == Duration: - time_unit = self.dtype.time_unit # type: ignore[union-attr] + time_unit = self.dtype.time_unit # type: ignore[attr-defined] td = _timedelta_to_pl_timedelta(other, time_unit) # type: ignore[arg-type] f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None @@ -4051,9 +4054,9 @@ def convert_to_date(arr: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: if self.dtype == Date: tp = "datetime64[D]" elif self.dtype == Duration: - tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[union-attr] + tp = f"timedelta64[{self.dtype.time_unit}]" # type: ignore[attr-defined] else: - tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[union-attr] + tp = f"datetime64[{self.dtype.time_unit}]" # type: ignore[attr-defined] return arr.astype(tp) def raise_no_zero_copy() -> None: @@ -4066,7 +4069,7 @@ def raise_no_zero_copy() -> None: writable=writable, use_pyarrow=use_pyarrow, ) - np_array.shape = (self.len(), self.dtype.width) # type: ignore[union-attr] + np_array.shape = (self.len(), self.dtype.width) # type: ignore[attr-defined] return np_array if ( @@ -6972,7 +6975,7 @@ def is_boolean(self) -> bool: True """ - return self.dtype is Boolean + return self.dtype == Boolean @deprecate_function("Use `Series.dtype == pl.Utf8` instead.", version="0.19.14") def is_utf8(self) -> bool: @@ -6989,7 +6992,7 @@ def is_utf8(self) -> bool: True """ - return self.dtype is Utf8 + return self.dtype == Utf8 @deprecate_renamed_function("gather_every", version="0.19.14") def take_every(self, n: int) -> Series: diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index 6af6baa7f8c0..a12613ed117f 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -9,9 +9,8 @@ from polars.utils.various import sphinx_accessor if TYPE_CHECKING: - from polars import DataFrame, Series + from polars import DataFrame, DataType, Series from polars.polars import PySeries - from polars.type_aliases import SchemaDict elif os.getenv("BUILDING_SPHINX_DOCS"): property = sphinx_accessor @@ -66,10 +65,10 @@ def rename_fields(self, names: Sequence[str]) -> Series: """ @property - def schema(self) -> SchemaDict: + def schema(self) -> OrderedDict[str, DataType]: """Get the struct definition as a name/dtype schema dict.""" if getattr(self, "_s", None) is None: - return {} + return OrderedDict() return OrderedDict(self._s.dtype().to_schema()) def unnest(self) -> DataFrame: diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index cec49db44b89..84fcccf14c39 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -18,7 +18,7 @@ from polars.testing.asserts.utils import raise_assertion_error if TYPE_CHECKING: - from polars.type_aliases import PolarsDataType + from polars import DataType def assert_series_equal( @@ -252,19 +252,19 @@ def _assert_series_nan_values_match(left: Series, right: Series) -> None: ) -def _comparing_floats(left: PolarsDataType, right: PolarsDataType) -> bool: +def _comparing_floats(left: DataType, right: DataType) -> bool: return left.is_float() and right.is_float() -def _comparing_lists(left: PolarsDataType, right: PolarsDataType) -> bool: +def _comparing_lists(left: DataType, right: DataType) -> bool: return left in (List, Array) and right in (List, Array) -def _comparing_structs(left: PolarsDataType, right: PolarsDataType) -> bool: +def _comparing_structs(left: DataType, right: DataType) -> bool: return left == Struct and right == Struct -def _comparing_nested_floats(left: PolarsDataType, right: PolarsDataType) -> bool: +def _comparing_nested_floats(left: DataType, right: DataType) -> bool: if not (_comparing_lists(left, right) or _comparing_structs(left, right)): return False diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index b40d30eae255..b733c1a40a4d 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -297,43 +297,78 @@ impl ToPyObject for Wrap { let pl = POLARS.as_ref(py); match &self.0 { - DataType::Int8 => pl.getattr(intern!(py, "Int8")).unwrap().into(), - DataType::Int16 => pl.getattr(intern!(py, "Int16")).unwrap().into(), - DataType::Int32 => pl.getattr(intern!(py, "Int32")).unwrap().into(), - DataType::Int64 => pl.getattr(intern!(py, "Int64")).unwrap().into(), - DataType::UInt8 => pl.getattr(intern!(py, "UInt8")).unwrap().into(), - DataType::UInt16 => pl.getattr(intern!(py, "UInt16")).unwrap().into(), - DataType::UInt32 => pl.getattr(intern!(py, "UInt32")).unwrap().into(), - DataType::UInt64 => pl.getattr(intern!(py, "UInt64")).unwrap().into(), - DataType::Float32 => pl.getattr(intern!(py, "Float32")).unwrap().into(), - DataType::Float64 => pl.getattr(intern!(py, "Float64")).unwrap().into(), + DataType::Int8 => { + let class = pl.getattr(intern!(py, "Int8")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Int16 => { + let class = pl.getattr(intern!(py, "Int16")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Int32 => { + let class = pl.getattr(intern!(py, "Int32")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Int64 => { + let class = pl.getattr(intern!(py, "Int64")).unwrap(); + class.call0().unwrap().into() + }, + DataType::UInt8 => { + let class = pl.getattr(intern!(py, "UInt8")).unwrap(); + class.call0().unwrap().into() + }, + DataType::UInt16 => { + let class = pl.getattr(intern!(py, "UInt16")).unwrap(); + class.call0().unwrap().into() + }, + DataType::UInt32 => { + let class = pl.getattr(intern!(py, "UInt32")).unwrap(); + class.call0().unwrap().into() + }, + DataType::UInt64 => { + let class = pl.getattr(intern!(py, "UInt64")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Float32 => { + let class = pl.getattr(intern!(py, "Float32")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Float64 => { + let class = pl.getattr(intern!(py, "Float64")).unwrap(); + class.call0().unwrap().into() + }, DataType::Decimal(precision, scale) => { - let kwargs = PyDict::new(py); - kwargs.set_item("precision", *precision).unwrap(); - kwargs.set_item("scale", *scale).unwrap(); - pl.getattr(intern!(py, "Decimal")) - .unwrap() - .call((), Some(kwargs)) - .unwrap() - .into() + let class = pl.getattr(intern!(py, "Decimal")).unwrap(); + let args = (*precision, *scale); + class.call1(args).unwrap().into() + }, + DataType::Boolean => { + let class = pl.getattr(intern!(py, "Boolean")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Utf8 => { + let class = pl.getattr(intern!(py, "Utf8")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Binary => { + let class = pl.getattr(intern!(py, "Binary")).unwrap(); + class.call0().unwrap().into() }, - DataType::Boolean => pl.getattr(intern!(py, "Boolean")).unwrap().into(), - DataType::Utf8 => pl.getattr(intern!(py, "Utf8")).unwrap().into(), - DataType::Binary => pl.getattr(intern!(py, "Binary")).unwrap().into(), DataType::Array(inner, size) => { + let class = pl.getattr(intern!(py, "Array")).unwrap(); let inner = Wrap(*inner.clone()).to_object(py); - let list_class = pl.getattr(intern!(py, "Array")).unwrap(); - let kwargs = PyDict::new(py); - kwargs.set_item("inner", inner).unwrap(); - kwargs.set_item("width", size).unwrap(); - list_class.call((), Some(kwargs)).unwrap().into() + let args = (inner, *size); + class.call1(args).unwrap().into() }, DataType::List(inner) => { + let class = pl.getattr(intern!(py, "List")).unwrap(); let inner = Wrap(*inner.clone()).to_object(py); - let list_class = pl.getattr(intern!(py, "List")).unwrap(); - list_class.call1((inner,)).unwrap().into() + class.call1((inner,)).unwrap().into() + }, + DataType::Date => { + let class = pl.getattr(intern!(py, "Date")).unwrap(); + class.call0().unwrap().into() }, - DataType::Date => pl.getattr(intern!(py, "Date")).unwrap().into(), DataType::Datetime(tu, tz) => { let datetime_class = pl.getattr(intern!(py, "Datetime")).unwrap(); datetime_class @@ -346,16 +381,20 @@ impl ToPyObject for Wrap { duration_class.call1((tu.to_ascii(),)).unwrap().into() }, #[cfg(feature = "object")] - DataType::Object(_) => pl.getattr(intern!(py, "Object")).unwrap().into(), + DataType::Object(_) => { + let class = pl.getattr(intern!(py, "Object")).unwrap(); + class.call0().unwrap().into() + }, DataType::Categorical(rev_map) => { if let Some(rev_map) = rev_map { if let RevMapping::Enum(categories, _) = &**rev_map { - let enum_dt = pl.getattr(intern!(py, "Enum")).unwrap(); + let class = pl.getattr(intern!(py, "Enum")).unwrap(); let ca = Utf8Chunked::from_iter(categories); - return enum_dt.call1((Wrap(&ca).to_object(py),)).unwrap().into(); + return class.call1((Wrap(&ca).to_object(py),)).unwrap().into(); } } - pl.getattr(intern!(py, "Categorical")).unwrap().into() + let class = pl.getattr(intern!(py, "Categorical")).unwrap(); + class.call0().unwrap().into() }, DataType::Time => pl.getattr(intern!(py, "Time")).unwrap().into(), DataType::Struct(fields) => { @@ -369,8 +408,14 @@ impl ToPyObject for Wrap { let struct_class = pl.getattr(intern!(py, "Struct")).unwrap(); struct_class.call1((fields,)).unwrap().into() }, - DataType::Null => pl.getattr(intern!(py, "Null")).unwrap().into(), - DataType::Unknown => pl.getattr(intern!(py, "Unknown")).unwrap().into(), + DataType::Null => { + let class = pl.getattr(intern!(py, "Null")).unwrap(); + class.call0().unwrap().into() + }, + DataType::Unknown => { + let class = pl.getattr(intern!(py, "Unknown")).unwrap(); + class.call0().unwrap().into() + }, } } } @@ -429,6 +474,18 @@ impl FromPyObject<'_> for Wrap { }, } }, + "Int8" => DataType::Int8, + "Int16" => DataType::Int16, + "Int32" => DataType::Int32, + "Int64" => DataType::Int64, + "UInt8" => DataType::UInt8, + "UInt16" => DataType::UInt16, + "UInt32" => DataType::UInt32, + "UInt64" => DataType::UInt64, + "Utf8" => DataType::Utf8, + "Binary" => DataType::Binary, + "Boolean" => DataType::Boolean, + "Categorical" => DataType::Categorical(None), "Enum" => { let categories = ob.getattr(intern!(py, "categories")).unwrap(); let categories = categories.extract::>()?.0; @@ -436,6 +493,12 @@ impl FromPyObject<'_> for Wrap { let arr = arr.as_any().downcast_ref::>().unwrap(); create_enum_data_type(arr.clone()) }, + "Date" => DataType::Date, + "Time" => DataType::Time, + "Float32" => DataType::Float32, + "Float64" => DataType::Float64, + "Null" => DataType::Null, + "Unknown" => DataType::Unknown, "Duration" => { let time_unit = ob.getattr(intern!(py, "time_unit")).unwrap(); let time_unit = time_unit.extract::>()?.0; diff --git a/py-polars/tests/parametric/test_series.py b/py-polars/tests/parametric/test_series.py index e70ab742c8d7..27d4062afe76 100644 --- a/py-polars/tests/parametric/test_series.py +++ b/py-polars/tests/parametric/test_series.py @@ -139,7 +139,7 @@ def test_series_duration_timeunits( "us": 1_000, "ms": 1_000_000, } - assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[union-attr] + assert nanos == [v * scale[s.dtype.time_unit] for v in s.to_physical()] # type: ignore[attr-defined] assert micros == [int(v / 1_000) for v in nanos] assert millis == [int(v / 1_000) for v in micros] diff --git a/py-polars/tests/unit/datatypes/test_duration.py b/py-polars/tests/unit/datatypes/test_duration.py index 27bd042bb16b..e9db9940c5b5 100644 --- a/py-polars/tests/unit/datatypes/test_duration.py +++ b/py-polars/tests/unit/datatypes/test_duration.py @@ -16,7 +16,7 @@ def test_duration_cum_sum() -> None: pl.Duration(time_unit="ms"), pl.Duration(time_unit="ns"), ): - assert df.schema["A"].is_(duration_dtype) is False # type: ignore[arg-type] + assert df.schema["A"].is_(duration_dtype) is False def test_duration_std_var() -> None: diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 3ef49b524c4a..bda21e661f15 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -18,7 +18,7 @@ def test_dtype() -> None: # inferred a = pl.Series("a", [[1, 2, 3], [2, 5], [6, 7, 8, 9]]) assert a.dtype == pl.List - assert a.dtype.inner == pl.Int64 # type: ignore[union-attr] + assert a.dtype.inner == pl.Int64 # type: ignore[attr-defined] assert a.dtype.is_(pl.List(pl.Int64)) # explicit @@ -43,7 +43,7 @@ def test_dtype() -> None: "dtm": pl.List(pl.Datetime), } assert all(tp.is_nested() for tp in df.dtypes) - assert df.schema["i"].inner == pl.Int8 # type: ignore[union-attr] + assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined] assert df.rows() == [ ( [1, 2, 3], @@ -75,8 +75,8 @@ def test_categorical() -> None: .to_series(3) ) - assert out.dtype.inner == pl.Categorical # type: ignore[union-attr] - assert out.dtype.inner.is_nested() is False # type: ignore[union-attr] + assert out.dtype.inner == pl.Categorical # type: ignore[attr-defined] + assert out.dtype.inner.is_nested() is False # type: ignore[attr-defined] def test_cast_inner() -> None: @@ -89,7 +89,7 @@ def test_cast_inner() -> None: # this creates an inner null type df = pl.from_pandas(pd.DataFrame(data=[[[]], [[]]], columns=["A"])) assert ( - df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore[union-attr] + df["A"].cast(pl.List(int)).dtype.inner == pl.Int64 # type: ignore[attr-defined] ) @@ -192,7 +192,7 @@ def test_local_categorical_list() -> None: values = [["a", "b"], ["c"], ["a", "d", "d"]] s = pl.Series(values, dtype=pl.List(pl.Categorical)) assert s.dtype == pl.List - assert s.dtype.inner == pl.Categorical # type: ignore[union-attr] + assert s.dtype.inner == pl.Categorical # type: ignore[attr-defined] assert s.to_list() == values # Check that underlying physicals match diff --git a/py-polars/tests/unit/functions/range/test_datetime_range.py b/py-polars/tests/unit/functions/range/test_datetime_range.py index 3f50616f9a28..0068afd45396 100644 --- a/py-polars/tests/unit/functions/range/test_datetime_range.py +++ b/py-polars/tests/unit/functions/range/test_datetime_range.py @@ -37,7 +37,7 @@ def test_datetime_range() -> None: time_unit=time_unit, eager=True, ) - assert rng.dtype.time_unit == time_unit # type: ignore[union-attr] + assert rng.dtype.time_unit == time_unit # type: ignore[attr-defined] assert rng.shape == (13,) assert rng.dt[0] == datetime(2020, 1, 1) assert rng.dt[-1] == datetime(2020, 1, 2) @@ -67,7 +67,7 @@ def test_datetime_range() -> None: datetime(2022, 1, 1), datetime(2022, 1, 1, 0, 1), "987456321ns", eager=True ) assert len(result) == 61 - assert result.dtype.time_unit == "ns" # type: ignore[union-attr] + assert result.dtype.time_unit == "ns" # type: ignore[attr-defined] assert result.dt.second()[-1] == 59 assert result.cast(pl.Utf8)[-1] == "2022-01-01 00:00:59.247379260" diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index fa3e92c1f85c..cc30fcfe9ffc 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -698,7 +698,7 @@ def test_from_null_column() -> None: assert df.shape == (2, 1) assert df.columns == ["n/a"] - assert df.dtypes[0] is pl.Null + assert df.dtypes[0] == pl.Null def test_to_pandas_series() -> None: diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 6a7161fd9fe2..d71dea5d0374 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -130,7 +130,7 @@ def test_ipc_schema(compression: IpcCompression) -> None: df.write_ipc(f, compression=compression) f.seek(0) - expected = {"a": pl.Int64, "b": pl.Utf8, "c": pl.Boolean} + expected = {"a": pl.Int64(), "b": pl.Utf8(), "c": pl.Boolean()} assert pl.read_ipc_schema(f) == expected @@ -152,18 +152,18 @@ def test_ipc_schema_from_file( schema = pl.read_ipc_schema(file_path) expected = { - "bools": pl.Boolean, - "bools_nulls": pl.Boolean, - "int": pl.Int64, - "int_nulls": pl.Int64, - "floats": pl.Float64, - "floats_nulls": pl.Float64, - "strings": pl.Utf8, - "strings_nulls": pl.Utf8, - "date": pl.Date, - "datetime": pl.Datetime, - "time": pl.Time, - "cat": pl.Categorical, + "bools": pl.Boolean(), + "bools_nulls": pl.Boolean(), + "int": pl.Int64(), + "int_nulls": pl.Int64(), + "floats": pl.Float64(), + "floats_nulls": pl.Float64(), + "strings": pl.Utf8(), + "strings_nulls": pl.Utf8(), + "date": pl.Date(), + "datetime": pl.Datetime(), + "time": pl.Time(), + "cat": pl.Categorical(), } assert schema == expected diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index aa34495621f7..4b86094e58f4 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -142,8 +142,8 @@ def test_init_inputs(monkeypatch: Any) -> None: s = pl.Series([date(2023, 1, 1), date(2023, 1, 2)], dtype=pl.Datetime) assert s.to_list() == [datetime(2023, 1, 1), datetime(2023, 1, 2)] assert Datetime == s.dtype - assert s.dtype.time_unit == "us" # type: ignore[union-attr] - assert s.dtype.time_zone is None # type: ignore[union-attr] + assert s.dtype.time_unit == "us" # type: ignore[attr-defined] + assert s.dtype.time_zone is None # type: ignore[attr-defined] # conversion of Date to Datetime with specified timezone and units tu: TimeUnit = "ms" @@ -153,8 +153,8 @@ def test_init_inputs(monkeypatch: Any) -> None: d2 = datetime(2023, 1, 2, 0, 0, 0, 0, ZoneInfo(tz)) assert s.to_list() == [d1, d2] assert Datetime == s.dtype - assert s.dtype.time_unit == tu # type: ignore[union-attr] - assert s.dtype.time_zone == tz # type: ignore[union-attr] + assert s.dtype.time_unit == tu # type: ignore[attr-defined] + assert s.dtype.time_zone == tz # type: ignore[attr-defined] # datetime64: check timeunit (auto-detect, implicit/explicit) and NaT d64 = pd.date_range(date(2021, 8, 1), date(2021, 8, 3)).values @@ -165,10 +165,10 @@ def test_init_inputs(monkeypatch: Any) -> None: s = pl.Series("dates", d64, dtype) assert s.to_list() == expected assert Datetime == s.dtype - assert s.dtype.time_unit == "ns" # type: ignore[union-attr] + assert s.dtype.time_unit == "ns" # type: ignore[attr-defined] s = pl.Series(values=d64.astype(" None: - # check "DataType.__new__" behaviour for all datatypes - all_datatypes = { - dtype - for dtype in (getattr(datatypes, attr) for attr in dir(datatypes)) - if isinstance(dtype, DataTypeClass) +if TYPE_CHECKING: + from polars.datatypes import DataTypeClass + +SIMPLE_DTYPES: list[DataTypeClass] = list( + pl.INTEGER_DTYPES # type: ignore[arg-type] + | pl.FLOAT_DTYPES + | { + pl.Boolean, + pl.Utf8, + pl.Binary, + pl.Time, + pl.Date, + pl.Categorical, + pl.Object, + pl.Null, + pl.Unknown, } - for dtype in all_datatypes: - assert dtype == dtype() +) + + +def test_simple_dtype_init_takes_no_args() -> None: + for dtype in SIMPLE_DTYPES: + with pytest.raises(TypeError): + dtype(10) + + +def test_simple_dtype_init_returns_instance() -> None: + dtype = pl.Int8() + assert isinstance(dtype, pl.Int8) + + +def test_complex_dtype_init_returns_instance() -> None: + dtype = pl.Datetime() + assert isinstance(dtype, pl.Datetime) + assert dtype.time_unit == "us" def test_dtype_temporal_units() -> None: @@ -36,8 +61,8 @@ def test_dtype_temporal_units() -> None: assert pl.Datetime == pl.Datetime(time_unit) assert pl.Duration == pl.Duration(time_unit) - assert pl.Datetime(time_unit) == pl.Datetime() - assert pl.Duration(time_unit) == pl.Duration() + assert pl.Datetime(time_unit) == pl.Datetime + assert pl.Duration(time_unit) == pl.Duration assert pl.Datetime("ms") != pl.Datetime("ns") assert pl.Duration("ns") != pl.Duration("us") diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 3efcbdcd5f7a..71511ad19dc4 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -188,7 +188,7 @@ def test_err_bubbling_up_to_lit() -> None: df = pl.DataFrame({"date": [date(2020, 1, 1)], "value": [42]}) with pytest.raises(TypeError): - df.filter(pl.col("date") == pl.Date("2020-01-01")) + df.filter(pl.col("date") == pl.Date("2020-01-01")) # type: ignore[call-arg] def test_error_on_double_agg() -> None: