You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It is not possible at the moment create/manipulate tensors containing bfloat16 outside of MLX.
x=mx.array(jnp.ones((2, 2), dtype=jnp.bfloat16))
This is because it seems everything using memoryview and consorts fail on bf16.
This would be important in safetensors in order to implement in-memory loading (weights = safetensors.mlx.load(f.read()) for instance) and also for lazy loading certain tensors.
mx.load("file.safetensors") works great.
As far as I can tell memoryview object loose the actual dtype anyway.
Any API that would get a bytes + shape + dtype would work super generically I feel (with or without copying depending on constraints).
This would allow me to correctly implement all supported dtypes on MLX within safetensors itself.
(And others to do advance stuff like loading files from network sockets directly)
Thanks a lot for this work.
The text was updated successfully, but these errors were encountered:
The feature is not to fix the memory view for JAX -> MLX, but really for a way to create tensors from raw bytes instead. (The memoryview just show cases the issue why it's necessary, but we cannot expect jax/tf to be existant for this to work).
@Narsil I'm still not fully understanding what API you are looking for / what's missing? Right now you can create an array from a Python memoryview object which should be pretty flexible:
It is not possible at the moment create/manipulate tensors containing
bfloat16
outside of MLX.This is because it seems everything using
memoryview
and consorts fail onbf16
.This would be important in
safetensors
in order to implement in-memory loading (weights = safetensors.mlx.load(f.read())
for instance) and also for lazy loading certain tensors.mx.load("file.safetensors")
works great.As far as I can tell
memoryview
object loose the actual dtype anyway.Any API that would get a
bytes
+ shape + dtype would work super generically I feel (with or without copying depending on constraints).This would allow me to correctly implement all supported dtypes on MLX within safetensors itself.
(And others to do advance stuff like loading files from network sockets directly)
Thanks a lot for this work.
The text was updated successfully, but these errors were encountered: