Skip to content

Commit bde593a

Browse files
committed
.WIP
1 parent 692c53c commit bde593a

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

pytensor/xtensor/vectorization.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,23 @@ def make_node(self, *inputs):
117117
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
118118
]
119119
return Apply(self, inputs, outputs)
120+
121+
122+
class XRandomVariable(XOp):
123+
__props__ = ("dist", "core_dims")
124+
125+
def __init__(self, dist, core_dims: tuple[tuple[tuple[str, ...], ...], tuple[tuple[str, ...], ...]]):
126+
super().__init__()
127+
self.dist = dist
128+
self.core_dims = core_dims,
129+
130+
def make_node(self, *inputs):
131+
inputs = [as_xtensor(inp) for inp in inputs]
132+
if len(inputs) != self.dist.nin:
133+
raise ValueError(
134+
f"Wrong number of inputs, expected {self.dist.nin}, got {len(inputs)}"
135+
)
136+
137+
output_dims, output_shape = inputs[0].type.dims, inputs[0].type.shape
138+
output = xtensor(dtype=self.dist.dtype, dims=output_dims, shape=output_shape)
139+
return Apply(self, inputs, [output])

tests/xtensor/test_random.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
3+
import pytensor.tensor.random as ptr
4+
from pytensor.graph.basic import equal_computations
5+
from pytensor.tensor.random.type import random_generator_type
6+
from pytensor.xtensor import xtensor
7+
from pytensor.xtensor.random import multinomial, multivariate_normal, normal, categorical
8+
9+
lower_rewrite = lambda x: x
10+
11+
def test_normal():
12+
pass
13+
14+
def test_categorical():
15+
pass
16+
17+
def test_multinomial():
18+
rng = random_generator_type("rng")
19+
n = xtensor(shape=(2,), dims=("a",))
20+
p = xtensor(shape=(3, None), dims=("p", "a"))
21+
c_size = xtensor(shape=(), dims=(), dtype=int)
22+
a_size = n.sizes["a"]
23+
24+
out = multinomial(n, p, core_dims=("p",), rng=rng)
25+
assert out.type.dims == ("a", "p")
26+
assert out.type.shape == (2, 3)
27+
assert equal_computations(
28+
[lower_rewrite(out)],
29+
[ptr.multinomial(n.values, p.values.T, size=None, rng=rng)],
30+
)
31+
# TODO: Make sure we can actually evaluate it
32+
...
33+
34+
out = multinomial(n, p, core_dims=("p",), size=dict(a=a_size), rng=rng)
35+
assert out.type.dims == ("a", "p")
36+
assert equal_computations(
37+
[lower_rewrite(out)],
38+
[ptr.multinomial(n.values, p.values.T, size=(a_size.values,), rng=rng)],
39+
)
40+
41+
out = multinomial(n, p, core_dims=("p",), size=dict(a=a_size, c=c_size), rng=rng)
42+
assert out.type.dims == ("a", "c", "p")
43+
assert equal_computations(
44+
[lower_rewrite(out)],
45+
[ptr.multinomial(n.values[:, None], p.values.T[:, None, :], size=(a_size.values, c_size.values), rng=rng)],
46+
)
47+
48+
out = multinomial(n, p, core_dims=("p",), size=dict(c=c_size, a=a_size,), rng=rng)
49+
assert out.type.dims == ("c", "a", "p")
50+
assert equal_computations(
51+
[lower_rewrite(out)],
52+
[ptr.multinomial(n.values, p.values.T, size=(c_size.values, a_size.values), rng=rng)],
53+
)
54+
55+
# Test missing core_dims
56+
with pytest.raises(ValueError):
57+
multinomial(n, p, rng=rng)
58+
59+
# Test invalid core_dims
60+
with pytest.raises(ValueError):
61+
# n cannot have a core dimension
62+
multinomial(n, p, core_dims=("a",), rng=rng)
63+
64+
# Test incomplete size
65+
with pytest.raises(ValueError):
66+
multinomial(n, p, core_dims=("p",), size=dict(c=c_size), rng=rng)
67+
68+
69+
def test_multivariate_normal():
70+
pass
71+
72+
def test_new_out_dim()
73+
pass

0 commit comments

Comments
 (0)