Skip to content

Commit 6d39c23

Browse files
committed
Check all ancestors for MinibatchOp
1 parent 518eac9 commit 6d39c23

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

pymc/model/core.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1241,15 +1241,17 @@ def register_rv(
12411241
self.add_named_variable(rv_var, dims)
12421242
self.set_initval(rv_var, initval)
12431243
else:
1244-
if (
1245-
isinstance(observed, TensorVariable)
1246-
and observed.owner is not None
1247-
and isinstance(observed.owner.op, MinibatchOp)
1248-
and total_size is None
1249-
):
1250-
warnings.warn(
1251-
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
1252-
)
1244+
if total_size is None:
1245+
for node in ancestors([observed]):
1246+
if (
1247+
isinstance(node, TensorVariable)
1248+
and node.owner is not None
1249+
and isinstance(node.owner.op, MinibatchOp)
1250+
):
1251+
warnings.warn(
1252+
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
1253+
)
1254+
break
12531255
if not is_valid_observed(observed):
12541256
raise TypeError(
12551257
"Variables that depend on other nodes cannot be used for observed data."

0 commit comments

Comments
 (0)