From a0cd34d4e25bbc8511041b8fcbf2e21445952bba Mon Sep 17 00:00:00 2001 From: Yi Yang Date: Tue, 3 Sep 2024 03:51:08 -0700 Subject: [PATCH] Add inference mode to TAPIR and TAPIR pytorch colabs to save memory. PiperOrigin-RevId: 670490071 Change-Id: I6eccc08899d008e09f1cc67d01100f40f1175bb3 --- colabs/torch_causal_tapir_demo.ipynb | 16 ++++++++++++++++ colabs/torch_tapir_demo.ipynb | 14 ++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/colabs/torch_causal_tapir_demo.ipynb b/colabs/torch_causal_tapir_demo.ipynb index 302c243..0d028fa 100644 --- a/colabs/torch_causal_tapir_demo.ipynb +++ b/colabs/torch_causal_tapir_demo.ipynb @@ -205,6 +205,20 @@ "model = model.eval()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JH5jNtQAfTP4" + }, + "outputs": [], + "source": [ + "# @title Set to Inference Mode to Save Memory {form-width: \"25%\"}\n", + "\n", + "model = model.eval()\n", + "torch.set_grad_enabled(False)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -213,6 +227,8 @@ }, "outputs": [], "source": [ + "# @title Inference Functions {form-width: \"25%\"}\n", + "\n", "def online_model_init(frames, query_points):\n", " \"\"\"Initialize query features for the query points.\"\"\"\n", " frames = preprocess_frames(frames)\n", diff --git a/colabs/torch_tapir_demo.ipynb b/colabs/torch_tapir_demo.ipynb index 6af12cc..62384e4 100644 --- a/colabs/torch_tapir_demo.ipynb +++ b/colabs/torch_tapir_demo.ipynb @@ -218,6 +218,20 @@ "model = model.to(device)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5xN1TQIGfrDs" + }, + "outputs": [], + "source": [ + "# @title Set to Inference Mode to Save Memory {form-width: \"25%\"}\n", + "\n", + "model = model.eval()\n", + "torch.set_grad_enabled(False)" + ] + }, { "cell_type": "code", "execution_count": null,