diff --git a/test/test_utils.py b/test/test_utils.py
index 49dc553de3e..ffcad425aeb 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -432,6 +432,22 @@ def test_draw_keypoints_visibility_default():
     assert_equal(result, expected)
 
 
+def test_draw_keypoints_dtypes():
+    image_uint8 = torch.randint(0, 256, size=(3, 100, 100), dtype=torch.uint8)
+    image_float = to_dtype(image_uint8, torch.float, scale=True)
+
+    out_uint8 = utils.draw_keypoints(image_uint8, keypoints)
+    out_float = utils.draw_keypoints(image_float, keypoints)
+
+    assert out_uint8.dtype == torch.uint8
+    assert out_uint8 is not image_uint8
+
+    assert out_float.is_floating_point()
+    assert out_float is not image_float
+
+    torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True), rtol=0, atol=1)
+
+
 def test_draw_keypoints_errors():
     h, w = 10, 10
     img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
diff --git a/torchvision/utils.py b/torchvision/utils.py
index 79e533d4663..734cb127db1 100644
--- a/torchvision/utils.py
+++ b/torchvision/utils.py
@@ -336,13 +336,13 @@ def draw_keypoints(
 
     """
     Draws Keypoints on given RGB image.
-    The values of the input image should be uint8 between 0 and 255.
+    The image values should be uint8 in [0, 255] or float in [0, 1].
     Keypoints can be drawn for multiple instances at a time.
 
     This method allows that keypoints and their connectivity are drawn based on the visibility of this keypoint.
 
     Args:
-        image (Tensor): Tensor of shape (3, H, W) and dtype uint8.
+        image (Tensor): Tensor of shape (3, H, W) and dtype uint8 or float.
         keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoint locations for each of the N instances,
             in the format [x, y].
         connectivity (List[Tuple[int, int]]]): A List of tuple where each tuple contains a pair of keypoints
@@ -363,7 +363,7 @@ def draw_keypoints(
             For more details, see :ref:`draw_keypoints_with_visibility`.
 
     Returns:
-        img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn.
+        img (Tensor[C, H, W]): Image Tensor with keypoints drawn.
     """
 
     if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@@ -371,8 +371,8 @@ def draw_keypoints(
     # validate image
     if not isinstance(image, torch.Tensor):
         raise TypeError(f"The image must be a tensor, got {type(image)}")
-    elif image.dtype != torch.uint8:
-        raise ValueError(f"The image dtype must be uint8, got {image.dtype}")
+    elif not (image.dtype == torch.uint8 or image.is_floating_point()):
+        raise ValueError(f"The image dtype must be uint8 or float, got {image.dtype}")
     elif image.dim() != 3:
         raise ValueError("Pass individual images, not batches")
     elif image.size()[0] != 3:
@@ -397,6 +397,12 @@ def draw_keypoints(
             f"Got {visibility.shape = } and {keypoints.shape = }"
         )
 
+    original_dtype = image.dtype
+    if original_dtype.is_floating_point:
+        from torchvision.transforms.v2.functional import to_dtype  # noqa
+
+        image = to_dtype(image, dtype=torch.uint8, scale=True)
+
     ndarr = image.permute(1, 2, 0).cpu().numpy()
     img_to_draw = Image.fromarray(ndarr)
     draw = ImageDraw.Draw(img_to_draw)
@@ -428,7 +434,10 @@ def draw_keypoints(
                     width=width,
                 )
 
-    return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8)
+    out = torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)
+    if original_dtype.is_floating_point:
+        out = to_dtype(out, dtype=original_dtype, scale=True)
+    return out
 
 
 # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization