|
34 | 34 |
|
35 | 35 | from pytensor.compile import DeepCopyOp, Function, get_mode
|
36 | 36 | from pytensor.compile.sharedvalue import SharedVariable
|
37 |
| -from pytensor.graph.basic import Constant, Variable, graph_inputs |
| 37 | +from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs |
38 | 38 | from pytensor.tensor.random.op import RandomVariable
|
39 | 39 | from pytensor.tensor.random.type import RandomType
|
40 | 40 | from pytensor.tensor.variable import TensorConstant, TensorVariable
|
@@ -1241,15 +1241,13 @@ def register_rv(
|
1241 | 1241 | self.add_named_variable(rv_var, dims)
|
1242 | 1242 | self.set_initval(rv_var, initval)
|
1243 | 1243 | else:
|
1244 |
| - if ( |
1245 |
| - isinstance(observed, TensorVariable) |
1246 |
| - and observed.owner is not None |
1247 |
| - and isinstance(observed.owner.op, MinibatchOp) |
1248 |
| - and total_size is None |
1249 |
| - ): |
1250 |
| - warnings.warn( |
1251 |
| - f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" |
1252 |
| - ) |
| 1244 | + if total_size is None and isinstance(observed, TensorVariable): |
| 1245 | + for node in ancestors([observed]): |
| 1246 | + if node.owner is not None and isinstance(node.owner.op, MinibatchOp): |
| 1247 | + warnings.warn( |
| 1248 | + f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" |
| 1249 | + ) |
| 1250 | + break |
1253 | 1251 | if not is_valid_observed(observed):
|
1254 | 1252 | raise TypeError(
|
1255 | 1253 | "Variables that depend on other nodes cannot be used for observed data."
|
|
0 commit comments