diff --git a/botorch/utils/multi_objective/hypervolume.py b/botorch/utils/multi_objective/hypervolume.py index 92aaf0e253..3a160ecd72 100644 --- a/botorch/utils/multi_objective/hypervolume.py +++ b/botorch/utils/multi_objective/hypervolume.py @@ -39,8 +39,11 @@ def infer_reference_point( ) -> Tensor: r"""Get reference point for hypervolume computations. - This sets the reference point to be `ref_point = nadir - 0.1 * range` - when there is no pareto_Y that is better than the reference point. + This sets the reference point to be `ref_point = nadir - scale * range` + when there is no `pareto_Y` that is better than `max_ref_point`. + If there's `pareto_Y` better than `max_ref_point`, the reference point + will be set to `max_ref_point - scale * range` if `scale_max_ref_point` + is true and to `max_ref_point` otherwise. [Ishibuchi2011]_ find 0.1 to be a robust multiplier for scaling the nadir point. @@ -50,6 +53,9 @@ def infer_reference_point( Args: pareto_Y: A `n x m`-dim tensor of Pareto-optimal points. max_ref_point: A `m` dim tensor indicating the maximum reference point. + Some elements can be NaN, except when `pareto_Y` is empty, + in which case these dimensions will be treated as if no + `max_ref_point` was provided and set to `nadir - scale * range`. scale: A multiplier used to scale back the reference point based on the range of each objective. scale_max_ref_point: A boolean indicating whether to apply scaling to @@ -58,20 +64,28 @@ def infer_reference_point( Returns: A `m`-dim tensor containing the reference point. """ - if pareto_Y.shape[0] == 0: if max_ref_point is None: raise BotorchError("Empty pareto set and no max ref point provided") + if max_ref_point.isnan().any(): + raise BotorchError("Empty pareto set and max ref point includes NaN.") if scale_max_ref_point: return max_ref_point - scale * max_ref_point.abs() return max_ref_point if max_ref_point is not None: - better_than_ref = (pareto_Y > max_ref_point).all(dim=-1) + non_nan_idx = ~max_ref_point.isnan() + # Count all points exceeding non-NaN reference point as being better. + better_than_ref = (pareto_Y[:, non_nan_idx] > max_ref_point[non_nan_idx]).all( + dim=-1 + ) else: - better_than_ref = torch.full( - pareto_Y.shape[:1], 1, dtype=bool, device=pareto_Y.device + non_nan_idx = torch.ones( + pareto_Y.shape[-1], dtype=torch.bool, device=pareto_Y.device ) - if max_ref_point is not None and better_than_ref.any(): + better_than_ref = torch.ones( + pareto_Y.shape[:1], dtype=torch.bool, device=pareto_Y.device + ) + if max_ref_point is not None and better_than_ref.any() and non_nan_idx.all(): Y_range = pareto_Y[better_than_ref].max(dim=0).values - max_ref_point if scale_max_ref_point: return max_ref_point - scale * Y_range @@ -80,17 +94,28 @@ def infer_reference_point( # no points better than max_ref_point and only a single observation # subtract MIN_Y_RANGE to handle the case that pareto_Y is a singleton # with objective value of 0. - return (pareto_Y - scale * pareto_Y.abs().clamp_min(MIN_Y_RANGE)).view(-1) - # no points better than max_ref_point and multiple observations - # make sure that each dimension of the nadir point is no greater than - # the max_ref_point - nadir = pareto_Y.min(dim=0).values - if max_ref_point is not None: - nadir = torch.min(nadir, max_ref_point) - ideal = pareto_Y.max(dim=0).values - # handle case where all values for one objective are the same - Y_range = (ideal - nadir).clamp_min(MIN_Y_RANGE) - return nadir - scale * Y_range + Y_range = pareto_Y.abs().clamp_min(MIN_Y_RANGE).view(-1) + ref_point = pareto_Y.view(-1) - scale * Y_range + else: + # no points better than max_ref_point and multiple observations + # make sure that each dimension of the nadir point is no greater than + # the max_ref_point + nadir = pareto_Y.min(dim=0).values + if max_ref_point is not None: + nadir[non_nan_idx] = torch.min( + nadir[non_nan_idx], max_ref_point[non_nan_idx] + ) + ideal = pareto_Y.max(dim=0).values + # handle case where all values for one objective are the same + Y_range = (ideal - nadir).clamp_min(MIN_Y_RANGE) + ref_point = nadir - scale * Y_range + # Set not-nan indices - if any - to max_ref_point. + if non_nan_idx.any() and not non_nan_idx.all() and better_than_ref.any(): + if scale_max_ref_point: + ref_point[non_nan_idx] = (max_ref_point - scale * Y_range)[non_nan_idx] + else: + ref_point[non_nan_idx] = max_ref_point[non_nan_idx] + return ref_point class Hypervolume: diff --git a/test/utils/multi_objective/test_hypervolume.py b/test/utils/multi_objective/test_hypervolume.py index 9799e558a3..d59d97949f 100644 --- a/test/utils/multi_objective/test_hypervolume.py +++ b/test/utils/multi_objective/test_hypervolume.py @@ -243,3 +243,60 @@ def test_infer_reference_point(self): ref_point = infer_reference_point(pareto_Y=Y, scale=0.2) self.assertAllClose(ref_point, expected_ref_point) ref_point = infer_reference_point(pareto_Y=Y) + expected_ref_point = nadir - 0.1 * (ideal - nadir) + self.assertAllClose(ref_point, expected_ref_point) + + # Test all NaN max_ref_point. + ref_point = infer_reference_point( + pareto_Y=Y, + max_ref_point=torch.tensor([float("nan"), float("nan")], **tkwargs), + ) + self.assertAllClose(ref_point, expected_ref_point) + # Test partial NaN, partial worse than nadir. + expected_ref_point = nadir.clone() + expected_ref_point[1] = -1e5 + ref_point = infer_reference_point( + pareto_Y=Y, + max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs), + scale=0.0, + ) + self.assertAllClose(ref_point, expected_ref_point) + # Test partial NaN, partial better than nadir. + expected_ref_point = nadir + ref_point = infer_reference_point( + pareto_Y=Y, + max_ref_point=torch.tensor([float("nan"), 1e5], **tkwargs), + scale=0.0, + ) + self.assertAllClose(ref_point, expected_ref_point) + # Test partial NaN, partial worse than nadir with scale_max_ref_point. + expected_ref_point[1] = -1e5 + expected_ref_point = expected_ref_point - 0.2 * (ideal - expected_ref_point) + ref_point = infer_reference_point( + pareto_Y=Y, + max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs), + scale=0.2, + scale_max_ref_point=True, + ) + self.assertAllClose(ref_point, expected_ref_point) + # Test with single point in Pareto_Y, worse than ref point. + ref_point = infer_reference_point( + pareto_Y=Y[:1], + max_ref_point=torch.tensor([float("nan"), 1e5], **tkwargs), + ) + expected_ref_point = Y[0] - 0.1 * Y[0].abs() + self.assertTrue(torch.equal(expected_ref_point, ref_point)) + # Test with single point in Pareto_Y, better than ref point. + ref_point = infer_reference_point( + pareto_Y=Y[:1], + max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs), + scale_max_ref_point=True, + ) + expected_ref_point[1] = -1e5 - 0.1 * Y[0, 1].abs() + self.assertTrue(torch.equal(expected_ref_point, ref_point)) + # Empty pareto_Y with nan ref point. + with self.assertRaisesRegex(BotorchError, "ref point includes NaN"): + ref_point = infer_reference_point( + pareto_Y=Y[:0], + max_ref_point=torch.tensor([float("nan"), -1e5], **tkwargs), + )