From f53b581d3bdc44718d971fc7585647eec2ad1270 Mon Sep 17 00:00:00 2001 From: Win Wang <1862202+wiwa@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:15:45 -0500 Subject: [PATCH] Fix Minibatch alignment in Bayesian Neural Network example + Pre-commit hooks (#719) * Fix Minibatch alignment in Bayesian Neural Network example * Run: pre-commit run all-files --------- Co-authored-by: Deepak CH --- .../bayesian_neural_network_advi.ipynb | 11 ++++++++--- .../bayesian_neural_network_advi.myst.md | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/variational_inference/bayesian_neural_network_advi.ipynb b/examples/variational_inference/bayesian_neural_network_advi.ipynb index 84ed5f84..ff18bfa0 100644 --- a/examples/variational_inference/bayesian_neural_network_advi.ipynb +++ b/examples/variational_inference/bayesian_neural_network_advi.ipynb @@ -186,8 +186,12 @@ " }\n", "\n", " with pm.Model(coords=coords) as neural_network:\n", - " ann_input = pm.Data(\"ann_input\", X_train, mutable=True)\n", - " ann_output = pm.Data(\"ann_output\", Y_train, mutable=True)\n", + " # Define minibatch variables\n", + " minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n", + "\n", + " # Define data variables using minibatches\n", + " ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n", + " ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n", "\n", " # Weights from input to hidden layer\n", " weights_in_1 = pm.Normal(\n", @@ -212,7 +216,8 @@ " \"out\",\n", " act_out,\n", " observed=ann_output,\n", - " total_size=Y_train.shape[0], # IMPORTANT for minibatches\n", + " total_size=X_train.shape[0], # IMPORTANT for minibatches\n", + " dims=\"obs_id\",\n", " )\n", " return neural_network\n", "\n", diff --git a/examples/variational_inference/bayesian_neural_network_advi.myst.md b/examples/variational_inference/bayesian_neural_network_advi.myst.md index 28201192..3cb7c325 100644 --- a/examples/variational_inference/bayesian_neural_network_advi.myst.md +++ b/examples/variational_inference/bayesian_neural_network_advi.myst.md @@ -131,8 +131,12 @@ def construct_nn(): } with pm.Model(coords=coords) as neural_network: - ann_input = pm.Data("ann_input", X_train, mutable=True) - ann_output = pm.Data("ann_output", Y_train, mutable=True) + # Define minibatch variables + minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50) + + # Define data variables using minibatches + ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols")) + ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id") # Weights from input to hidden layer weights_in_1 = pm.Normal( @@ -157,7 +161,8 @@ def construct_nn(): "out", act_out, observed=ann_output, - total_size=Y_train.shape[0], # IMPORTANT for minibatches + total_size=X_train.shape[0], # IMPORTANT for minibatches + dims="obs_id", ) return neural_network