Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

On the ambiguity of .shape behavior #891

Open
ricardoV94 opened this issue Feb 4, 2025 · 8 comments
Open

On the ambiguity of .shape behavior #891

ricardoV94 opened this issue Feb 4, 2025 · 8 comments

Comments

@ricardoV94
Copy link

ricardoV94 commented Feb 4, 2025

According to #97 a library can decide to either return a tuple[int | None, ...] or a tuple-like object that:

The returned value should be a tuple; however, where warranted, an array library may choose to return a custom shape object. If an array library returns a custom shape object, the object must be immutable, must support indexing for dimension retrieval, and must behave similarly to a tuple.

This seems like a recipe for disaster? The second option allows to operate on shape graphs, whereas the first would fail when you try to act on None, say to find the size of some dimensions by doing prod(x.shape[1:]) (forced example so that .size wouldn't be applicable).

In PyTensor we have the distinction between variable.shape and variable.type.shape, that correspond to those two kinds of output. They are flipped though, and it seems odd to make variable.shape return a tuple with None. It doesn't make sense to build a computation on top of static shape, because those None are not linked to anything.

import numpy as np

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(3, None,))
print(x.shape)  # Shape.0
print(x.type.shape)  # (3, None)

# Could not possibly work with x.type.shape
out = pt.broadcast_to(x, (2, x.shape[0], x.shape[1]))
print(out.type.shape)  # (2, 3, None)

assert out.eval({x: np.ones((3, 4))}).shape == (2, 3, 4)
assert out.eval({x: np.ones((3, 5))}).shape == (2, 3, 5)

Besides that, we sometimes also allow users to replace variables with different static shapes, although it's arguable a bit of an undefined behavior. It seems to contradict the specification that it must be immutable, so happy to say it's out of scope:

new_x = pt.tensor("x", shape=(4, 4))

# Even ignoring the issue of using None for unknown dimensions, the following could not work
# if the original graph was built on top of the static 3 dim length, as that's not "connected" to anything.
new_out = pytensor.graph.clone_replace(out, {x: new_x}, rebuild_strict=False)
print(new_out.type.shape)  # (2, 4, 4)

assert new_out.eval({new_x: np.ones((4, 4))}).shape == (2, 4, 4)

Proposal

Would make sense to separate the two kinds of shape clearly? Perhaps as variable.shape and variable.static_shape. The first should be valid to build computations on top of variable shapes, statically known or not, while the second would allow libraries to reason as much as possible about what is known (and choose to fail if the provided information is insufficient) without having to try and probe which kind of shape output is returned by a specific library.

This is somewhat related to #839, where a library may need as much information as possible to make a decision. Perhaps a static_value would also make sense for a library to return the entries that can be known ahead of time. Anyway that should be discussed there.

If both options make sense, I would argue that .shape should behave like pytensor does.

The standard should also specify if library.shape(x) should match x.shape or x.static_shape. Again I think it should match the first.

@ricardoV94 ricardoV94 changed the title On ambiguity of .shape behavior On the ambiguity of .shape behavior Feb 4, 2025
@rgommers
Copy link
Member

rgommers commented Feb 4, 2025

@ricardoV94 thanks for the questions and idea. I'd like to try to get some clarity on the actual problem first. Here is a bit of context and some thoughts:

  • .shape is used by libraries that are fully eager, fully lazy, or allow for mixed eager/lazy execution (e.g., JIT a part of the code like with jax.jit, or able to handle graph breaks like PyTorch with TorchDynamo)
  • None will only be present during lazy execution; when execution is eager the actual shapes are always known
  • the only real issue during the design discussion was deciding on None vs. nan vs. allowing both as the sentinel for "not yet known size of dimension"
  • the design of the standard is meant to be completely agnostic to execution mode - separate APIs for eager vs lazy are undesirable

