From bef4d6c34fcb3df3493f95b72e7634a149809c75 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Thu, 29 Aug 2024 15:04:07 -0700 Subject: [PATCH] Update `compare_mc_analytic_acquisition` tutorial to use LogEI Summary: Updating the analytic vs MC comparison tutorial to use LogEI. Differential Revision: D61997650 --- .../compare_mc_analytic_acquisition.ipynb | 178 +++++++++++------- 1 file changed, 113 insertions(+), 65 deletions(-) diff --git a/tutorials/compare_mc_analytic_acquisition.ipynb b/tutorials/compare_mc_analytic_acquisition.ipynb index c74c7e7334..1a04abe66d 100644 --- a/tutorials/compare_mc_analytic_acquisition.ipynb +++ b/tutorials/compare_mc_analytic_acquisition.ipynb @@ -19,18 +19,23 @@ "showInput": false }, "source": [ - "### Comparison of analytic and MC-based EI" + "### Comparison of analytic and MC-based EI\n", + "Note that we use the analytic and MC variants of the LogEI family of acquisition functions, which remedy numerical issues encountered in the naive implementations. See https://arxiv.org/pdf/2310.20708 for more details." ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 158, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649205799, "executionStopTime": 1668649205822, "originalKey": "f678d607-be4c-4f37-aed5-3597158432ce", + "output": { + "id": 8143993305683446, + "loadingStatus": "loaded" + }, "requestMsgId": "0aae9d3f-d796-4a18-a4aa-b015b5b582ac" }, "outputs": [], @@ -57,17 +62,30 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 159, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649205895, "executionStopTime": 1668649206067, "originalKey": "a7724f86-8b67-4f70-bf57-f0da79b88f52", + "output": { + "id": 1605553740344114, + "loadingStatus": "loaded" + }, "requestMsgId": "25794582-0506-4e89-a112-ba362b7c7e59" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W 240829 14:47:26 659969212:4] The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444\n" + ] + } + ], "source": [ + "torch.manual_seed(seed=12345) # to keep the data conditions the same\n", "train_x = torch.rand(10, 6)\n", "train_obj = neg_hartmann6(train_x).unsqueeze(-1)\n", "model = SingleTaskGP(train_X=train_x, train_Y=train_obj)\n", @@ -87,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 160, "metadata": { "collapsed": false, "customOutput": null, @@ -98,10 +116,10 @@ }, "outputs": [], "source": [ - "from botorch.acquisition import ExpectedImprovement\n", + "from botorch.acquisition.analytic import LogExpectedImprovement\n", "\n", "best_value = train_obj.max()\n", - "EI = ExpectedImprovement(model=model, best_f=best_value)" + "EI = LogExpectedImprovement(model=model, best_f=best_value)" ] }, { @@ -116,13 +134,17 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 161, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649206218, "executionStopTime": 1668649206938, "originalKey": "dc5613c6-2f99-4193-8956-6e710fee5fa2", + "output": { + "id": 422599616946465, + "loadingStatus": "loaded" + }, "requestMsgId": "3df2fc12-7f4c-4abb-b1d2-90bb3b8bf05c" }, "outputs": [], @@ -141,31 +163,35 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 162, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649206992, "executionStopTime": 1668649207011, "originalKey": "76fb19a3-c2c2-451a-8c0b-50cb14c55460", + "output": { + "id": 410743084949823, + "loadingStatus": "loaded" + }, "requestMsgId": "a5cbada9-0b7c-41a2-934f-10d9bbe2e316" }, "outputs": [ { "data": { "text/plain": [ - "tensor([[0.4730, 0.0836, 0.8247, 0.5628, 0.2964, 0.6131]])" + "(tensor([-2.6574], grad_fn=),\n", + " tensor([[0.1382, 0.3801, 0.9660, 0.3046, 0.3479, 0.9341]]))" ] }, - "execution_count": 20, - "metadata": { - "bento_obj_id": "140510701845616" - }, + "execution_count": 165, + "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_point_analytic" + "# NOTE: The acquisition value here is the log of the expected improvement.\n", + "EI(new_point_analytic), new_point_analytic" ] }, { @@ -180,23 +206,27 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 163, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649207083, "executionStopTime": 1668649207929, "originalKey": "aaf04cba-3716-4fbd-8baa-2c75dd068860", + "output": { + "id": 495747073400348, + "loadingStatus": "loaded" + }, "requestMsgId": "0e7691f2-34c7-43df-a247-7f7ba95220f1" }, "outputs": [], "source": [ - "from botorch.acquisition import qExpectedImprovement\n", + "from botorch.acquisition.logei import qLogExpectedImprovement\n", "from botorch.sampling import SobolQMCNormalSampler\n", "\n", "\n", "sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]), seed=0)\n", - "MC_EI = qExpectedImprovement(model, best_f=best_value, sampler=sampler)\n", + "MC_EI = qLogExpectedImprovement(model, best_f=best_value, sampler=sampler, fat=False)\n", "torch.manual_seed(seed=0) # to keep the restart conditions the same\n", "new_point_mc, _ = optimize_acqf(\n", " acq_function=MC_EI,\n", @@ -210,31 +240,35 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 164, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649207976, "executionStopTime": 1668649207989, "originalKey": "73ffa9ea-3cff-46eb-91ea-b2f75fdb07f2", + "output": { + "id": 1030708752027469, + "loadingStatus": "loaded" + }, "requestMsgId": "b780cff4-6e90-4e39-8558-b04136e71e94" }, "outputs": [ { "data": { "text/plain": [ - "tensor([[0.4730, 0.0835, 0.8248, 0.5627, 0.2963, 0.6130]])" + "(tensor([-2.6565], grad_fn=),\n", + " tensor([[0.1378, 0.3803, 0.9663, 0.3046, 0.3478, 0.9334]]))" ] }, - "execution_count": 22, - "metadata": { - "bento_obj_id": "140510701845696" - }, + "execution_count": 167, + "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_point_mc" + "# NOTE: The acquisition value here is the log of the expected improvement.\n", + "MC_EI(new_point_mc), new_point_mc" ] }, { @@ -249,26 +283,28 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 165, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649208035, "executionStopTime": 1668649208043, "originalKey": "c5c20ba9-82af-4d07-832f-86ede74f8959", + "output": { + "id": 504527228978384, + "loadingStatus": "loaded" + }, "requestMsgId": "0b3db1ad-6ddb-4f86-9767-e0a486914b33" }, "outputs": [ { "data": { "text/plain": [ - "tensor(0.0002)" + "tensor(0.0008)" ] }, - "execution_count": 23, - "metadata": { - "bento_obj_id": "140510702063760" - }, + "execution_count": 168, + "metadata": {}, "output_type": "execute_result" } ], @@ -292,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 166, "metadata": { "collapsed": false, "customOutput": null, @@ -308,7 +344,7 @@ "from botorch.optim import gen_batch_initial_conditions\n", "\n", "resampler = StochasticSampler(sample_shape=torch.Size([512]))\n", - "MC_EI_resample = qExpectedImprovement(model, best_f=best_value, sampler=resampler)\n", + "MC_EI_resample = qLogExpectedImprovement(model, best_f=best_value, sampler=resampler)\n", "bounds = torch.tensor([[0.0] * 6, [1.0] * 6])\n", "\n", "batch_initial_conditions = gen_batch_initial_conditions(\n", @@ -333,55 +369,61 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 167, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649208304, "executionStopTime": 1668649208320, "originalKey": "81c29b36-c663-47e1-8155-ad034c214f53", + "output": { + "id": 824304859816682, + "loadingStatus": "loaded" + }, "requestMsgId": "aac6f703-e046-448a-8abe-1742befb9bf9" }, "outputs": [ { "data": { "text/plain": [ - "tensor([[0.4527, 0.1183, 0.8902, 0.5630, 0.3151, 0.5804]])" + "(tensor([-2.6399], grad_fn=),\n", + " tensor([[0.0000, 0.6789, 0.6553, 0.7695, 0.6079, 0.3511]]))" ] }, - "execution_count": 25, - "metadata": { - "bento_obj_id": "140510701998384" - }, + "execution_count": 170, + "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_point_torch_Adam" + "# NOTE: The acquisition value here is the log of the expected improvement.\n", + "MC_EI_resample(new_point_torch_Adam), new_point_torch_Adam" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 168, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649208364, "executionStopTime": 1668649208372, "originalKey": "17fb0de0-3c5a-414e-9aba-b82710d166c0", + "output": { + "id": 1060067959029219, + "loadingStatus": "loaded" + }, "requestMsgId": "a13ce358-3ee6-43ad-9e3a-16181a8cdc1e" }, "outputs": [ { "data": { "text/plain": [ - "tensor(0.0855)" + "tensor(0.9102)" ] }, - "execution_count": 26, - "metadata": { - "bento_obj_id": "140510701610704" - }, + "execution_count": 171, + "metadata": {}, "output_type": "execute_result" } ], @@ -401,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 169, "metadata": { "collapsed": false, "customOutput": null, @@ -427,55 +469,60 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 170, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649208505, "executionStopTime": 1668649208523, "originalKey": "350e456d-0d1c-46dc-a618-0fbba9e0a158", + "output": { + "id": 1743585103114437, + "loadingStatus": "loaded" + }, "requestMsgId": "aa33d42e-c526-4117-88a7-aa3034d82886" }, "outputs": [ { "data": { "text/plain": [ - "tensor([[0.3566, 0.0410, 0.7926, 0.3118, 0.3758, 0.6110]])" + "(tensor([-2.7114], grad_fn=),\n", + " tensor([[0.1133, 0.6818, 0.6779, 0.7704, 0.6066, 0.1224]]))" ] }, - "execution_count": 28, - "metadata": { - "bento_obj_id": "140510702066640" - }, + "execution_count": 173, + "metadata": {}, "output_type": "execute_result" } ], "source": [ - "new_point_torch_SGD" + "MC_EI_resample(new_point_torch_SGD), new_point_torch_SGD" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 171, "metadata": { "collapsed": false, "customOutput": null, "executionStartTime": 1668649208566, "executionStopTime": 1668649208574, "originalKey": "e263cfc7-47a0-4b81-ab33-3aa16320c87e", + "output": { + "id": 1249161419761217, + "loadingStatus": "loaded" + }, "requestMsgId": "3c654fc0-ce64-43c7-a8bf-42935257008a" }, "outputs": [ { "data": { "text/plain": [ - "tensor(0.2928)" + "tensor(1.0570)" ] }, - "execution_count": 29, - "metadata": { - "bento_obj_id": "140510701611584" - }, + "execution_count": 174, + "metadata": {}, "output_type": "execute_result" } ], @@ -485,10 +532,13 @@ } ], "metadata": { + "fileHeader": "", + "fileUid": "8fef7fd4-00ef-428b-946f-953919266648", + "isAdHoc": false, "kernelspec": { - "display_name": "python3", + "display_name": "automl", "language": "python", - "name": "python3" + "name": "bento_kernel_automl" }, "language_info": { "codemirror_mode": { @@ -502,7 +552,5 @@ "pygments_lexer": "ipython3", "version": "3.7.6" } - }, - "nbformat": 4, - "nbformat_minor": 2 + } }