Skip to content

Commit 19389fc

Browse files
authored
Make Minibatch warning less conservative
1 parent 518eac9 commit 19389fc

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

pymc/model/core.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from pytensor.compile import DeepCopyOp, Function, get_mode
3636
from pytensor.compile.sharedvalue import SharedVariable
37-
from pytensor.graph.basic import Constant, Variable, graph_inputs
37+
from pytensor.graph.basic import Constant, Variable, ancestors, graph_inputs
3838
from pytensor.tensor.random.op import RandomVariable
3939
from pytensor.tensor.random.type import RandomType
4040
from pytensor.tensor.variable import TensorConstant, TensorVariable
@@ -1241,15 +1241,13 @@ def register_rv(
12411241
self.add_named_variable(rv_var, dims)
12421242
self.set_initval(rv_var, initval)
12431243
else:
1244-
if (
1245-
isinstance(observed, TensorVariable)
1246-
and observed.owner is not None
1247-
and isinstance(observed.owner.op, MinibatchOp)
1248-
and total_size is None
1249-
):
1250-
warnings.warn(
1251-
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
1252-
)
1244+
if total_size is None and isinstance(observed, TensorVariable):
1245+
for node in ancestors([observed]):
1246+
if node.owner is not None and isinstance(node.owner.op, MinibatchOp):
1247+
warnings.warn(
1248+
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
1249+
)
1250+
break
12531251
if not is_valid_observed(observed):
12541252
raise TypeError(
12551253
"Variables that depend on other nodes cannot be used for observed data."

0 commit comments

Comments
 (0)