You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This might be a trivial one, but it's important from a usability perspective.
PjRt executables expect a slice of buffers or literals as input to run on. However, if you have a model with dozens of parameter tensors organizing them all into a slice manually becomes tedious (already visible in the mnist_xla example), so I think our abstract model API should basically be callable on any user-defined structure implementing Into<Vec<Literal>> or Into<Vec<PjRtBuffer>>. The gradient engine should also return gradients with the same desired structure if the parameter struct implements From<Vec<T>>.
This is basically the equivalent of JAX tree flattening/unflattening
I don't think we have to do anything very thoughtful for this one, just write the type abstract signatures and call into and from appropriately.
The text was updated successfully, but these errors were encountered:
This might be a trivial one, but it's important from a usability perspective.
PjRt executables expect a slice of buffers or literals as input to run on. However, if you have a model with dozens of parameter tensors organizing them all into a slice manually becomes tedious (already visible in the mnist_xla example), so I think our abstract model API should basically be callable on any user-defined structure implementing
Into<Vec<Literal>>
orInto<Vec<PjRtBuffer>>
. The gradient engine should also return gradients with the same desired structure if the parameter struct implementsFrom<Vec<T>>
.This is basically the equivalent of JAX tree flattening/unflattening
I don't think we have to do anything very thoughtful for this one, just write the type abstract signatures and call
into
andfrom
appropriately.The text was updated successfully, but these errors were encountered: