Skip to content

Commit 60acb8d

Browse files
committed
Pushing changes with the core shit.
1 parent fafedd6 commit 60acb8d

File tree

2 files changed

+174
-1
lines changed

2 files changed

+174
-1
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ __pycache__
2727
\#*\#
2828
build
2929
compiled/*.cpp
30-
core.*
3130
cutils_ext.cpp
3231
dist
3332
doc/.build/

pytensor/link/mlx/dispatch/core.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
pytensor/link/mlx/dispatch/basic.py
3+
-----------------------------------
4+
5+
First‑cut MLX translations for the most common tensor Ops.
6+
7+
The structure intentionally follows pytensor's JAX dispatcher so that
8+
once these kernels stabilise they can be optimised further (e.g. fusing
9+
element‑wise graphs, adding in‑place updates, RNG thinning, etc.).
10+
"""
11+
from __future__ import annotations
12+
13+
import warnings
14+
import numpy as np
15+
16+
import mlx.core as mx # MLX
17+
from pytensor.link.mlx.dispatch.basic import mlx_funcify # MLX
18+
19+
from pytensor.tensor import get_vector_length
20+
from pytensor.tensor.basic import (
21+
Join, Split, ExtractDiag, Eye, MakeVector,
22+
ScalarFromTensor, TensorFromScalar, Tri,
23+
get_scalar_constant_value,
24+
)
25+
from pytensor.tensor.exceptions import NotScalarConstantError
26+
27+
28+
# ------------------------------------------------------------------
29+
# Join
30+
# ------------------------------------------------------------------
31+
@mlx_funcify.register(Join) # MLX
32+
def mlx_funcify_Join(op, **kwargs):
33+
def join(axis, *tensors):
34+
view = op.view
35+
if (view != -1) and all(
36+
tensors[i].shape[axis] == 0 # MLX
37+
for i in list(range(view)) + list(range(view + 1, len(tensors)))
38+
):
39+
return tensors[view]
40+
41+
return mx.concatenate(tensors, axis=axis) # MLX
42+
43+
return join
44+
45+
46+
# ------------------------------------------------------------------
47+
# Split
48+
# ------------------------------------------------------------------
49+
@mlx_funcify.register(Split) # MLX
50+
def mlx_funcify_Split(op: Split, node, **kwargs):
51+
_, axis_sym, splits_sym = node.inputs
52+
53+
try:
54+
constant_axis = get_scalar_constant_value(axis_sym)
55+
except NotScalarConstantError:
56+
constant_axis = None
57+
warnings.warn(
58+
"Split node does not have a constant axis. MLX implementation may fail."
59+
)
60+
61+
try:
62+
constant_splits = np.array(
63+
[get_scalar_constant_value(splits_sym[i])
64+
for i in range(get_vector_length(splits_sym))]
65+
)
66+
except (ValueError, NotScalarConstantError):
67+
constant_splits = None
68+
warnings.warn(
69+
"Split node does not have constant split positions. MLX implementation may fail."
70+
)
71+
72+
def split(x, axis, splits):
73+
# Resolve constants (avoids tracing extra ops)
74+
if constant_axis is not None:
75+
axis = int(constant_axis)
76+
77+
if constant_splits is not None:
78+
splits = constant_splits
79+
cumsum_splits = np.cumsum(splits[:-1])
80+
else:
81+
# dynamic ‑– keep in graph
82+
splits_arr = mx.array(splits) # MLX
83+
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist() # python list for mx.split
84+
85+
if len(splits) != op.len_splits:
86+
raise ValueError("Length of 'splits' is not equal to n_splits")
87+
if np.sum(np.asarray(splits)) != x.shape[axis]:
88+
raise ValueError("Split sizes do not sum to the input length on the chosen axis.")
89+
if np.any(np.asarray(splits) < 0):
90+
raise ValueError("Split sizes cannot be negative.")
91+
92+
return mx.split(x, cumsum_splits, axis=axis) # MLX
93+
94+
return split
95+
96+
97+
# ------------------------------------------------------------------
98+
# ExtractDiag
99+
# ------------------------------------------------------------------
100+
@mlx_funcify.register(ExtractDiag) # MLX
101+
def mlx_funcify_ExtractDiag(op, **kwargs):
102+
offset, axis1, axis2 = op.offset, op.axis1, op.axis2
103+
104+
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
105+
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) # MLX
106+
107+
return extract_diag
108+
109+
110+
# ------------------------------------------------------------------
111+
# Eye
112+
# ------------------------------------------------------------------
113+
@mlx_funcify.register(Eye) # MLX
114+
def mlx_funcify_Eye(op, **kwargs):
115+
dtype = op.dtype
116+
117+
def eye(N, M, k):
118+
return mx.eye(int(N), int(M), int(k), dtype=dtype) # MLX
119+
120+
return eye
121+
122+
123+
# ------------------------------------------------------------------
124+
# MakeVector
125+
# ------------------------------------------------------------------
126+
@mlx_funcify.register(MakeVector) # MLX
127+
def mlx_funcify_MakeVector(op, **kwargs):
128+
def makevector(*x):
129+
return mx.array(x, dtype=op.dtype) # MLX
130+
131+
return makevector
132+
133+
134+
# ------------------------------------------------------------------
135+
# TensorFromScalar (identity for MLX)
136+
# ------------------------------------------------------------------
137+
@mlx_funcify.register(TensorFromScalar) # MLX
138+
def mlx_funcify_TensorFromScalar(op, **kwargs):
139+
def tensor_from_scalar(x):
140+
return x # already an MLX array / scalar
141+
142+
return tensor_from_scalar
143+
144+
145+
# ------------------------------------------------------------------
146+
# ScalarFromTensor
147+
# ------------------------------------------------------------------
148+
@mlx_funcify.register(ScalarFromTensor) # MLX
149+
def mlx_funcify_ScalarFromTensor(op, **kwargs):
150+
def scalar_from_tensor(x):
151+
return mx.array(x).reshape(-1)[0] # MLX
152+
153+
return scalar_from_tensor
154+
155+
156+
# ------------------------------------------------------------------
157+
# Tri
158+
# ------------------------------------------------------------------
159+
@mlx_funcify.register(Tri) # MLX
160+
def mlx_funcify_Tri(op, node, **kwargs):
161+
# node.inputs -> N, M, k
162+
const_args = [getattr(inp, "data", None) for inp in node.inputs]
163+
164+
def tri(*args):
165+
# Replace args with compile‑time constants when available
166+
args = [
167+
arg if const_a is None else const_a
168+
for arg, const_a in zip(args, const_args, strict=True)
169+
]
170+
return mx.tri(*args, dtype=op.dtype) # MLX
171+
172+
return tri
173+
174+
## Change the code to use the mlx functions

0 commit comments

Comments
 (0)