From 00e75e1de6af736cb5b5a9e875269087d9180fc0 Mon Sep 17 00:00:00 2001 From: Moritz Schaefer Date: Fri, 21 Feb 2020 13:33:59 +0100 Subject: [PATCH] Preserve order of samples/classes/labels for PCA plot visualization (plot_pca_2d_projection) --- scikitplot/decomposition.py | 5 ++++- scikitplot/tests/test_decomposition.py | 27 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/scikitplot/decomposition.py b/scikitplot/decomposition.py index d3b28e3..f88204f 100644 --- a/scikitplot/decomposition.py +++ b/scikitplot/decomposition.py @@ -163,7 +163,10 @@ def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection', fig, ax = plt.subplots(1, 1, figsize=figsize) ax.set_title(title, fontsize=title_fontsize) - classes = np.unique(np.array(y)) + + # Get unique classes from y, preserving order of class occurence in y + _, class_indexes = np.unique(np.array(y), return_index=True) + classes = np.array(y)[np.sort(class_indexes)] colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(classes))) diff --git a/scikitplot/tests/test_decomposition.py b/scikitplot/tests/test_decomposition.py index f7e555b..3c3e7af 100644 --- a/scikitplot/tests/test_decomposition.py +++ b/scikitplot/tests/test_decomposition.py @@ -9,6 +9,7 @@ from scikitplot.decomposition import plot_pca_component_variance from scikitplot.decomposition import plot_pca_2d_projection +import scikitplot class TestPlotPCAComponentVariance(unittest.TestCase): @@ -81,3 +82,29 @@ def test_biplot(self): clf.fit(self.X) ax = plot_pca_2d_projection(clf, self.X, self.y, biplot=True, feature_labels=load_data().feature_names) + + def test_label_order(self): + ''' + Plot labels should be in the same order as the classes in the provided y-array + ''' + np.random.seed(0) + clf = PCA() + clf.fit(self.X) + + # define y such that the first entry is 1 + y = np.copy(self.y) + y[0] = 1 # load_iris is be default orderer (i.e.: 0 0 0 ... 1 1 1 ... 2 2 2) + + # test with len(y) == X.shape[0] with multiple rows belonging to the same class + ax = plot_pca_2d_projection(clf, self.X, y, cmap='Spectral') + legend_labels = ax.get_legend_handles_labels()[1] + self.assertListEqual(['1', '0', '2'], legend_labels) + + # test with len(y) == #classes with each row belonging to an individual class + y = list(range(len(y))) + np.random.shuffle(y) + ax = plot_pca_2d_projection(clf, self.X, y, cmap='Spectral') + legend_labels = ax.get_legend_handles_labels()[1] + self.assertListEqual([str(v) for v in y], legend_labels) + +