Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter structure abstraction #74

Open
Ebanflo42 opened this issue Apr 4, 2024 · 1 comment
Open

Parameter structure abstraction #74

Ebanflo42 opened this issue Apr 4, 2024 · 1 comment

Comments

@Ebanflo42
Copy link
Collaborator

This might be a trivial one, but it's important from a usability perspective.

PjRt executables expect a slice of buffers or literals as input to run on. However, if you have a model with dozens of parameter tensors organizing them all into a slice manually becomes tedious (already visible in the mnist_xla example), so I think our abstract model API should basically be callable on any user-defined structure implementing Into<Vec<Literal>> or Into<Vec<PjRtBuffer>>. The gradient engine should also return gradients with the same desired structure if the parameter struct implements From<Vec<T>>.

This is basically the equivalent of JAX tree flattening/unflattening

I don't think we have to do anything very thoughtful for this one, just write the type abstract signatures and call into and from appropriately.

@Ebanflo42
Copy link
Collaborator Author

@atlv24 could you add the definition of your typeclass and its macros to a file called src/tree.rs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant