Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Dim, support Dim.capacity > max(Dim.dyn_size_ext) #1641

Open
albertz opened this issue Nov 6, 2024 · 0 comments
Open

Tensor Dim, support Dim.capacity > max(Dim.dyn_size_ext) #1641

albertz opened this issue Nov 6, 2024 · 0 comments

Comments

@albertz
Copy link
Member

albertz commented Nov 6, 2024

Currently, in multiple places (where exactly?) we have the assumption, for some tensor x: Tensor that max(x.dims[i].dyn_size_ext.raw_tensor) == x.raw_tensor.shape[i].

We want to support the case where max(x.dims[i].dyn_size_ext.raw_tensor) < x.raw_tensor.shape[i]. Specifically, we always have x.dims[i].capacity == x.raw_tensor.shape[i], thus we want to support the case dim.capacity > max(dim.dyn_size_ext.raw_tensor).

This is e.g. needed for JAX, where we only can have static shapes. So there we would compile our code for a few predefined batch sizes (fixed batch_dim, fixed spatial dim), and then the batching would prepare the batch fitting for one of these predefined sizes (e.g. via bucketing). But this often means that we don't reach the max possible seq len for this batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant