diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 7728d66cc..a308d5867 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -5,12 +5,14 @@ import shutil import tempfile import urllib.request +from itertools import combinations from types import MappingProxyType from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple import networkx as nx import numpy as np import pandas as pd +from scipy.linalg import block_diag import anndata as ad from anndata import AnnData @@ -411,7 +413,8 @@ def simulate_data( Literal indicating whether to add costs corresponding to a specific problem setting. If `None`, no quadratic cost element is generated. lin_cost_matrix - Key where to save the linear cost matrix. If `None`, no linear cost matrix is generated. + Key where to save the linear cost matrix. It is generated according to the pairwise policy. + If `None`, no linear cost matrix is generated. quad_cost_matrix Key where to save the quadratic cost matrices. If `None`, no quadratic cost matrix is generated. @@ -457,9 +460,18 @@ def simulate_data( barcode_dim = kwargs.pop("barcode_dim", 10) adata.obsm["barcode"] = rng.choice(n_intBCs, size=(adata.n_obs, barcode_dim)) if lin_cost_matrix is not None: - raise NotImplementedError("TODO") + adata.uns[lin_cost_matrix] = {} + for i, j in combinations(range(n_distributions), 2): + adata.uns[lin_cost_matrix][(str(i), str(j))] = np.abs( + rng.normal(size=(cells_per_distribution, cells_per_distribution)) + ) if quad_cost_matrix is not None: - raise NotImplementedError("TODO") + quad_costs = [ + (np.abs(rng.normal(size=(cells_per_distribution, cells_per_distribution)))) for i in range(n_distributions) + ] + quad_cm = block_diag(*quad_costs) + np.fill_diagonal(quad_cm, 0) + adata.obsp[quad_cost_matrix] = quad_cm return adata