Skip to content

Commit

Permalink
adds option to sample random dict, e.g. {'YRI': 3, 'CEU': 5, 'CHB': 2…
Browse files Browse the repository at this point in the history
…} from sampling_populations
  • Loading branch information
kevinkorfmann committed Feb 3, 2025
1 parent 841eb0b commit 1e74d19
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
17 changes: 17 additions & 0 deletions stdpopsim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import textwrap
import random

import msprime
import warnings
Expand Down Expand Up @@ -290,6 +291,22 @@ def get_samples(self, *args):
)
return samples

def random_sample_counts(self, n, seed=None):
"""
Randomly distributes `n` samples across the given sampling populations.
:param int n: Number of samples to draw.
:param int seed: Optional seed for reproducibility.
:return: Dictionary mapping population names to sample counts.
:rtype: dict
"""
rng = random.Random(seed)
sampled_counts = {pop.name: 0 for pop in self.sampling_populations}
for pop in rng.choices(self.sampling_populations, k=n):
sampled_counts[pop.name] += 1
return sampled_counts


class PiecewiseConstantSize(DemographicModel):
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,26 @@ def test_get_sample_sets(self):
sample_populations = [i.population for i in test_samples]
assert sample_populations == [0, 1]

def test_random_sample_counts(self):
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfricaArchaicAdmixture_5R19")
counts = model.random_sample_counts(n=10, seed=10)
assert counts == {
"YRI": 3,
"CEU": 5,
"CHB": 2,
}, f"Expected {{'YRI': 3, 'CEU': 5, 'CHB': 2}}, but got {counts}"

def test_random_sample_counts_varied(self):
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("OutOfAfricaArchaicAdmixture_5R19")
ns = [5, 15, 20, 25, 30]
for n in ns:
counts = model.random_sample_counts(n=n)
assert (
sum(counts.values()) == n
), f"Expected sum {n}, but got {sum(counts.values())}"

@pytest.mark.filterwarnings("ignore::stdpopsim.DeprecatedFeatureWarning")
def test_deprecated_get_samples(self):
base_mod = self.make_model()
Expand Down

0 comments on commit 1e74d19

Please sign in to comment.