Skip to content

Commit

Permalink
feat: add when then chaining back
Browse files Browse the repository at this point in the history
  • Loading branch information
aivanoved committed Sep 11, 2024
1 parent ad5a50a commit 09bea00
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 178 deletions.
260 changes: 157 additions & 103 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def __init__(
self._then_value = then_value
self._otherwise_value = otherwise_value

self._already_set = self._condition

def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
Expand Down Expand Up @@ -372,114 +374,166 @@ def __init__(
self._root_names = root_names
self._output_names = output_names

def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasLikeExpr:
def otherwise(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `PandasWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self

def when(self, *predicates: IntoPandasLikeExpr) -> PandasChainedWhen:
plx = PandasLikeNamespace(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)
return PandasChainedWhen(
self, condition, self._depth + 1, self._implementation, self._backend_version
)


class PandasChainedWhen:
def __init__(
self,
above_then: PandasThen | PandasChainedThen,
condition: PandasLikeExpr,
depth: int,
implementation: Implementation,
backend_version: tuple[int, ...],
then_value: Any = None,
otherise_value: Any = None,
) -> None:
self._implementation = implementation
self._depth = depth
self._backend_version = backend_version
self._condition = condition
self._above_then = above_then
self._then_value = then_value
self._otherwise_value = otherise_value

# TODO @aivanoved: this is way slow as during computation time this takes
# quadratic time need to improve this to linear time
self._above_already_set = self._above_then._call._already_set # type: ignore[attr-defined]
self._already_set = self._above_already_set | self._condition

# class PandasChainedWhen:
# def __init__(
# self,
# above_when: PandasWhen | PandasChainedWhen,
# condition: PandasLikeExpr,
# depth: int,
# implementation: Implementation,
# backend_version: tuple[int, ...],
# then_value: Any = None,
# otherise_value: Any = None,
# ) -> None:
# self._implementation = implementation
# self._depth = depth
# self._backend_version = backend_version
# self._condition = condition
# self._above_when = above_when
# self._then_value = then_value
# self._otherwise_value = otherise_value
#
# # TODO @aivanoved: this is way slow as during computation time this takes
# # quadratic time need to improve this to linear time
# self._condition = self._condition & (~self._above_when._already_set) # type: ignore[has-type]
# self._already_set = self._above_when._already_set | self._condition # type: ignore[has-type]
#
# def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
# from narwhals._pandas_like.namespace import PandasLikeNamespace
#
# plx = PandasLikeNamespace(
# implementation=self._implementation, backend_version=self._backend_version
# )
#
# set_then = self._condition._call(df)[0]
# already_set = self._already_set._call(df)[0]
#
# value_series = plx._create_broadcast_series_from_scalar(
# self._then_value, set_then
# )
# otherwise_series = plx._create_broadcast_series_from_scalar(
# self._otherwise_value, set_then
# )
#
# above_result = self._above_when(df)[0]
#
# result = value_series.zip_with(set_then, above_result).zip_with(
# already_set, otherwise_series
# )
#
# return [result]
#
# def then(self, value: Any) -> PandasChainedThen:
# self._then_value = value
# return PandasChainedThen(
# self,
# depth=self._depth,
# implementation=self._implementation,
# function_name="chainedwhen",
# root_names=None,
# output_names=None,
# backend_version=self._backend_version,
# )
#
#
# class PandasChainedThen(PandasLikeExpr):
# def __init__(
# self,
# call: PandasChainedWhen,
# *,
# depth: int,
# function_name: str,
# root_names: list[str] | None,
# output_names: list[str] | None,
# implementation: Implementation,
# backend_version: tuple[int, ...],
# ) -> None:
# self._implementation = implementation
# self._backend_version = backend_version
#
# self._call = call
# self._depth = depth
# self._function_name = function_name
# self._root_names = root_names
# self._output_names = output_names
#
# def when(
# self,
# *predicates: IntoPandasLikeExpr | Iterable[IntoPandasLikeExpr],
# ) -> PandasChainedWhen:
# return PandasChainedWhen(
# self._call, # type: ignore[arg-type]
# when_processing(
# PandasLikeNamespace(self._implementation, self._backend_version),
# *predicates,
# ),
# depth=self._depth + 1,
# implementation=self._implementation,
# backend_version=self._backend_version,
# )
#
# def otherwise(self, value: Any) -> PandasLikeExpr:
# self._call._otherwise_value = value # type: ignore[attr-defined]
# self._function_name = "chainedwhenotherwise"
# return self
def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace

plx = PandasLikeNamespace(
implementation=self._implementation, backend_version=self._backend_version
)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
value_series = cast(PandasLikeSeries, value_series)

set_then = condition
set_then_native = set_then._native_series
above_already_set = parse_into_expr(self._above_already_set, namespace=plx)._call(
df # type: ignore[arg-type]
)[0]

value_series_native = value_series._native_series

above_result = self._above_then._call(df)[0]
above_result_native = above_result._native_series
set_then_native = set_then._native_series
above_already_set_native = above_already_set._native_series
if self._otherwise_value is None:
return [
above_result._from_native_series(
value_series_native.where(
~above_already_set_native & set_then_native, above_result_native
)
)
]

try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
otherwise_series = condition.__class__._from_iterable( # type: ignore[call-arg]
[self._otherwise_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
)
otherwise_series = cast(PandasLikeSeries, otherwise_series)
return [
above_result.zip_with(
above_already_set, value_series.zip_with(set_then, otherwise_series)
)
]

def then(self, value: Any) -> PandasChainedThen:
self._then_value = value
return PandasChainedThen(
self,
depth=self._depth,
implementation=self._implementation,
function_name="chainedwhen",
root_names=None,
output_names=None,
backend_version=self._backend_version,
)


class PandasChainedThen(PandasLikeExpr):
def __init__(
self,
call: PandasChainedWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
implementation: Implementation,
backend_version: tuple[int, ...],
) -> None:
self._implementation = implementation
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names

def when(
self,
*predicates: IntoPandasLikeExpr,
) -> PandasChainedWhen:
plx = PandasLikeNamespace(self._implementation, self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)
return PandasChainedWhen(
self,
condition,
depth=self._depth + 1,
implementation=self._implementation,
backend_version=self._backend_version,
)

def otherwise(self, value: Any) -> PandasChainedThen:
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "chainedwhenotherwise"
return self
60 changes: 32 additions & 28 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3995,16 +3995,17 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def _extract_predicates(plx: Any, predicates: IntoExpr | Iterable[IntoExpr]) -> Any:
return [extract_compliant(plx, v) for v in flatten([predicates])]


class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]

def then(self, value: Any) -> Then:
return Then(
lambda plx: plx.when(*self._extract_predicates(plx)).then(
lambda plx: plx.when(*_extract_predicates(plx, self._predicates)).then(
extract_compliant(plx, value)
)
)
Expand All @@ -4017,36 +4018,39 @@ def __init__(self, call: Callable[[Any], Any]) -> None:
def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(extract_compliant(plx, value)))

def when(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> ChainedWhen:
return ChainedWhen(self, *predicates)

# class ChainedWhen:
# def __init__(
# self,
# above_then: Then | ChainedThen,
# *predicates: IntoExpr | Iterable[IntoExpr],
# ) -> None:
# self._above_then = above_then
# self._predicates = predicates

# def then(self, value: Any) -> ChainedThen:
# return ChainedThen(
# lambda plx: self._above_then._call(plx)
# .when(*_extract_predicates(plx, flatten([self._predicates])))
# .then(value)
# )
class ChainedWhen:
def __init__(
self,
above_then: Then | ChainedThen,
*predicates: IntoExpr | Iterable[IntoExpr],
) -> None:
self._above_then = above_then
self._predicates = flatten([predicates])

def then(self, value: Any) -> ChainedThen:
return ChainedThen(
lambda plx: self._above_then._call(plx)
.when(*_extract_predicates(plx, self._predicates))
.then(value)
)

# class ChainedThen(Expr):
# def __init__(self, call: Callable[[Any], Any]) -> None:
# self._call = call

# def when(
# self,
# *predicates: IntoExpr | Iterable[IntoExpr],
# ) -> ChainedWhen:
# return ChainedWhen(self, *predicates)
class ChainedThen(Expr):
def __init__(self, call: Callable[[Any], Any]) -> None:
self._call = call

# def otherwise(self, value: Any) -> Expr:
# return Expr(lambda plx: self._call(plx).otherwise(value))
def when(
self,
*predicates: IntoExpr | Iterable[IntoExpr],
) -> ChainedWhen:
return ChainedWhen(self, *predicates)

def otherwise(self, value: Any) -> Expr:
return Expr(lambda plx: self._call(plx).otherwise(value))


def when(*predicates: IntoExpr | Iterable[IntoExpr]) -> When:
Expand Down
Loading

0 comments on commit 09bea00

Please sign in to comment.