diff --git a/examples/variational_inference/bayesian_neural_network_advi.ipynb b/examples/variational_inference/bayesian_neural_network_advi.ipynb index 84ed5f84..ff18bfa0 100644 --- a/examples/variational_inference/bayesian_neural_network_advi.ipynb +++ b/examples/variational_inference/bayesian_neural_network_advi.ipynb @@ -186,8 +186,12 @@ " }\n", "\n", " with pm.Model(coords=coords) as neural_network:\n", - " ann_input = pm.Data(\"ann_input\", X_train, mutable=True)\n", - " ann_output = pm.Data(\"ann_output\", Y_train, mutable=True)\n", + " # Define minibatch variables\n", + " minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n", + "\n", + " # Define data variables using minibatches\n", + " ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n", + " ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n", "\n", " # Weights from input to hidden layer\n", " weights_in_1 = pm.Normal(\n", @@ -212,7 +216,8 @@ " \"out\",\n", " act_out,\n", " observed=ann_output,\n", - " total_size=Y_train.shape[0], # IMPORTANT for minibatches\n", + " total_size=X_train.shape[0], # IMPORTANT for minibatches\n", + " dims=\"obs_id\",\n", " )\n", " return neural_network\n", "\n", diff --git a/examples/variational_inference/bayesian_neural_network_advi.myst.md b/examples/variational_inference/bayesian_neural_network_advi.myst.md index 28201192..3cb7c325 100644 --- a/examples/variational_inference/bayesian_neural_network_advi.myst.md +++ b/examples/variational_inference/bayesian_neural_network_advi.myst.md @@ -131,8 +131,12 @@ def construct_nn(): } with pm.Model(coords=coords) as neural_network: - ann_input = pm.Data("ann_input", X_train, mutable=True) - ann_output = pm.Data("ann_output", Y_train, mutable=True) + # Define minibatch variables + minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50) + + # Define data variables using minibatches + ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols")) + ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id") # Weights from input to hidden layer weights_in_1 = pm.Normal( @@ -157,7 +161,8 @@ def construct_nn(): "out", act_out, observed=ann_output, - total_size=Y_train.shape[0], # IMPORTANT for minibatches + total_size=X_train.shape[0], # IMPORTANT for minibatches + dims="obs_id", ) return neural_network