say to find the size of some dimensions by doing prod(x.shape[1:]) (forced example so that .size wouldn't be applicable).

I assume you meant math.prod here. That's an example where you are trying to use an eager function from outside the library on a shape element - yes that won't work for lazy arrays. xp.prod(x.shape[1:]), with xp the namespace that x is also from, should work fine and keep things lazy.

In general, lazy implementations have more limitations than eager ones, like you cannot use functions from the stdlib most of the time. That's not specific to the standard though. The standard is carefully designed to not require eager behavior unless it absolutely cannot be avoided - and those few parts have warnings about value-dependent behavior. The most annoying one is __bool__, so one cannot write if expr_evaluating_to_bool: with lazy arrays.

It doesn't make sense to build a computation on top of static shape, because those None are not linked to anything.

I hope the above makes clear that this is not a case that happens in the real world, since static shapes are never unknown.

Are you running into an actual problem using or implementing .shape support?

@ricardoV94
Copy link
Author

ricardoV94 commented Feb 4, 2025

The point was that this complicates writing code that operates on .shape. Most libraries would be happy to give you an output for x.reshape(x.shape[0] * x.shape[1]), but if I were to follow the first suggestion of the API standard this will always fail if I have to return None for unknown dimensions.

Now this is fine within a library because I'm allowed to define x.shape as I want. But then what about meta-libraries that want to implement their own version of reshape? They would need to know if the library is going to do the first sort of shape or the second. So they cannot be backend agnostic.

Am I miusnderstanding the scope of the project?

@ricardoV94
Copy link
Author

ricardoV94 commented Feb 4, 2025

Or put another way why would anyone implement x.shape as a tuple with None in their library? Is anyone doing it /interested in that format?

@rgommers
Copy link
Member

rgommers commented Feb 4, 2025

but if I were to follow the API standard this will always fail if I have to return None for unknown dimensions.

This is just not true? It always works eagerly because static shapes are known, and it always works lazily because x.shape will be fully determined when you arrive at the execution for that line within the compute graph.

Am I miusnderstanding that the project is not interested in facilitating backend-agnostic libraries?

Yes, that's a misunderstanding, unless I'm misunderstanding what you are saying - one of the key goals of this whole effort is to allow libraries to write code that's agnostic to the library and execution model that's backing the input arrays.

But then what about meta-libraries that want to implement their own version of reshape? They would need to know if the library is going to do the first sort of shape or the second. So they cannot be backend agnostic.

I don't quite understand this question, so I'll answer the below.

Or put another way why would anyone implement x.shape as a tuple with None in their library? Is anyone doing/interested in that format?

It's only None within an unevaluated graph. E.g., Dask allows you to build up a graph without calling .compute(), and then poking at the raw attributes. That's not what you want to be doing, but that's when you can see None (EDIT: Dask still uses nan, JAX uses None).

@rgommers
Copy link
Member

rgommers commented Feb 4, 2025

I'm not that familiar with PyTensor so there's a chance there is something I am missing that's behind your questions. There's also a lot of history here. We are rapidly gaining more experience with lazy libraries and their strengths and limitations when used through the standard, e.g. adding support for JAX and Dask in SciPy and scikit-learn. I'm happy to set up a call if you prefer and talk it through?

@ricardoV94
Copy link
Author

ricardoV94 commented Feb 4, 2025

I guess my question is, how do I decide which format to offer? Well it's easy to answer that because if I want xp.reshape(x.shape[0] * x.shape[1]) it will only ever work if I follow the second format. I can't see how JAX can support that if at some point x.shape[0] is literally None?

But importantly for me, will another library ever look for None in the output of .shape to decide on behavior? Will they do something suboptimal because I'm returning a symbolic shape even though almost all the dimensions of my array are static and those are the ones they would need to make the right decision?

For a concrete example, when adding the PyTensor backend to einops, we implemented shape as a mix of symbolic and static (if available) shapes: https://github.com/arogozhnikov/einops/blob/47c742ff94b21dbe2de35ea14fca17d6632f8f73/einops/_backends.py#L699-L704

I had to tell the library how to do that specifically the PyTensor backend (there's something similar for non-eager TF above).

I guess for dask the equivalent would be to call .compute to figure out what the None mean?

No idea how the JAX case can be used from the outside.

Maybe the point is that without the standard, a meta-library like einops will have to figure out which backend it is if they want to make eager decisions on lazy graphs? That's why I feel this may be connected to #839 although it's about shape and not values, which is a simpler case?

@ricardoV94
Copy link
Author

ricardoV94 commented Feb 4, 2025

I'm not that familiar with PyTensor so there's a chance there is something I am missing that's behind your questions. There's also a lot of history here. We are rapidly gaining more experience with lazy libraries and their strengths and limitations when used through the standard, e.g. adding support for JAX and Dask in SciPy and scikit-learn. I'm happy to set up a call if you prefer and talk it through?

I'm sure we're both missing something (me more) :) Feel free to reach out to me

@rgommers
Copy link
Member

rgommers commented Feb 4, 2025

I guess my question is, how do I decide which format to offer?

I'd say put in the actual values if you have them, and None only if you can't. That's the only reasonable thing to do.

But importantly for me, will another library ever look for None in the output of .shape to decide on behavior?

From what I've seen, this is only done when an algorithm has inherently value-dependent behavior, so there is no way to keep things lazy. E.g.:

if unique(x).shape[0] < 5:
     small_size_algo(...)
else:
    regular_algo(...)

Scikit-learn has a fair amount of code like that for example, often using unique. At that point, you must have values so the only two choices are to force computation or to raise.

These cases are very hard to support for lazy arrays, and that's more the problem than whether you hit the "must compute or raise" point in .shape or some other function/attribute.

Feel free to reach out to me at

done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants