Skip to content

Commit

Permalink
implement keep_intervals method
Browse files Browse the repository at this point in the history
  • Loading branch information
bguo068 committed May 29, 2024
1 parent 29c1237 commit 7755a02
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 0 deletions.
172 changes: 172 additions & 0 deletions src/table_collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ use crate::metadata::SiteMetadata;
use crate::sys::bindings as ll_bindings;
use crate::sys::TableCollection as LLTableCollection;
use crate::types::Bookmark;
use crate::EdgeTable;
use crate::IndividualTableSortOptions;
use crate::MigrationId;
use crate::MigrationTable;
use crate::MutationId;
use crate::MutationTable;
use crate::PopulationId;
use crate::Position;
use crate::SimplificationOptions;
use crate::SiteId;
use crate::SiteTable;
use crate::TableClearOptions;
use crate::TableEqualityOptions;
use crate::TableIntegrityCheckFlags;
Expand Down Expand Up @@ -1372,4 +1376,172 @@ impl TableCollection {
pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_table_collection_t {
self.inner.as_mut_ptr()
}

/// Truncate the Table Collection to specified genome intervals.
///
/// # Error
/// Any errors from the C API propagate. An [TskitError::RangeError] will
/// occur when `intervals` are not sorted.
///
/// # Example
/// ```rust
/// use tskit::TreeSequence;
/// let mut tables = TreeSequence::load("./testdata/1.trees")
/// .unwrap()
/// .dump_tables()
/// .unwrap();
/// tables
/// .keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true)
/// .unwrap();
/// ```
///
/// Note that no new provenance will be appended.
pub fn keep_intervals(
&mut self,
intervals: impl Iterator<Item = (Position, Position)>,
simplify: bool,
) -> Result<(), TskitError> {
use streaming_iterator::StreamingIterator;

// use tables from sys to allow easier process with metadata
let options = 0;
let mut new_edges = crate::sys::EdgeTable::new(options)?;
let mut new_migrations = crate::sys::MigrationTable::new(options)?;
let mut new_sites = crate::sys::SiteTable::new(options)?;
let mut new_mutations = crate::sys::MutationTable::new(options)?;

// for old site id to new site id mapping
let mut site_map = vec![-1i32; self.sites().num_rows().as_usize()];

// logicals to indicate whether a site (old) will be kept in new site table
let mut keep_sites = vec![false; self.sites().num_rows().try_into()?];

let mut last_interval = (Position::from(0.0), Position::from(0.0));
for (s, e) in intervals {
// make sure intervals are sorted
if (s > e) || (s < last_interval.1) {
return Err(TskitError::RangeError(
"intervals not valid or sorted".into(),
));
}
keep_sites
.iter_mut()
.zip(self.sites_iter())
.for_each(|(k, site_row)| {
*k = *k || ((site_row.position >= s) && (site_row.position < e));
});

// use stream_iter and while-let pattern for easier ? operator within a loop
let mut edge_iter = self
.edges()
.lending_iter()
.filter(|edge_row| !((edge_row.right <= s) || (edge_row.left >= e)));

while let Some(edge_row) = edge_iter.next() {
new_edges.add_row_with_metadata(
if edge_row.left < s { s } else { edge_row.left }.into(),
if edge_row.right > e {
e
} else {
edge_row.right
}
.into(),
edge_row.parent.into(),
edge_row.child.into(),
edge_row.metadata.unwrap_or(&[0u8; 0]),
)?;
}

let mut migration_iter = self
.migrations()
.lending_iter()
.filter(|mrow| !!((mrow.right <= s) || (mrow.left >= e)));

while let Some(migration_row) = migration_iter.next() {
new_migrations.add_row_with_metadata(
(migration_row.left.into(), migration_row.right.into()),
migration_row.node.into(),
migration_row.source.into(),
migration_row.dest.into(),
migration_row.time.into(),
migration_row.metadata.unwrap_or(&[0u8; 0]),
)?;
}
last_interval = (s, e);
}

let mut running_site_id = 0;
let mut site_iter = self.sites().lending_iter();
while let Some(site_row) = site_iter.next() {
let old_id = site_row.id.to_usize().unwrap();
if keep_sites[old_id] {
new_sites.add_row_with_metadata(
site_row.position.into(),
site_row.ancestral_state,
site_row.metadata.unwrap_or(&[0u8; 0]),
)?;
site_map[old_id] = running_site_id;
running_site_id += 1;
}
}

// build mutation_map
let mutation_map = {
let mut v = Vec::with_capacity(keep_sites.len());
let mut n = 0;
self.mutations().site_slice().iter().for_each(|site| {
if keep_sites[site.as_usize()] {
n += 1
};
v.push(n - 1);
});
v
};

let mut mutations_iter = self.mutations().lending_iter();
while let Some(mutation_row) = mutations_iter.next() {
let old_id = mutation_row.site.to_usize().unwrap();
if keep_sites[old_id] {
let new_site = site_map[old_id];
let new_parent = {
if mutation_row.parent.is_null() {
mutation_row.parent.into()
} else {
mutation_map[mutation_row.parent.as_usize()]
}
};
new_mutations.add_row_with_metadata(
new_site,
mutation_row.node.into(),
new_parent,
mutation_row.time.into(),
mutation_row.derived_state,
mutation_row.metadata.unwrap_or(&[0u8; 0]),
)?;
}
}

// convert sys version of tables to non-sys version of tables
let new_edges = EdgeTable::new_from_table(new_edges.as_mut())?;
let new_migrations = MigrationTable::new_from_table(new_migrations.as_mut())?;
let new_mutations = MutationTable::new_from_table(new_mutations.as_mut())?;
let new_sites = SiteTable::new_from_table(new_sites.as_mut())?;

// replace old tables with new tables
self.set_edges(&new_edges).map(|_| ())?;
self.set_migrations(&new_migrations).map(|_| ())?;
self.set_mutations(&new_mutations).map(|_| ())?;
self.set_sites(&new_sites)?;

// sort tables
self.full_sort(TableSortOptions::default())?;

// simplify tables
if simplify {
let samples = self.samples_as_vector();
self.simplify(samples.as_slice(), SimplificationOptions::default(), false)?;
}

Ok(())
}
}
24 changes: 24 additions & 0 deletions src/trees/treeseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,30 @@ impl TreeSequence {
))
}

