Skip to content

Commit

Permalink
Update for polars 1.x and fix some hacks
Browse files Browse the repository at this point in the history
Signed-off-by: Devin Petersohn <[email protected]>
  • Loading branch information
sfc-gh-dpetersohn committed Jul 18, 2024
1 parent 64a91d7 commit 0019245
Show file tree
Hide file tree
Showing 3 changed files with 500 additions and 256 deletions.
58 changes: 9 additions & 49 deletions modin/polars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,25 +205,6 @@ def to_arrow(self):
"""
return polars.from_pandas(self._query_compiler.to_pandas()).to_arrow()

def to_dict(
self, *, as_series: bool = True
) -> dict[str, "Series"] | dict[str, list[Any]]:
"""
Convert the DataFrame to a dictionary representation.
Args:
as_series: Whether to convert the columns to Series.
Returns:
Dictionary representation of the DataFrame.
"""
if as_series:
return {name: self[name] for name in self.columns}
else:
return polars.from_pandas(self._query_compiler.to_pandas()).to_dict(
as_series=as_series
)

def to_jax(self, device=None):
"""
Convert the DataFrame to JAX format.
Expand All @@ -238,20 +219,6 @@ def to_jax(self, device=None):
device=device
)

def to_list(self, *, use_pyarrow: bool | None = None) -> list[Any]:
"""
Convert the DataFrame to a list representation.
Args:
use_pyarrow: Whether to use PyArrow for conversion.
Returns:
List representation of the DataFrame.
"""
return polars.from_pandas(self._query_compiler.to_pandas()).to_list(
use_pyarrow=use_pyarrow
)

def to_numpy(
self,
*,
Expand Down Expand Up @@ -313,15 +280,6 @@ def cast(self, dtypes, *, strict: bool = True) -> "BasePolarsDataset":
# TODO: support strict
return self.__constructor__(_query_compiler=self._query_compiler.astype(dtypes))

def copy(self):
"""
Copy the DataFrame.
Returns:
Copied DataFrame.
"""
return self.__constructor__(_query_compiler=self._query_compiler.copy())

def clone(self) -> "BasePolarsDataset":
"""
Clone the DataFrame.
Expand All @@ -345,8 +303,6 @@ def drop_nulls(self, subset=None):
_query_compiler=self._query_compiler.dropna(subset=subset, how="any")
)

drop_nans = drop_nulls

def explode(self, columns: str, *more_columns: str) -> "BasePolarsDataset":
"""
Explode the given columns to long format.
Expand Down Expand Up @@ -528,9 +484,6 @@ def sample(
def shift(self, n: int = 1, *, fill_value=None) -> "DataFrame":
raise NotImplementedError("not yet")

def shift_and_fill(self, fill_value=None, *, n: int = 1) -> "DataFrame":
return self.shift(n=n, fill_value=fill_value)

def shrink_to_fit(self) -> "DataFrame":
"""
Shrink the DataFrame to fit in memory.
Expand Down Expand Up @@ -605,8 +558,6 @@ def tail(self, n: int = 5) -> "DataFrame":
_query_compiler=self._query_compiler.getitem_row_array(slice(-n, None))
)

take_every = gather_every

def to_dummies(
self,
columns: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -706,3 +657,12 @@ def equals(self, other: "BasePolarsDataset", *, null_equal: bool = True) -> bool
@property
def plot(self):
return polars.from_pandas(self._query_compiler.to_pandas()).plot

def count(self):
"""
Get the number of non-null values in each column.
Returns:
DataFrame with the counts.
"""
return self.__constructor__(_query_compiler=self._query_compiler.count(axis=0))
Loading

0 comments on commit 0019245

Please sign in to comment.