Skip to content

Commit

Permalink
simulate cost matrices in simulate_data (#671)
Browse files Browse the repository at this point in the history
* simulate cost matrices

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* mypy and ruff fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* mypy fix

* pairwise linear cms

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Arina Danilina <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent 7f1530d commit 9667dc7
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/moscot/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import shutil
import tempfile
import urllib.request
from itertools import combinations
from types import MappingProxyType
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple

import networkx as nx
import numpy as np
import pandas as pd
from scipy.linalg import block_diag

import anndata as ad
from anndata import AnnData
Expand Down Expand Up @@ -411,7 +413,8 @@ def simulate_data(
Literal indicating whether to add costs corresponding to a specific problem setting.
If `None`, no quadratic cost element is generated.
lin_cost_matrix
Key where to save the linear cost matrix. If `None`, no linear cost matrix is generated.
Key where to save the linear cost matrix. It is generated according to the pairwise policy.
If `None`, no linear cost matrix is generated.
quad_cost_matrix
Key where to save the quadratic cost matrices. If `None`, no quadratic cost matrix is generated.
Expand Down Expand Up @@ -457,9 +460,18 @@ def simulate_data(
barcode_dim = kwargs.pop("barcode_dim", 10)
adata.obsm["barcode"] = rng.choice(n_intBCs, size=(adata.n_obs, barcode_dim))
if lin_cost_matrix is not None:
raise NotImplementedError("TODO")
adata.uns[lin_cost_matrix] = {}
for i, j in combinations(range(n_distributions), 2):
adata.uns[lin_cost_matrix][(str(i), str(j))] = np.abs(
rng.normal(size=(cells_per_distribution, cells_per_distribution))
)
if quad_cost_matrix is not None:
raise NotImplementedError("TODO")
quad_costs = [
(np.abs(rng.normal(size=(cells_per_distribution, cells_per_distribution)))) for i in range(n_distributions)
]
quad_cm = block_diag(*quad_costs)
np.fill_diagonal(quad_cm, 0)
adata.obsp[quad_cost_matrix] = quad_cm
return adata


Expand Down

0 comments on commit 9667dc7

Please sign in to comment.