diff --git a/py-polars/polars/_utils/udfs.py b/py-polars/polars/_utils/udfs.py index ed91c1920cc9..2e345ba72c06 100644 --- a/py-polars/polars/_utils/udfs.py +++ b/py-polars/polars/_utils/udfs.py @@ -102,6 +102,7 @@ class OpNames: | set(SYNTHETIC) | LOAD_VALUES ) + MATCHABLE_OPS = PARSEABLE_OPS | set(BINARY) | LOAD_ATTR | CALL UNARY_VALUES = frozenset(UNARY.values()) @@ -753,11 +754,18 @@ def __init__( self._function = function self._caller_variables = caller_variables self._original_instructions = list(instructions) - self._rewritten_instructions = self._rewrite( - self._upgrade_instruction(inst) - for inst in self._unpack_superinstructions(self._original_instructions) - if inst.opname not in self._ignored_ops - ) + + normalised_instructions = [] + + for inst in self._unpack_superinstructions(self._original_instructions): + if inst.opname not in self._ignored_ops: + if inst.opname not in OpNames.MATCHABLE_OPS: + self._rewritten_instructions = [] + return + upgraded_inst = self._upgrade_instruction(inst) + normalised_instructions.append(upgraded_inst) + + self._rewritten_instructions = self._rewrite(normalised_instructions) def __len__(self) -> int: return len(self._rewritten_instructions) @@ -810,7 +818,7 @@ def _matches( return instructions return [] - def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: + def _rewrite(self, instructions: list[Instruction]) -> list[Instruction]: """ Apply rewrite rules, potentially injecting synthetic operations. @@ -818,7 +826,7 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: it as needed, pushing updates into "updated_instructions" and returning True/False to indicate if any changes were made. """ - self._instructions = list(instructions) + self._instructions = instructions updated_instructions: list[Instruction] = [] idx = 0 while idx < len(self._instructions):