Skip to content

Commit

Permalink
fix(python): Issue correct PolarsInefficientMapWarning for lshift/r…
Browse files Browse the repository at this point in the history
…shift operations (#12385)
  • Loading branch information
MarcoGorelli authored Nov 13, 2023
1 parent 7711af2 commit a380439
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
11 changes: 11 additions & 0 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class OpNames:
"BINARY_ADD": "+",
"BINARY_AND": "&",
"BINARY_FLOOR_DIVIDE": "//",
"BINARY_LSHIFT": "<<",
"BINARY_RSHIFT": ">>",
"BINARY_MODULO": "%",
"BINARY_MULTIPLY": "*",
"BINARY_OR": "|",
Expand Down Expand Up @@ -524,6 +526,15 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
if not isinstance(self._caller_variables.get(e1, None), dict):
raise NotImplementedError("require dict mapping")
return f"{e2}.{op}({e1})"
elif op == "<<":
# Result of 2**e2 might be float is e2 was negative.
# But, if e1 << e2 was valid, then e2 must have been positive.
# Hence, the output of 2**e2 can be safely cast to Int64, which
# may be necessary if chaining operations which assume Int64 output.
return f"({e1}*2**{e2}).cast(pl.Int64)"
elif op == ">>":
# Motivation for the cast is the same as in the '<<' case above.
return f"({e1} / 2**{e2}).cast(pl.Int64)"
else:
expr = f"{e1} {op} {e2}"
return f"({expr})" if depth else expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,29 @@
'lambda x: dt.datetime.strptime(x, "%Y-%m-%d")',
'pl.col("d").str.to_datetime(format="%Y-%m-%d")',
),
# ---------------------------------------------
# Bitwise shifts
# ---------------------------------------------
(
"a",
"lambda x: (3 << (32-x)) & 3",
'(3*2**(32 - pl.col("a"))).cast(pl.Int64) & 3',
),
(
"a",
"lambda x: (x << 32) & 3",
'(pl.col("a")*2**32).cast(pl.Int64) & 3',
),
(
"a",
"lambda x: ((32-x) >> (3)) & 3",
'((32 - pl.col("a")) / 2**3).cast(pl.Int64) & 3',
),
(
"a",
"lambda x: (32 >> (3-x)) & 3",
'(32 / 2**(3 - pl.col("a"))).cast(pl.Int64) & 3',
),
]

NOOP_TEST_CASES = [
Expand Down

0 comments on commit a380439

Please sign in to comment.