diff --git a/fiftyone/brain/__init__.py b/fiftyone/brain/__init__.py index 19a9ad55..83c62593 100644 --- a/fiftyone/brain/__init__.py +++ b/fiftyone/brain/__init__.py @@ -253,6 +253,7 @@ def compute_visualization( batch_size=None, num_workers=None, skip_failures=True, + labels=None, **kwargs, ): """Computes a low-dimensional representation of the samples' media or their @@ -358,6 +359,7 @@ def compute_visualization( batch_size, num_workers, skip_failures, + labels, **kwargs, ) diff --git a/fiftyone/brain/internal/core/visualization.py b/fiftyone/brain/internal/core/visualization.py index d972caf8..ad0f731f 100644 --- a/fiftyone/brain/internal/core/visualization.py +++ b/fiftyone/brain/internal/core/visualization.py @@ -51,6 +51,7 @@ def compute_visualization( batch_size, num_workers, skip_failures, + labels, **kwargs, ): """See ``fiftyone/brain/__init__.py``.""" @@ -104,7 +105,10 @@ def compute_visualization( ) logger.info("Generating visualization...") - points = brain_method.fit(embeddings) + if labels: + points = brain_method.fit(embeddings, labels=labels) + else: + points = brain_method.fit(embeddings) results = VisualizationResults(samples, config, points) brain_method.save_run_results(samples, brain_key, results) @@ -142,7 +146,7 @@ def ensure_requirements(self): ), ) - def fit(self, embeddings): + def fit(self, embeddings, labels=None): _umap = umap.UMAP( n_components=self.config.num_dims, n_neighbors=self.config.num_neighbors, @@ -151,7 +155,26 @@ def fit(self, embeddings): random_state=self.config.seed, verbose=self.config.verbose, ) - return _umap.fit_transform(embeddings) + + if labels is None: + return _umap.fit_transform(embeddings) + else: + categories = {} + curr_cat = 0 + fit_labels = [] + fit_embeddings = [] + for i, l in enumerate(labels): + if l is None: + continue + + if l not in categories: + categories[l] = curr_cat + curr_cat += 1 + + fit_labels.append(categories[l]) + fit_embeddings.append(embeddings[i]) + _umap.fit(np.array(fit_embeddings), y=fit_labels) + return _umap.transform(embeddings) class TSNEVisualization(Visualization):