/// Truncate the TreeSequence to specified genome intervals.
///
/// # Error
/// Any errors from the C API propagate. An [TskitError::RangeError] will
/// occur when `intervals` are not sorted.
///
/// # Example
/// ```rust
/// use tskit::TreeSequence;
/// let mut ts = TreeSequence::load("testdata/1.trees").expect("error loading ts");
/// let new_ts = ts.keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true).unwrap();
/// ```
///
/// Note that no new provenance will be appended.
pub fn keep_intervals(
self,
intervals: impl Iterator<Item = (Position, Position)>,
simplify: bool,
) -> Result<Self, TskitError> {
let mut tables = self.dump_tables()?;
tables.keep_intervals(intervals, simplify)?;
Self::new(tables, TreeSequenceFlags::default().build_indexes())
}

#[cfg(feature = "provenance")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))]
/// Add provenance record with a time stamp.
Expand Down
Binary file added testdata/1.trees
Binary file not shown.
Binary file added testdata/2.trees
Binary file not shown.
Binary file added testdata/3.trees
Binary file not shown.
27 changes: 27 additions & 0 deletions testdata/gen_trees.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import msprime
import tskit

demography = msprime.Demography()
demography.add_population(name="A", initial_size=10_000)
demography.add_population(name="B", initial_size=5_000)
demography.add_population(name="C", initial_size=1_000)
# demography.set_migration_rate("A", "B", 0.01)
demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C")
ts1: tskit.TreeSequence = msprime.sim_ancestry(
samples={"A": 1, "B": 1},
demography=demography,
random_seed=12,
recombination_rate=0.00001,
sequence_length=300.0,
# record_migrations=True
)
ts1 = msprime.sim_mutations(ts1, rate=0.0001, random_seed=12)

ts2 = ts1.keep_intervals([(10, 130)], record_provenance=False)
ts3 = ts1.keep_intervals([(10, 40), (100, 200)], record_provenance=False)


ts1.dump("testdata/1.trees")
ts2.dump("testdata/2.trees")
ts3.dump("testdata/3.trees")

64 changes: 64 additions & 0 deletions tests/test_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,67 @@ mod test_metadata_round_trips {
);
}
}

#[test]
fn test_keep_intervals() {
use tskit::{TableEqualityOptions, TreeSequence};

// test on tables
let mut tables1 = TreeSequence::load("./testdata/1.trees")
.unwrap()
.dump_tables()
.unwrap();
// run keep_intervals
tables1
.keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true)
.unwrap();
// expected tree sequences
let tables2 = TreeSequence::load("./testdata/2.trees")
.unwrap()
.dump_tables()
.unwrap();

assert!(tables1.equals(&tables2, TableEqualityOptions::all()));

// test on tables
let mut tables1 = TreeSequence::load("./testdata/1.trees")
.unwrap()
.dump_tables()
.unwrap();
// run keep_intervals
tables1
.keep_intervals(
vec![(10.0.into(), 40.0.into()), (100.0.into(), 200.0.into())].into_iter(),
true,
)
.unwrap();
// expected tree sequences
let tables2 = TreeSequence::load("./testdata/3.trees")
.unwrap()
.dump_tables()
.unwrap();

assert!(tables1.equals(&tables2, TableEqualityOptions::all()));

// test on treeseq
let ts1 = TreeSequence::load("./testdata/1.trees")
.unwrap()
.keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true)
.unwrap();
let ts2 = TreeSequence::load("./testdata/2.trees").unwrap();
ts1.dump_tables()
.unwrap()
.equals(&ts2.dump_tables().unwrap(), TableEqualityOptions::all());

let ts1 = TreeSequence::load("./testdata/1.trees")
.unwrap()
.keep_intervals(
vec![(10.0.into(), 40.0.into()), (100.0.into(), 200.0.into())].into_iter(),
true,
)
.unwrap();
let ts3 = TreeSequence::load("./testdata/3.trees").unwrap();
ts1.dump_tables()
.unwrap()
.equals(&ts3.dump_tables().unwrap(), TableEqualityOptions::all());
}

0 comments on commit 7755a02

Please sign in to comment.