1
+ """
2
+ pytensor/link/mlx/dispatch/basic.py
3
+ -----------------------------------
4
+
5
+ First‑cut MLX translations for the most common tensor Ops.
6
+
7
+ The structure intentionally follows pytensor's JAX dispatcher so that
8
+ once these kernels stabilise they can be optimised further (e.g. fusing
9
+ element‑wise graphs, adding in‑place updates, RNG thinning, etc.).
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import warnings
14
+ import numpy as np
15
+
16
+ import mlx .core as mx # MLX
17
+ from pytensor .link .mlx .dispatch .basic import mlx_funcify # MLX
18
+
19
+ from pytensor .tensor import get_vector_length
20
+ from pytensor .tensor .basic import (
21
+ Join , Split , ExtractDiag , Eye , MakeVector ,
22
+ ScalarFromTensor , TensorFromScalar , Tri ,
23
+ get_scalar_constant_value ,
24
+ )
25
+ from pytensor .tensor .exceptions import NotScalarConstantError
26
+
27
+
28
+ # ------------------------------------------------------------------
29
+ # Join
30
+ # ------------------------------------------------------------------
31
+ @mlx_funcify .register (Join ) # MLX
32
+ def mlx_funcify_Join (op , ** kwargs ):
33
+ def join (axis , * tensors ):
34
+ view = op .view
35
+ if (view != - 1 ) and all (
36
+ tensors [i ].shape [axis ] == 0 # MLX
37
+ for i in list (range (view )) + list (range (view + 1 , len (tensors )))
38
+ ):
39
+ return tensors [view ]
40
+
41
+ return mx .concatenate (tensors , axis = axis ) # MLX
42
+
43
+ return join
44
+
45
+
46
+ # ------------------------------------------------------------------
47
+ # Split
48
+ # ------------------------------------------------------------------
49
+ @mlx_funcify .register (Split ) # MLX
50
+ def mlx_funcify_Split (op : Split , node , ** kwargs ):
51
+ _ , axis_sym , splits_sym = node .inputs
52
+
53
+ try :
54
+ constant_axis = get_scalar_constant_value (axis_sym )
55
+ except NotScalarConstantError :
56
+ constant_axis = None
57
+ warnings .warn (
58
+ "Split node does not have a constant axis. MLX implementation may fail."
59
+ )
60
+
61
+ try :
62
+ constant_splits = np .array (
63
+ [get_scalar_constant_value (splits_sym [i ])
64
+ for i in range (get_vector_length (splits_sym ))]
65
+ )
66
+ except (ValueError , NotScalarConstantError ):
67
+ constant_splits = None
68
+ warnings .warn (
69
+ "Split node does not have constant split positions. MLX implementation may fail."
70
+ )
71
+
72
+ def split (x , axis , splits ):
73
+ # Resolve constants (avoids tracing extra ops)
74
+ if constant_axis is not None :
75
+ axis = int (constant_axis )
76
+
77
+ if constant_splits is not None :
78
+ splits = constant_splits
79
+ cumsum_splits = np .cumsum (splits [:- 1 ])
80
+ else :
81
+ # dynamic ‑– keep in graph
82
+ splits_arr = mx .array (splits ) # MLX
83
+ cumsum_splits = mx .cumsum (splits_arr [:- 1 ]).tolist () # python list for mx.split
84
+
85
+ if len (splits ) != op .len_splits :
86
+ raise ValueError ("Length of 'splits' is not equal to n_splits" )
87
+ if np .sum (np .asarray (splits )) != x .shape [axis ]:
88
+ raise ValueError ("Split sizes do not sum to the input length on the chosen axis." )
89
+ if np .any (np .asarray (splits ) < 0 ):
90
+ raise ValueError ("Split sizes cannot be negative." )
91
+
92
+ return mx .split (x , cumsum_splits , axis = axis ) # MLX
93
+
94
+ return split
95
+
96
+
97
+ # ------------------------------------------------------------------
98
+ # ExtractDiag
99
+ # ------------------------------------------------------------------
100
+ @mlx_funcify .register (ExtractDiag ) # MLX
101
+ def mlx_funcify_ExtractDiag (op , ** kwargs ):
102
+ offset , axis1 , axis2 = op .offset , op .axis1 , op .axis2
103
+
104
+ def extract_diag (x , offset = offset , axis1 = axis1 , axis2 = axis2 ):
105
+ return mx .diagonal (x , offset = offset , axis1 = axis1 , axis2 = axis2 ) # MLX
106
+
107
+ return extract_diag
108
+
109
+
110
+ # ------------------------------------------------------------------
111
+ # Eye
112
+ # ------------------------------------------------------------------
113
+ @mlx_funcify .register (Eye ) # MLX
114
+ def mlx_funcify_Eye (op , ** kwargs ):
115
+ dtype = op .dtype
116
+
117
+ def eye (N , M , k ):
118
+ return mx .eye (int (N ), int (M ), int (k ), dtype = dtype ) # MLX
119
+
120
+ return eye
121
+
122
+
123
+ # ------------------------------------------------------------------
124
+ # MakeVector
125
+ # ------------------------------------------------------------------
126
+ @mlx_funcify .register (MakeVector ) # MLX
127
+ def mlx_funcify_MakeVector (op , ** kwargs ):
128
+ def makevector (* x ):
129
+ return mx .array (x , dtype = op .dtype ) # MLX
130
+
131
+ return makevector
132
+
133
+
134
+ # ------------------------------------------------------------------
135
+ # TensorFromScalar (identity for MLX)
136
+ # ------------------------------------------------------------------
137
+ @mlx_funcify .register (TensorFromScalar ) # MLX
138
+ def mlx_funcify_TensorFromScalar (op , ** kwargs ):
139
+ def tensor_from_scalar (x ):
140
+ return x # already an MLX array / scalar
141
+
142
+ return tensor_from_scalar
143
+
144
+
145
+ # ------------------------------------------------------------------
146
+ # ScalarFromTensor
147
+ # ------------------------------------------------------------------
148
+ @mlx_funcify .register (ScalarFromTensor ) # MLX
149
+ def mlx_funcify_ScalarFromTensor (op , ** kwargs ):
150
+ def scalar_from_tensor (x ):
151
+ return mx .array (x ).reshape (- 1 )[0 ] # MLX
152
+
153
+ return scalar_from_tensor
154
+
155
+
156
+ # ------------------------------------------------------------------
157
+ # Tri
158
+ # ------------------------------------------------------------------
159
+ @mlx_funcify .register (Tri ) # MLX
160
+ def mlx_funcify_Tri (op , node , ** kwargs ):
161
+ # node.inputs -> N, M, k
162
+ const_args = [getattr (inp , "data" , None ) for inp in node .inputs ]
163
+
164
+ def tri (* args ):
165
+ # Replace args with compile‑time constants when available
166
+ args = [
167
+ arg if const_a is None else const_a
168
+ for arg , const_a in zip (args , const_args , strict = True )
169
+ ]
170
+ return mx .tri (* args , dtype = op .dtype ) # MLX
171
+
172
+ return tri
173
+
174
+ ## Change the code to use the mlx functions
0 commit comments