diff --git a/Tutorial_2_JAX_HeroPro+_Colab.ipynb b/Tutorial_2_JAX_HeroPro+_Colab.ipynb index 9ec600d..07e72b2 100644 --- a/Tutorial_2_JAX_HeroPro+_Colab.ipynb +++ b/Tutorial_2_JAX_HeroPro+_Colab.ipynb @@ -407,7 +407,7 @@ }, "source": [ "another_list_of_lists = list_of_lists\n", - "print(jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists))" + "print(jax.jax.tree_util.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))" ], "execution_count": null, "outputs": [] @@ -418,10 +418,10 @@ "id": "09Pdhyh2ISb4" }, "source": [ - "# PyTrees need to have the same structure if we are to apply tree_multimap!\n", + "# PyTrees need to have the same structure if we are to apply jax.tree_util.tree_map!\n", "another_list_of_lists = deepcopy(list_of_lists)\n", "another_list_of_lists.append([23])\n", - "print(jax.tree_multimap(lambda x, y: x+y, list_of_lists, another_list_of_lists))" + "print(jax.jax.tree_util.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists))" ], "execution_count": null, "outputs": [] @@ -493,7 +493,7 @@ " # Task: analyze grads and make sure it has the same structure as params\n", "\n", " # SGD update\n", - " return jax.tree_multimap(\n", + " return jax.jax.tree_util.tree_map(\n", " lambda p, g: p - lr * g, params, grads # for every leaf i.e. for every param of MLP\n", " )" ], @@ -979,7 +979,7 @@ "\n", " # Each device performs its own SGD update, but since we start with the same params\n", " # and synchronise gradients, the params stay in sync on each device.\n", - " new_params = jax.tree_multimap(\n", + " new_params = jax.jax.tree_util.tree_map(\n", " lambda param, g: param - g * lr, params, grads)\n", " \n", " # If we were using Adam or another stateful optimizer,\n", @@ -1316,4 +1316,4 @@ "outputs": [] } ] -} \ No newline at end of file +}