diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index ba6ffa8eaa..3aeba505a9 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size): return stats.nbinom.rvs(n, p, size=size, random_state=rng) -nbinom = NegBinomialRV() -negative_binomial = NegBinomialRV() +nbinom = negative_binomial = NegBinomialRV() class BetaBinomialRV(ScipyRandomVariable): @@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size): multinomial = MultinomialRV() + vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()") diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 284f4af2b8..9e1c6c1a14 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node): return [new_var] +@register_infer_shape @node_rewriter([Assert]) def local_remove_all_assert(fgraph, node): r"""A rewrite that removes all `Assert`\s from a graph. @@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node): See the :ref:`unsafe` section. """ - if not isinstance(node.op, Assert): - return - return [node.inputs[0]] diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 3c730a3179..930a6f03e4 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -9,6 +9,7 @@ import pytensor from pytensor.graph import FunctionGraph, Variable from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.utils import hash_from_code @@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]): https://github.com/numpy/numpy/issues/28921 """ return product(*(range(s) for s in shape)) + + +def get_static_shape_from_size_variables( + size_vars: Sequence[Variable], +) -> tuple[int | None, ...]: + """Get static shape from size variables. + + Parameters + ---------- + size_vars : Sequence[Variable] + A sequence of variables representing the size of each dimension. + Returns + ------- + tuple[int | None, ...] + A tuple containing the static lengths of each dimension, or None if + the length is not statically known. + """ + from pytensor.tensor.basic import get_scalar_constant_value + + static_lengths = [None] * len(size_vars) + for i, length in enumerate(size_vars): + try: + static_length = get_scalar_constant_value(length) + except NotScalarConstantError: + pass + else: + static_lengths[i] = int(static_length) + return tuple(static_lengths) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 7e5d22528a..56fe76da0c 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -349,6 +349,9 @@ def dimshuffle(self, *pattern): if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)): pattern = pattern[0] ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern) + if ds_op.new_order == tuple(range(self.type.ndim)): + # No-op + return self return ds_op(self) def flatten(self, ndim=1): diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index ab035be346..fdcab95d7c 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -67,8 +67,8 @@ def make_node(self, x): return Apply(self, [x], [output]) -def xtensor_from_tensor(x, dims): - return XTensorFromTensor(dims=dims)(x) +def xtensor_from_tensor(x, dims, name=None): + return XTensorFromTensor(dims=dims)(x, name=name) class Rename(XTypeCastOp): diff --git a/pytensor/xtensor/random.py b/pytensor/xtensor/random.py new file mode 100644 index 0000000000..d59d091af8 --- /dev/null +++ b/pytensor/xtensor/random.py @@ -0,0 +1,167 @@ +from collections.abc import Sequence +from functools import wraps + +import pytensor.tensor.random.basic as ptr +from pytensor.graph.basic import Variable +from pytensor.tensor.random.op import RandomVariable +from pytensor.xtensor import as_xtensor +from pytensor.xtensor.math import sqrt +from pytensor.xtensor.vectorization import XRV + + +def _as_xrv( + core_op: RandomVariable, + core_inps_dims_map: Sequence[Sequence[int]] | None = None, + core_out_dims_map: Sequence[int] | None = None, +): + """Helper function to define an XRV constructor. + + Parameters + ---------- + core_op : RandomVariable + The core random variable operation to wrap. + core_inps_dims_map : Sequence[Sequence[int]] | None, optional + A sequence of sequences mapping the core dimensions (specified by the user) + for each input parameter. This is used when lowering to a RandomVariable operation, + to decide the ordering of the core dimensions for each input. + If None, it assumes the core dimensions are positional from left to right. + core_out_dims_map : Sequence[int] | None, optional + A sequence mapping the core dimensions (specified by the user) for the output variable. + This is used when lowering to a RandomVariable operation, + to decide the ordering of the core dimensions for the output. + If None, it assumes the core dimensions are positional from left to right. + + """ + if core_inps_dims_map is None: + # Assume core_dims map positionally from left to right + core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params] + if core_out_dims_map is None: + # Assume core_dims map positionally from left to right + core_out_dims_map = tuple(range(core_op.ndim_supp)) + + core_dims_needed = max( + (*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 + ) + + @wraps(core_op) + def xrv_constructor( + *params, + core_dims: Sequence[str] | str | None = None, + extra_dims: dict[str, Variable] | None = None, + rng: Variable | None = None, + ): + if core_dims is None: + core_dims = () + if core_dims_needed: + raise ValueError( + f"{core_op.name} needs {core_dims_needed} core_dims to be specified" + ) + elif isinstance(core_dims, str): + core_dims = (core_dims,) + + if len(core_dims) != core_dims_needed: + raise ValueError( + f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}" + ) + + full_input_core_dims = tuple( + tuple(core_dims[i] for i in inp_dims_map) + for inp_dims_map in core_inps_dims_map + ) + full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map) + full_core_dims = (full_input_core_dims, full_output_core_dims) + + if extra_dims is None: + extra_dims = {} + + return XRV(core_op, core_dims=full_core_dims, extra_dims=extra_dims.keys())( + rng, *extra_dims.values(), *params + ) + + return xrv_constructor + + +bernoulli = _as_xrv(ptr.bernoulli) +beta = _as_xrv(ptr.beta) +betabinom = _as_xrv(ptr.betabinom) +binomial = _as_xrv(ptr.binomial) +categorical = _as_xrv(ptr.categorical) +cauchy = _as_xrv(ptr.cauchy) +dirichlet = _as_xrv(ptr.dirichlet) +exponential = _as_xrv(ptr.exponential) +gamma = _as_xrv(ptr._gamma) +gengamma = _as_xrv(ptr.gengamma) +geometric = _as_xrv(ptr.geometric) +gumbel = _as_xrv(ptr.gumbel) +halfcauchy = _as_xrv(ptr.halfcauchy) +halfnormal = _as_xrv(ptr.halfnormal) +hypergeometric = _as_xrv(ptr.hypergeometric) +integers = _as_xrv(ptr.integers) +invgamma = _as_xrv(ptr.invgamma) +laplace = _as_xrv(ptr.laplace) +logistic = _as_xrv(ptr.logistic) +lognormal = _as_xrv(ptr.lognormal) +multinomial = _as_xrv(ptr.multinomial) +nbinom = negative_binomial = _as_xrv(ptr.negative_binomial) +normal = _as_xrv(ptr.normal) +pareto = _as_xrv(ptr.pareto) +poisson = _as_xrv(ptr.poisson) +t = _as_xrv(ptr.t) +triangular = _as_xrv(ptr.triangular) +truncexpon = _as_xrv(ptr.truncexpon) +uniform = _as_xrv(ptr.uniform) +vonmises = _as_xrv(ptr.vonmises) +wald = _as_xrv(ptr.wald) +weibull = _as_xrv(ptr.weibull) + + +def multivariate_normal( + mean, + cov, + *, + core_dims: Sequence[str], + extra_dims=None, + rng=None, + method="cholesky", +): + mean = as_xtensor(mean) + if len(core_dims) != 2: + raise ValueError( + f"multivariate_normal requires 2 core_dims, got {len(core_dims)}" + ) + + # Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov + # This will be the core dimension of the output + if core_dims[0] not in mean.type.dims: + core_dims = core_dims[::-1] + + xop = _as_xrv(ptr.MvNormalRV(method=method)) + return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng) + + +def standard_normal( + extra_dims: dict[str, Variable] | None = None, + rng: Variable | None = None, +): + """Standard normal random variable.""" + return normal(0, 1, extra_dims=extra_dims, rng=rng) + + +def chisquare( + df: Variable, + extra_dims: dict[str, Variable] | None = None, + rng: Variable | None = None, +): + """Chi-square random variable.""" + return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng) + + +def rayleigh( + scale: Variable, + extra_dims: dict[str, Variable] | None = None, + rng: Variable | None = None, +): + """Rayleigh random variable.""" + + df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV + return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale diff --git a/pytensor/xtensor/rewriting/vectorization.py b/pytensor/xtensor/rewriting/vectorization.py index 9b52022f07..c8ca9c61d4 100644 --- a/pytensor/xtensor/rewriting/vectorization.py +++ b/pytensor/xtensor/rewriting/vectorization.py @@ -1,9 +1,10 @@ from pytensor.graph import node_rewriter from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.random.utils import compute_batch_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor -from pytensor.xtensor.vectorization import XBlockwise, XElemwise +from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise @register_lower_xtensor @@ -53,6 +54,7 @@ def lower_blockwise(fgraph, node): tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) tensor_inputs.append(tensor_inp) + # FIXME: Now that we removed signature, Create signature from op.core_dims tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature) tensor_outs = tensor_op(*tensor_inputs, return_list=True) @@ -62,3 +64,49 @@ def lower_blockwise(fgraph, node): for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) ] return new_outs + + +@register_lower_xtensor +@node_rewriter(tracks=[XRV]) +def lower_rv(fgraph, node): + op: XRV = node.op + core_op = op.core_op + + _, old_out = node.outputs + rng, *extra_dim_lengths_and_params = node.inputs + extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)] + params = extra_dim_lengths_and_params[len(op.extra_dims) :] + + batch_ndim = old_out.type.ndim - len(op.core_dims[1]) + param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim] + + # Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end + tensor_params = [] + for inp, core_dims in zip(params, op.core_dims[0]): + inp_dims = inp.type.dims + # Align the batch dims of the input, and place the core dims on the right + batch_order = [ + inp_dims.index(batch_dim) if batch_dim in inp_dims else "x" + for batch_dim in param_batch_dims + ] + core_order = [inp_dims.index(core_dim) for core_dim in core_dims] + tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order) + tensor_params.append(tensor_inp) + + size = None + if op.extra_dims: + # RV size contains the lengths of all batch dimensions, including those coming from the parameters + if tensor_params: + param_batch_shape = tuple( + compute_batch_shape(tensor_params, ndims_params=core_op.ndims_params) + ) + else: + param_batch_shape = () + size = [*extra_dim_lengths, *param_batch_shape] + + # RVs are their own core Op + new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs + + # Convert output Tensors to XTensors + new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims) + return [new_next_rng, new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index ceded71ec0..60e181cdc7 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -10,6 +10,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.type import integer_dtypes +from pytensor.tensor.utils import get_static_shape_from_size_variables from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor @@ -130,14 +131,9 @@ def make_node(self, x, *unstacked_length): ) ) - static_unstacked_lengths = [None] * len(unstacked_lengths) - for i, length in enumerate(unstacked_lengths): - try: - static_length = get_scalar_constant_value(length) - except NotScalarConstantError: - pass - else: - static_unstacked_lengths[i] = int(static_length) + static_unstacked_lengths = get_static_shape_from_size_variables( + unstacked_lengths + ) output = xtensor( dtype=x.type.dtype, diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 96b0a1fd7c..134ce62125 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -65,6 +65,8 @@ def __init__( ) self.ndim = len(self.dims) self.name = name + self.numpy_dtype = np.dtype(self.dtype) + self.filter_checks_isfinite = False def clone( self, @@ -82,8 +84,9 @@ def clone( return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs) def filter(self, value, strict=False, allow_downcast=None): - # TODO implement this - return value + return TensorType.filter( + self, value, strict=strict, allow_downcast=allow_downcast + ) def convert_variable(self, var): # TODO: Implement this @@ -690,7 +693,7 @@ def signature(self): def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): x_dims: tuple[str, ...] - if isinstance(x, xr.DataArray): + if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): xarray_dims = x.dims if not all(isinstance(dim, str) for dim in xarray_dims): raise NotImplementedError( @@ -746,17 +749,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None): if isinstance(x.type, XTensorType): return x if isinstance(x.type, TensorType): - if x.type.ndim > 0 and dims is None: - raise TypeError( - "non-scalar TensorVariable cannot be converted to XTensorVariable without dims." - ) - return px.basic.xtensor_from_tensor(x, dims) + if dims is None: + if x.type.ndim == 0: + dims = () + else: + raise TypeError( + "non-scalar TensorVariable cannot be converted to XTensorVariable without dims." + ) + return px.basic.xtensor_from_tensor(x, dims=dims, name=name) else: raise TypeError( "Variable with type {x.type} cannot be converted to XTensorVariable." ) try: - return xtensor_constant(x, name=name, dims=dims) + return xtensor_constant(x, dims=dims, name=name) except TypeError as err: raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err diff --git a/pytensor/xtensor/vectorization.py b/pytensor/xtensor/vectorization.py index 1fe7dd99d7..4d02958896 100644 --- a/pytensor/xtensor/vectorization.py +++ b/pytensor/xtensor/vectorization.py @@ -1,13 +1,38 @@ from itertools import chain +import numpy as np + from pytensor import scalar as ps +from pytensor import shared from pytensor.graph import Apply, Op +from pytensor.scalar import discrete_dtypes from pytensor.tensor import tensor -from pytensor.tensor.utils import _parse_gufunc_signature +from pytensor.tensor.random.type import RandomType +from pytensor.tensor.utils import ( + _parse_gufunc_signature, + get_static_shape_from_size_variables, +) from pytensor.xtensor.basic import XOp from pytensor.xtensor.type import as_xtensor, xtensor +def combine_dims_and_shape(inputs): + dims_and_shape: dict[str, int | None] = {} + for inp in inputs: + for dim, dim_length in zip(inp.type.dims, inp.type.shape): + if dim not in dims_and_shape: + dims_and_shape[dim] = dim_length + elif dim_length is not None: + # Check for conflicting shapes + if (dims_and_shape[dim] is not None) and ( + dims_and_shape[dim] != dim_length + ): + raise ValueError(f"Dimension {dim} has conflicting shapes") + # Keep the non-None shape + dims_and_shape[dim] = dim_length + return dims_and_shape + + class XElemwise(XOp): __props__ = ("scalar_op",) @@ -22,20 +47,7 @@ def make_node(self, *inputs): f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" ) - dims_and_shape: dict[str, int | None] = {} - for inp in inputs: - for dim, dim_length in zip(inp.type.dims, inp.type.shape): - if dim not in dims_and_shape: - dims_and_shape[dim] = dim_length - elif dim_length is not None: - # Check for conflicting shapes - if (dims_and_shape[dim] is not None) and ( - dims_and_shape[dim] != dim_length - ): - raise ValueError(f"Dimension {dim} has conflicting shapes") - # Keep the non-None shape - dims_and_shape[dim] = dim_length - + dims_and_shape = combine_dims_and_shape(inputs) if dims_and_shape: output_dims, output_shape = zip(*dims_and_shape.items()) else: @@ -53,7 +65,7 @@ def make_node(self, *inputs): class XBlockwise(XOp): - __props__ = ("core_op", "signature", "core_dims") + __props__ = ("core_op", "core_dims") def __init__( self, @@ -63,7 +75,6 @@ def __init__( ): super().__init__() self.core_op = core_op - self.signature = signature self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.core_dims = core_dims @@ -74,27 +85,13 @@ def make_node(self, *inputs): f"Wrong number of inputs, expected {len(self.inputs_sig)}, got {len(inputs)}" ) - dims_and_shape: dict[str, int | None] = {} - for inp in inputs: - for dim, dim_length in zip(inp.type.dims, inp.type.shape): - if dim not in dims_and_shape: - dims_and_shape[dim] = dim_length - elif dim_length is not None: - # Check for conflicting shapes - if (dims_and_shape[dim] is not None) and ( - dims_and_shape[dim] != dim_length - ): - raise ValueError(f"Dimension {dim} has conflicting shapes") - # Keep the non-None shape - dims_and_shape[dim] = dim_length + dims_and_shape = combine_dims_and_shape(inputs) core_inputs_dims, core_outputs_dims = self.core_dims - # TODO: Avoid intermediate dict - core_dims = set(chain.from_iterable(core_inputs_dims)) - batched_dims_and_shape = { - k: v for k, v in dims_and_shape.items() if k not in core_dims - } - batch_dims, batch_shape = zip(*batched_dims_and_shape.items()) + core_input_dims_set = set(chain.from_iterable(core_inputs_dims)) + batch_dims, batch_shape = zip( + *((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set) + ) dummy_core_inputs = [] for inp, core_inp_dims in zip(inputs, core_inputs_dims): @@ -120,3 +117,135 @@ def make_node(self, *inputs): for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) ] return Apply(self, inputs, outputs) + + +class XRV(XOp): + """Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics. + + Xarray does not offer random generators, so this class implements a new API. + + It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and + enforces dim-based broadcasting between inputs and output. + + It differs from XBlockwise in a couple of ways: + 1. It is restricted to one sample output + 2. It takes a random generator as the first input and returns the consumed generator as the first output. + 3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not + implied by batch dimensions of the parameters. + """ + + default_output = 1 + __props__ = ("core_op", "core_dims", "extra_dims") + + def __init__( + self, + core_op, + core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]], + extra_dims: tuple[str, ...], + ): + super().__init__() + self.core_op = core_op + inps_core_dims, out_core_dims = core_dims + for operand_dims in (*inps_core_dims, out_core_dims): + if len(set(operand_dims)) != len(operand_dims): + raise ValueError(f"Operand has repeated dims {operand_dims}") + self.core_dims = (tuple(i for i in inps_core_dims), tuple(out_core_dims)) + if len(set(extra_dims)) != len(extra_dims): + raise ValueError("size_dims must be unique") + self.extra_dims = tuple(extra_dims) + + def make_node(self, rng, *extra_dim_lengths_and_params): + if rng is None: + rng = shared(np.random.default_rng()) + elif not isinstance(rng.type, RandomType): + raise TypeError( + "The type of rng should be an instance of RandomGeneratorType " + ) + + extra_dim_lengths = [ + as_xtensor(dim_length).values + for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)] + ] + if not all( + (dim_length.type.ndim == 0 and dim_length.type.dtype in discrete_dtypes) + for dim_length in extra_dim_lengths + ): + raise TypeError("All dimension lengths should be scalar discrete dtype.") + + params = [ + as_xtensor(param) + for param in extra_dim_lengths_and_params[len(self.extra_dims) :] + ] + if len(params) != len(self.core_op.ndims_params): + raise ValueError( + f"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, " + f"got {len(extra_dim_lengths_and_params)}" + ) + + param_core_dims, output_core_dims = self.core_dims + input_core_dims_set = set(chain.from_iterable(param_core_dims)) + + # Check parameters don't have core dimensions they shouldn't have + for param, core_param_dims in zip(params, param_core_dims): + if invalid_core_dims := ( + set(param.type.dims) - set(core_param_dims) + ).intersection(input_core_dims_set): + raise ValueError( + f"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}" + ) + + extra_dims_and_shape = dict( + zip( + self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths) + ) + ) + params_dims_and_shape = combine_dims_and_shape(params) + + # Check that no parameter dims conflict with size dims + if conflict_dims := set(extra_dims_and_shape).intersection( + params_dims_and_shape + ): + raise ValueError( + f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique." + ) + + batch_dims_and_shape = [ + (dim, dim_length) + for dim, dim_length in ( + extra_dims_and_shape | params_dims_and_shape + ).items() + if dim not in input_core_dims_set + ] + if batch_dims_and_shape: + batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape) + else: + batch_output_dims, batch_output_shape = (), () + + dummy_core_inputs = [] + for param, core_param_dims in zip(params, param_core_dims): + try: + core_static_shape = [ + param.type.shape[param.type.dims.index(d)] for d in core_param_dims + ] + except ValueError: + raise ValueError( + f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}" + ) + dummy_core_inputs.append( + tensor(dtype=param.type.dtype, shape=core_static_shape) + ) + core_node = self.core_op.make_node(rng, None, *dummy_core_inputs) + + if not len(core_node.outputs) == 2: + raise NotImplementedError( + "XRandomVariable only supports core ops with two outputs (rng, out)" + ) + + _, core_out = core_node.outputs + out = xtensor( + dtype=core_out.type.dtype, + shape=batch_output_shape + core_out.type.shape, + dims=batch_output_dims + output_core_dims, + ) + + return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out]) diff --git a/tests/xtensor/test_random.py b/tests/xtensor/test_random.py new file mode 100644 index 0000000000..de248c3cb7 --- /dev/null +++ b/tests/xtensor/test_random.py @@ -0,0 +1,422 @@ +import inspect +import re +from copy import deepcopy + +import numpy as np +import pytest + +import pytensor.tensor.random as ptr +import pytensor.xtensor.random as pxr +from pytensor import function, shared +from pytensor.graph import rewrite_graph +from pytensor.graph.basic import equal_computations +from pytensor.tensor import broadcast_arrays, tensor +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import random_generator_type +from pytensor.xtensor import as_xtensor, xtensor +from pytensor.xtensor.random import ( + categorical, + multinomial, + multivariate_normal, + normal, +) +from pytensor.xtensor.vectorization import XRV + + +def lower_rewrite(vars): + return rewrite_graph( + vars, + include=( + "lower_xtensor", + "canonicalize", + ), + ) + + +def test_all_basic_rvs_are_wrapped(): + # This ignores wrapper functions + pxr_members = {name for name, _ in inspect.getmembers(pxr)} + for name, op in inspect.getmembers(ptr.basic): + if name in "_gamma": + name = "gamma" + if isinstance(op, RandomVariable) and name not in pxr_members: + raise NotImplementedError(f"Variable {name} not implemented as XRV") + + +def test_updates(): + rng = shared(np.random.default_rng(40)) + next_rng, draws = normal(0, 1, rng=rng).owner.outputs + fn = function([], [draws], updates=[(rng, next_rng)]) + res1, res2 = fn(), fn() + + rng = np.random.default_rng(40) + expected_res1, expected_res2 = rng.normal(0, 1), rng.normal(0, 1) + np.testing.assert_allclose(res1, expected_res1) + np.testing.assert_allclose(res2, expected_res2) + + +def test_zero_inputs(): + class ZeroInputRV(RandomVariable): + signature = "->()" + dtype = "floatX" + name = "ZeroInputRV" + + @classmethod + def rng_fn(cls, rng, size=None): + return rng.random(size=size) + + zero_input_rv = ZeroInputRV() + zero_input_xrv = XRV(zero_input_rv, core_dims=((), ()), extra_dims=["a"]) + + rng = random_generator_type("rng") + a_size = xtensor("a_size", dims=(), dtype=int) + rv = zero_input_xrv(rng, a_size) + assert rv.type.dims == ("a",) + assert rv.type.shape == (None,) + + rng_test = np.random.default_rng(12345) + a_size_val = np.array(5) + np.testing.assert_allclose( + rv.eval({rng: rng_test, a_size: a_size_val}), + rng_test.random(size=(a_size_val,)), + ) + + +def test_output_dim_does_not_map_from_input_dims(): + class NewDimRV(RandomVariable): + signature = "()->(p)" + dtype = "floatX" + name = "NewDimRV" + + @classmethod + def rng_fn(cls, rng, n, size=None): + r = np.stack([n, n + 1], axis=-1) + if size is None: + return r + return np.broadcast_to(r, (*size, 2)) + + def _supp_shape_from_params(self, dist_params, param_shapes=None): + return (2,) + + new_dim_rv = NewDimRV() + new_dim_xrv = XRV(new_dim_rv, core_dims=(((),), ("p",)), extra_dims=["a"]) + + a_size = xtensor("a_size", dims=(), dtype=int) + rv = new_dim_xrv(None, a_size, 1) + assert rv.type.dims == ("a", "p") + assert rv.type.shape == (None, 2) + + a_size_val = np.array(5) + np.testing.assert_allclose( + rv.eval({a_size: a_size_val}), np.broadcast_to((1, 2), (a_size_val, 2)) + ) + + +def test_normal(): + rng = random_generator_type("rng") + c_size = tensor("c_size", shape=(), dtype=int) + mu = tensor("mu", shape=(3,)) + sigma = tensor("sigma", shape=(2,)) + + mu_val = np.array([-10, 0.0, 10.0]) + sigma_val = np.array([1.0, 10.0]) + c_size_val = np.array(5) + rng_val = np.random.default_rng(12345) + + c_size_xr = as_xtensor(c_size, name="c_size_xr") + mu_xr = as_xtensor(mu, dims=("mu_dim",), name="mu_xr") + sigma_xr = as_xtensor(sigma, dims=("sigma_dim",), name="sigma_xr") + + out = normal(mu_xr, sigma_xr, rng=rng) + assert out.type.dims == ("mu_dim", "sigma_dim") + assert out.type.shape == (3, 2) + assert equal_computations( + [lower_rewrite(out.values)], + [rewrite_graph(ptr.normal(mu[:, None], sigma[None, :], rng=rng))], + ) + + out_eval = out.eval( + { + mu: mu_val, + sigma: sigma_val, + rng: rng_val, + } + ) + out_expected = deepcopy(rng_val).normal(mu_val[:, None], sigma_val[None, :]) + np.testing.assert_allclose(out_eval, out_expected) + + # Test with batch dimension + out = normal(mu_xr, sigma_xr, extra_dims=dict(c_dim=c_size_xr), rng=rng) + assert out.type.dims == ("c_dim", "mu_dim", "sigma_dim") + assert out.type.shape == (None, 3, 2) + lowered_size = (c_size, *broadcast_arrays(mu[:, None], sigma[None, :])[0].shape) + assert equal_computations( + [lower_rewrite(out.values)], + [ + rewrite_graph( + ptr.normal(mu[:, None], sigma[None, :], size=lowered_size, rng=rng) + ) + ], + ) + out_eval = out.eval( + { + mu: mu_val, + sigma: sigma_val, + c_size: c_size_val, + rng: rng_val, + } + ) + out_expected = deepcopy(rng_val).normal( + mu_val[:, None], + sigma_val[None, :], + size=(c_size_val, mu_val.shape[0], sigma_val.shape[0]), + ) + np.testing.assert_allclose(out_eval, out_expected) + + # Test invalid core_dims + with pytest.raises( + ValueError, + match=re.escape("normal needs 0 core_dims, but got 1"), + ): + normal(mu_xr, sigma_xr, core_dims=("a",), rng=rng) + + # Test Invalid extra_dims (conflicting with existing batch dims) + with pytest.raises( + ValueError, + match=re.escape( + "Size dimensions ['mu_dim'] conflict with parameter dimensions. They should be unique." + ), + ): + pxr.normal(mu_xr, sigma_xr, extra_dims=dict(mu_dim=c_size_xr), rng=rng) + + +def test_categorical(): + rng = random_generator_type("rng") + p = tensor("p", shape=(2, 3)) + c_size = tensor("c", shape=(), dtype=int) + + p_xr = as_xtensor(p, dims=("p", "batch_dim"), name="p_xr") + c_size_xr = as_xtensor(c_size, name="c_size_xr") + + out = categorical(p_xr, core_dims=("p",), rng=rng) + assert out.type.dims == ("batch_dim",) + assert out.type.shape == (3,) + assert equal_computations( + [lower_rewrite(out.values)], [ptr.categorical(p.T, rng=rng)] + ) + np.testing.assert_allclose( + out.eval( + { + p: np.array([[1.0, 0], [0, 1.0], [1.0, 0]]).T, + rng: np.random.default_rng(), + } + ), + np.array([0, 1, 0]), + ) + + out = categorical( + p_xr, core_dims=("p",), extra_dims=dict(cp1=c_size_xr + 1, c=c_size_xr), rng=rng + ) + assert out.type.dims == ("cp1", "c", "batch_dim") + assert out.type.shape == (None, None, 3) + assert equal_computations( + [lower_rewrite(out.values)], + [ + rewrite_graph( + ptr.categorical( + p.T, size=(1 + c_size, c_size, p[0].shape.squeeze()), rng=rng + ) + ) + ], + ) + np.testing.assert_allclose( + out.eval( + { + p: np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]).T, + c_size: np.array(5), + rng: np.random.default_rng(), + } + ), + np.broadcast_to([0, 1, 0], shape=(6, 5, 3)), + ) + + # Test invaild core dims + with pytest.raises( + ValueError, match="categorical needs 1 core_dims to be specified" + ): + categorical(p_xr, rng=rng) + + with pytest.raises( + ValueError, + match=re.escape( + "At least one core dim=('px',) missing from input p_xr with dims=('p', 'batch_dim')" + ), + ): + categorical(p_xr, core_dims=("px",), rng=rng) + + +def test_multinomial(): + rng = random_generator_type("rng") + n = tensor("n", shape=(2,)) + p = tensor("p", shape=(3, None)) + c_size = tensor("c", shape=(), dtype=int) + n_xr = as_xtensor(n, dims=("a",), name="a_xr") + p_xr = as_xtensor(p, dims=("p", "a"), name="p_xr") + c_size_xr = as_xtensor(c_size, name="c_size_xr") + a_size_xr = n_xr.sizes["a"] + + out = multinomial(n_xr, p_xr, core_dims=("p",), rng=rng) + assert out.type.dims == ("a", "p") + assert out.type.shape == (2, 3) + assert equal_computations( + [lower_rewrite(out.values)], + [ptr.multinomial(n, p.T, size=None, rng=rng)], + ) + # Test we can actually evaluate it + np.testing.assert_allclose( + out.eval( + { + n: [5, 10], + p: np.array([[1.0, 0, 0], [0, 0, 1.0]]).T, + rng: np.random.default_rng(), + } + ), + np.array([[5, 0, 0], [0, 0, 10]]), + ) + + out = multinomial( + n_xr, p_xr, core_dims=("p",), extra_dims=dict(c=c_size_xr), rng=rng + ) + assert out.type.dims == ("c", "a", "p") + assert equal_computations( + [lower_rewrite(out.values)], + [rewrite_graph(ptr.multinomial(n, p.T, size=(c_size, n.shape[0]), rng=rng))], + ) + + # Test we can actually evaluate it with extra_dims + np.testing.assert_allclose( + out.eval( + { + n: [5, 10], + p: np.array([[1.0, 0, 0], [0, 0, 1.0]]).T, + c_size: 5, + rng: np.random.default_rng(), + } + ), + np.broadcast_to( + [[5, 0, 0], [0, 0, 10]], + shape=(5, 2, 3), + ), + ) + + # Test invalid core_dims + with pytest.raises( + ValueError, match="multinomial needs 1 core_dims to be specified" + ): + multinomial(n_xr, p_xr, rng=rng) + + with pytest.raises(ValueError, match="multinomial needs 1 core_dims, but got 2"): + multinomial(n_xr, p_xr, core_dims=("p1", "p2"), rng=rng) + + with pytest.raises( + ValueError, match=re.escape("Parameter a_xr has invalid core dimensions ['a']") + ): + # n cannot have a core dimension + multinomial(n_xr, p_xr, core_dims=("a",), rng=rng) + + with pytest.raises( + ValueError, + match=re.escape( + "At least one core dim=('px',) missing from input p_xr with dims=('p', 'a')" + ), + ): + multinomial(n_xr, p_xr, core_dims=("px",), rng=rng) + + # Test invalid extra_dims + with pytest.raises( + ValueError, + match=re.escape( + "Size dimensions ['a'] conflict with parameter dimensions. They should be unique." + ), + ): + multinomial( + n_xr, + p_xr, + core_dims=("p",), + extra_dims=dict(c=c_size_xr, a=a_size_xr), + rng=rng, + ) + + +def test_multivariate_normal(): + rng = random_generator_type("rng") + mu = tensor("mu", shape=(4, 2)) + cov = tensor("cov", shape=(2, 3, 2, 4)) + + mu_xr = as_xtensor(mu, dims=("b1", "rows"), name="mu_xr") + cov_xr = as_xtensor(cov, dims=("cols", "b2", "rows", "b1"), name="cov_xr") + + out = multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "cols"), rng=rng) + assert out.type.dims == ("b1", "b2", "rows") + assert out.type.shape == (4, 3, 2) + assert equal_computations( + [lower_rewrite(out.values)], + [ptr.multivariate_normal(mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng)], + ) + + # Order of core_dims doesn't matter + out = multivariate_normal(mu_xr, cov_xr, core_dims=("cols", "rows"), rng=rng) + assert out.type.dims == ("b1", "b2", "rows") + assert out.type.shape == (4, 3, 2) + assert equal_computations( + [lower_rewrite(out.values)], + [ptr.multivariate_normal(mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng)], + ) + + # Test method + out = multivariate_normal( + mu_xr, cov_xr, core_dims=("rows", "cols"), rng=rng, method="svd" + ) + assert equal_computations( + [lower_rewrite(out.values)], + [ + ptr.multivariate_normal( + mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng, method="svd" + ) + ], + ) + + # Test invalid core_dims + with pytest.raises( + TypeError, + match=re.escape( + "multivariate_normal() missing 1 required keyword-only argument: 'core_dims'" + ), + ): + multivariate_normal(mu_xr, cov_xr) + + with pytest.raises( + ValueError, match="multivariate_normal requires 2 core_dims, got 3" + ): + multivariate_normal(mu_xr, cov_xr, core_dims=("b1", "rows", "cols")) + + with pytest.raises( + ValueError, match=re.escape("Operand has repeated dims ('rows', 'rows')") + ): + multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "rows")) + + with pytest.raises( + ValueError, + match=re.escape("Parameter mu_xr has invalid core dimensions ['b1']"), + ): + # mu cannot have two core_dims + multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "b1")) + + with pytest.raises( + ValueError, + match=re.escape( + "At least one core dim=('rows', 'missing_cols') missing from input cov_xr with dims=('cols', 'b2', 'rows', 'b1')" + ), + ): + # cov must have both core_dims + multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols"))