diff --git a/muon/_core/plot.py b/muon/_core/plot.py
index 752518c..52be000 100644
--- a/muon/_core/plot.py
+++ b/muon/_core/plot.py
@@ -1,4 +1,4 @@
-from typing import Union, List, Optional, Iterable, Sequence, Dict
+from typing import Dict, Iterable, List, Optional, Sequence, Union
 import warnings
 
 from matplotlib.axes import Axes
@@ -22,7 +22,7 @@ def scatter(
     data: Union[AnnData, MuData],
     x: Optional[str] = None,
     y: Optional[str] = None,
-    color: Optional[str] = None,
+    color: Optional[Union[str, Sequence[str]]] = None,
     use_raw: Optional[bool] = None,
     layers: Optional[Union[str, Sequence[str]]] = None,
     **kwargs,
@@ -42,8 +42,8 @@ def scatter(
         x coordinate
     y : Optional[str]
         y coordinate
-    color : Optional[str], optional (default: None)
-        Key for variables or annotations of observations (.obs columns),
+    color : Optional[Union[str, Sequence[str]]], optional (default: None)
+        Keys or a single key for variables or annotations of observations (.obs columns),
         or a hex colour specification.
     use_raw : Optional[bool], optional (default: None)
         Use `.raw` attribute of the modality where a feature (from `color`) is derived from.
@@ -71,7 +71,7 @@ def scatter(
             color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
             color_obs = pd.DataFrame({color: color_obs})
         else:
-            raise TypeError("Expected color to be a string.")
+            color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
 
         color_obs.index = data.obs_names
         obs = pd.concat([obs, color_obs], axis=1, ignore_index=False)