Skip to content

Use numpy C-API einsum for unoptimized Einsum #1356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Apr 9, 2025 · 0 comments
Open

Use numpy C-API einsum for unoptimized Einsum #1356

ricardoV94 opened this issue Apr 9, 2025 · 0 comments

Comments

@ricardoV94
Copy link
Member

Description

When Einsum can't be optimized (because we don't know the static shapes) it stays as an OpFromGraph. We could replace it by a COp (as a cxx_only rewrite) in this case, that calls the numpy C function:

https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_EinsteinSum

@register_specialize
@node_rewriter([Einsum])
def inline_optimized_einsum(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Inline einsums that are already optimized.
This allows the inner garph to be optimized with the rest of the graph, now that we got ordering right.
"""
op: Einsum = node.op
if not op.optimized:
return None
return cast(list[TensorVariable], inline_ofg_node(node))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant