Skip to content

Commit fc80fee

Browse files
Intron7meeseeksmachine
authored andcommitted
Backport PR #2589: Fixed wrong order for groups with logreg
1 parent ab2ba2d commit fc80fee

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

docs/release-notes/1.9.4.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
* Support scikit-learn 1.3 {pr}`2515` {smaller}`P Angerer`
77
* Deal with `None` value vanishing from things like `.uns['log1p']` {pr}`2546` {smaller}`SP Shen`
88
* Depend on `igraph` instead of `python-igraph` {pr}`2566` {smaller}`P Angerer`
9+
* {func}`~scanpy.tl.rank_genes_groups` now handles unsorted groups as intended {pr}`2589` {smaller}`S Dicks`

scanpy/tests/test_rank_genes_groups_logreg.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
import numpy as np
44
import scanpy as sc
5+
import pandas as pd
56

67

7-
@pytest.mark.parametrize(
8-
"method",
9-
["t-test", "logreg"],
10-
)
8+
@pytest.mark.parametrize('method', ['t-test', 'logreg'])
119
def test_rank_genes_groups_with_renamed_categories(method):
1210
adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200)
1311
assert np.allclose(adata.X[1], [9.214668, -2.6487126, 4.2020774, 0.51076424])
@@ -30,14 +28,34 @@ def test_rank_genes_groups_with_renamed_categories_use_rep():
3028
adata = sc.datasets.blobs(n_variables=4, n_centers=3, n_observations=200)
3129
assert np.allclose(adata.X[1], [9.214668, -2.6487126, 4.2020774, 0.51076424])
3230

33-
adata.layers["to_test"] = adata.X.copy()
31+
adata.layers['to_test'] = adata.X.copy()
3432
adata.X = adata.X[::-1, :]
3533

3634
sc.tl.rank_genes_groups(
37-
adata, 'blobs', method='logreg', layer="to_test", use_raw=False
35+
adata, 'blobs', method='logreg', layer='to_test', use_raw=False
3836
)
3937
assert adata.uns['rank_genes_groups']['names'].dtype.names == ('0', '1', '2')
4038
assert adata.uns['rank_genes_groups']['names'][0].tolist() == ('1', '3', '0')
4139

42-
sc.tl.rank_genes_groups(adata, 'blobs', method="logreg")
40+
sc.tl.rank_genes_groups(adata, 'blobs', method='logreg')
4341
assert not adata.uns['rank_genes_groups']['names'][0].tolist() == ('3', '1', '0')
42+
43+
44+
def test_rank_genes_groups_with_unsorted_groups():
45+
adata = sc.datasets.blobs(n_variables=10, n_centers=5, n_observations=200)
46+
adata._sanitize()
47+
adata.rename_categories('blobs', ['Zero', 'One', 'Two', 'Three', 'Four'])
48+
bdata = adata.copy()
49+
sc.tl.rank_genes_groups(
50+
adata, 'blobs', groups=['Zero', 'One', 'Three'], method='logreg'
51+
)
52+
sc.tl.rank_genes_groups(
53+
bdata, 'blobs', groups=['One', 'Three', 'Zero'], method='logreg'
54+
)
55+
array_ad = pd.DataFrame(
56+
adata.uns['rank_genes_groups']['scores']['Three']
57+
).to_numpy()
58+
array_bd = pd.DataFrame(
59+
bdata.uns['rank_genes_groups']['scores']['Three']
60+
).to_numpy()
61+
np.testing.assert_equal(array_ad, array_bd)

scanpy/tools/_rank_genes_groups.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,17 @@ def logreg(self, **kwds):
341341
clf = LogisticRegression(**kwds)
342342
clf.fit(X, self.grouping.cat.codes)
343343
scores_all = clf.coef_
344-
for igroup, _ in enumerate(self.groups_order):
344+
# not all codes necessarily appear in data
345+
existing_codes = np.unique(self.grouping.cat.codes)
346+
for igroup, cat in enumerate(self.groups_order):
345347
if len(self.groups_order) <= 2: # binary logistic regression
346348
scores = scores_all[0]
347349
else:
348-
scores = scores_all[igroup]
349-
350+
# cat code is index of cat value in .categories
351+
cat_code: int = np.argmax(self.grouping.cat.categories == cat)
352+
# index of scores row is index of cat code in array of existing codes
353+
scores_idx: int = np.argmax(existing_codes == cat_code)
354+
scores = scores_all[scores_idx]
350355
yield igroup, scores, None
351356

352357
if len(self.groups_order) <= 2:

0 commit comments

Comments
 (0)