From e30edc821539e3266555c5c7d46d2ecf04a6a520 Mon Sep 17 00:00:00 2001 From: Bing Guo Date: Mon, 27 May 2024 23:01:11 -0400 Subject: [PATCH] feat: added keep intervals for tables and tree sequences --- src/table_collection.rs | 206 ++++++++++++++++ src/test_fixtures.rs | 511 ++++++++++++++++++++++++++++++++++++++++ src/trees/treeseq.rs | 60 +++++ 3 files changed, 777 insertions(+) diff --git a/src/table_collection.rs b/src/table_collection.rs index f9d6f9df7..117d79e7e 100644 --- a/src/table_collection.rs +++ b/src/table_collection.rs @@ -10,13 +10,17 @@ use crate::metadata::SiteMetadata; use crate::sys::bindings as ll_bindings; use crate::sys::TableCollection as LLTableCollection; use crate::types::Bookmark; +use crate::EdgeTable; use crate::IndividualTableSortOptions; use crate::MigrationId; +use crate::MigrationTable; use crate::MutationId; +use crate::MutationTable; use crate::PopulationId; use crate::Position; use crate::SimplificationOptions; use crate::SiteId; +use crate::SiteTable; use crate::TableClearOptions; use crate::TableEqualityOptions; use crate::TableIntegrityCheckFlags; @@ -1372,4 +1376,206 @@ impl TableCollection { pub fn as_mut_ptr(&mut self) -> *mut ll_bindings::tsk_table_collection_t { self.inner.as_mut_ptr() } + + /// Truncate the [TableCollection] to specified genome intervals. + /// + /// # Return + /// - `Ok(None)`: when truncation leads to empty edge table. + /// - `Ok(Some(TableCollection))`: when trunction is successfully performed + /// and results in non-empty edge table. + /// - `Error(TskitError)`: Any errors from the C API propagate. An + /// [TskitError::RangeError] will occur when `intervals` are not + /// sorted. Note that as `tskit` currently does not support `simplify` + /// on [TableCollection] with a non-empty migration table, calling + /// `keep_intervals` on those [TableCollection] with `simplify` set to + /// `true` will return an error. + /// + /// # Example + /// ```rust + /// # use tskit::*; + /// # let snode = NodeFlags::new_sample(); + /// # let anode = NodeFlags::default(); + /// # let pop = PopulationId::NULL; + /// # let ind = IndividualId::NULL; + /// # let seqlen = 100.0; + /// # let (t0, t10) = (0.0, 10.0); + /// # let (left, right) = (0.0, 100.0); + /// # let sim_opts = SimplificationOptions::default(); + /// # + /// # let mut tables = TableCollection::new(seqlen).unwrap(); + /// # let child1 = tables.add_node(snode, t0, pop, ind).unwrap(); + /// # let child2 = tables.add_node(snode, t0, pop, ind).unwrap(); + /// # let parent = tables.add_node(anode, t10, pop, ind).unwrap(); + /// # + /// # tables.add_edge(left, right, parent, child1).unwrap(); + /// # tables.add_edge(left, right, parent, child2).unwrap(); + /// # tables.full_sort(TableSortOptions::all()).unwrap(); + /// # tables.simplify(&[child1, child2], sim_opts, false).unwrap(); + /// # tables.build_index().unwrap(); + /// # + /// let intervals = [(0.0, 10.0), (90.0, 100.0)].into_iter(); + /// tables.keep_intervals(intervals, true).unwrap().unwrap(); + /// ``` + /// + /// Note that no new provenance will be appended. + pub fn keep_intervals

( + self, + intervals: impl Iterator, + simplify: bool, + ) -> Result, TskitError> + where + P: Into, + { + use streaming_iterator::StreamingIterator; + let mut tables = self; + // use tables from sys to allow easier process with metadata + let options = 0; + let mut new_edges = crate::sys::EdgeTable::new(options)?; + let mut new_migrations = crate::sys::MigrationTable::new(options)?; + let mut new_sites = crate::sys::SiteTable::new(options)?; + let mut new_mutations = crate::sys::MutationTable::new(options)?; + + // for old site id to new site id mapping + let mut site_map = vec![-1i32; tables.sites().num_rows().as_usize()]; + + // logicals to indicate whether a site (old) will be kept in new site table + let mut keep_sites = vec![false; tables.sites().num_rows().try_into()?]; + + let mut last_interval = (Position::from(0.0), Position::from(0.0)); + for (s, e) in intervals { + let (s, e) = (s.into(), e.into()); + // make sure intervals are sorted + if (s > e) || (s < last_interval.1) { + return Err(TskitError::RangeError( + "intervals not valid or sorted".into(), + )); + } + keep_sites + .iter_mut() + .zip(tables.sites_iter()) + .for_each(|(k, site_row)| { + *k = *k || ((site_row.position >= s) && (site_row.position < e)); + }); + + // use stream_iter and while-let pattern for easier ? operator within a loop + let mut edge_iter = tables + .edges() + .lending_iter() + .filter(|edge_row| !((edge_row.right <= s) || (edge_row.left >= e))); + + while let Some(edge_row) = edge_iter.next() { + new_edges.add_row_with_metadata( + if edge_row.left < s { s } else { edge_row.left }.into(), + if edge_row.right > e { + e + } else { + edge_row.right + } + .into(), + edge_row.parent.into(), + edge_row.child.into(), + edge_row.metadata.unwrap_or(&[0u8; 0]), + )?; + } + + let mut migration_iter = tables + .migrations() + .lending_iter() + .filter(|mrow| !((mrow.right <= s) || (mrow.left >= e))); + + while let Some(migration_row) = migration_iter.next() { + new_migrations.add_row_with_metadata( + (migration_row.left.into(), migration_row.right.into()), + migration_row.node.into(), + migration_row.source.into(), + migration_row.dest.into(), + migration_row.time.into(), + migration_row.metadata.unwrap_or(&[0u8; 0]), + )?; + } + last_interval = (s, e); + } + + let mut running_site_id = 0; + let mut site_iter = tables.sites().lending_iter(); + while let Some(site_row) = site_iter.next() { + let old_id = site_row.id.to_usize().unwrap(); + if keep_sites[old_id] { + new_sites.add_row_with_metadata( + site_row.position.into(), + site_row.ancestral_state, + site_row.metadata.unwrap_or(&[0u8; 0]), + )?; + site_map[old_id] = running_site_id; + running_site_id += 1; + } + } + + // build mutation_map + let mutation_map: Vec<_> = { + let mut n = 0; + tables + .mutations() + .site_slice() + .iter() + .map(|site| { + if keep_sites[site.as_usize()] { + n += 1 + }; + n - 1 + }) + .collect() + }; + + let mut mutations_iter = tables.mutations().lending_iter(); + while let Some(mutation_row) = mutations_iter.next() { + let old_id = mutation_row.site.to_usize().unwrap(); + if keep_sites[old_id] { + let new_site = site_map[old_id]; + let new_parent = { + if mutation_row.parent.is_null() { + mutation_row.parent.into() + } else { + mutation_map[mutation_row.parent.as_usize()] + } + }; + new_mutations.add_row_with_metadata( + new_site, + mutation_row.node.into(), + new_parent, + mutation_row.time.into(), + mutation_row.derived_state, + mutation_row.metadata.unwrap_or(&[0u8; 0]), + )?; + } + } + + // convert sys version of tables to non-sys version of tables + let new_edges = EdgeTable::new_from_table(new_edges.as_mut())?; + let new_migrations = MigrationTable::new_from_table(new_migrations.as_mut())?; + let new_mutations = MutationTable::new_from_table(new_mutations.as_mut())?; + let new_sites = SiteTable::new_from_table(new_sites.as_mut())?; + + // replace old tables with new tables + tables.set_edges(&new_edges).map(|_| ())?; + tables.set_migrations(&new_migrations).map(|_| ())?; + tables.set_mutations(&new_mutations).map(|_| ())?; + tables.set_sites(&new_sites)?; + + // sort tables + tables.full_sort(TableSortOptions::default())?; + + // simplify tables + if simplify { + let samples = tables.samples_as_vector(); + tables.simplify(samples.as_slice(), SimplificationOptions::default(), false)?; + } + + // return None when edge table is empty + if tables.edges().num_rows() == 0 { + Ok(None) + } else { + Ok(Some(tables)) + } + } } diff --git a/src/test_fixtures.rs b/src/test_fixtures.rs index cbff1faaf..0cbce82ca 100644 --- a/src/test_fixtures.rs +++ b/src/test_fixtures.rs @@ -83,3 +83,514 @@ pub mod bad_metadata { impl crate::metadata::MutationMetadata for F {} impl crate::metadata::MutationMetadata for Ff {} } + +/// mimic the c simulate function in tskit c api document +/// https://tskit.dev/tskit/docs/stable/c-api.html#basic-forwards-simulator +#[cfg(test)] +pub mod simulation { + use core::panic; + + use crate::{ + metadata::{MetadataError, MetadataRoundtrip, PopulationMetadata}, + EdgeId, IndividualId, MutationId, NodeFlags, NodeId, PopulationId, Position, + SimplificationOptions, SiteId, TableCollection, TableSortOptions, TreeSequence, + TreeSequenceFlags, TskitError, + }; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + struct MyMeta { + inner: String, + } + impl From for MyMeta { + fn from(value: String) -> Self { + MyMeta { inner: value } + } + } + impl<'a> From<&'a str> for MyMeta { + fn from(value: &'a str) -> Self { + MyMeta { + inner: value.to_owned(), + } + } + } + + // helper structs, impls and functions + + impl MetadataRoundtrip for MyMeta { + fn encode(&self) -> Result, MetadataError> { + Ok(self.inner.as_bytes().to_owned()) + } + fn decode(md: &[u8]) -> Result + where + Self: Sized, + { + Ok(MyMeta { + inner: String::from_utf8(md.to_owned()).unwrap(), + }) + } + } + + impl PopulationMetadata for MyMeta {} + + fn add_pop(tables: &mut TableCollection, name: &str) -> PopulationId { + tables + .add_population_with_metadata(&MyMeta::from(name)) + .unwrap() + } + + fn add_ind( + tables: &mut TableCollection, + parent1: (NodeId, NodeId), + parent2: (NodeId, NodeId), + ) -> IndividualId { + let parent1_ind = tables.nodes().individual(parent1.0).unwrap(); + let parent2_ind = tables.nodes().individual(parent2.0).unwrap(); + let flags = 0u32; + let loc_null = None; + tables + .add_individual(flags, loc_null, [parent1_ind, parent2_ind]) + .unwrap() + } + + fn find_parent( + rng: &mut StdRng, + parents: &[(NodeId, NodeId)], + child_pop: PopulationId, + ) -> ((NodeId, NodeId), PopulationId) { + assert_eq!(parents.len() % 2, 0); + let (pop_anc, pop_1, pop_2) = (0, 1, 2); + let child_pop: i32 = child_pop.into(); + + let pop_size = parents.len(); + let mut parent_pop = child_pop; + + let is_migrant = (child_pop != pop_anc) && rng.gen_bool(0.01); + if is_migrant { + parent_pop = if child_pop == pop_1 { pop_2 } else { pop_1 }; + }; + let parent = match parent_pop { + // pop_anc + 0 => parents[rng.gen_range(0..pop_size)], + // pop_1 + 1 => parents[rng.gen_range(0..(pop_size / 2))], + // pop_2 + 2 => parents[rng.gen_range((pop_size / 2)..pop_size)], + _ => panic!("wrong population id encountered"), + }; + (parent, parent_pop.into()) + } + + fn find_breakpoint(rng: &mut StdRng, seqlen: Position) -> Position { + // avoid breaking as edges + let seqlen = f64::from(seqlen).floor() as usize; + let sel = rng.gen_range(1..seqlen) as f64; + Position::from(sel) + } + + fn add_node( + tables: &mut TableCollection, + is_sample: bool, + time: usize, + pop: PopulationId, + ind: IndividualId, + ) -> NodeId { + tables + .add_node( + if is_sample { + NodeFlags::new_sample() + } else { + NodeFlags::default() + }, + time as f64, + pop, + ind, + ) + .unwrap() + } + + fn add_edge( + tables: &mut TableCollection, + start: impl Into, + end: impl Into, + parent_node: NodeId, + child_node: NodeId, + ) -> EdgeId { + tables + .add_edge(start, end, parent_node, child_node) + .unwrap() + } + + fn find_overlaps

(start: P, end: P, intervals: &Vec<(P, P)>, out: &mut Vec<(P, P)>) + where + P: Into + Copy + PartialOrd, + { + // assert intervals is sorted + assert!(intervals.iter().all(|(a, b)| *a <= *b)); + assert!(intervals + .iter() + .zip(intervals.iter().skip(1)) + .all(|(p1, p2)| p1.1 <= p2.0)); + // clear out + out.clear(); + + for (m, n) in intervals { + // no overlap + if (*n <= start) || (end <= *m) { + continue; + } + let new_start = if *m < start { start } else { *m }; + let new_end = if *n < end { *n } else { end }; + out.push((new_start, new_end)); + } + } + + fn find_mutation_pos

(rng: &mut StdRng, s: P, e: P) -> usize + where + P: Into, + { + let s = f64::from(Into::::into(s)).ceil() as usize; + let e = f64::from(Into::::into(e)).floor() as usize; + rng.gen_range(s..e) + } + + fn calc_derived_state(site_last_mutation_order: &[usize], mut_pos: usize) -> [u8; 1] { + [b'a' + + match site_last_mutation_order[mut_pos] + 1 { + x if x > 45 => 45u8, + x => x as u8, + }] + } + + /// simulate diplid individual with migration between two subpopulations + /// + /// Both full_trees and trucated_trees will be generated + pub fn simulate_two_treesequences

( + seqlen: P, + pop_size: usize, + start_time: usize, + split_time: usize, + intervals: &[(P, P)], + seed: u64, + ) -> Result<(TreeSequence, TreeSequence), TskitError> + where + P: Into + Copy + PartialOrd, + { + let rng = &mut StdRng::seed_from_u64(seed); + let intervals: Vec<(Position, Position)> = intervals + .iter() + .map(|(a, b)| ((*a).into(), (*b).into())) + .collect(); + assert!(split_time < start_time); + assert_eq!(pop_size % 2, 0); + // tables without truncation + let mut tables = TableCollection::new(seqlen).unwrap(); + // expected tables after truncation + // it is built following `tables` except for positions for edge table + let mut tr_tbls = TableCollection::new(seqlen).unwrap(); + + let mut buffer = Vec::new(); + + // add pop + let pop_anc = add_pop(&mut tables, "ancestor"); + let pop_1 = add_pop(&mut tables, "pop1"); + let pop_2 = add_pop(&mut tables, "pop2"); + + add_pop(&mut tr_tbls, "ancestral"); + add_pop(&mut tr_tbls, "pop1"); + add_pop(&mut tr_tbls, "pop2"); + + // state variables for site/mutation tables + let num_sites = f64::from(seqlen.into()) as usize; + let mut site_last_mutation_order = vec![0usize; num_sites]; + + let mut site_last_mutation_tables = vec![MutationId::NULL; num_sites]; + let mut site_last_mutation_tr_tbls = vec![MutationId::NULL; num_sites]; + + let mut site_id_map_tables = vec![SiteId::NULL; num_sites]; + let mut site_id_map_tr_tbls = vec![SiteId::NULL; num_sites]; + + // base population + let mut parents = Vec::<(NodeId, NodeId)>::with_capacity(pop_size); + for _ in 0..pop_size { + const FLAGS: u32 = 0; + let loc_null = None; + let parent_ind = tables.add_individual(FLAGS, loc_null, None).unwrap(); + tr_tbls.add_individual(FLAGS, loc_null, None).unwrap(); + + let parent_id = ( + add_node(&mut tables, false, start_time, pop_anc, parent_ind), + add_node(&mut tables, false, start_time, pop_anc, parent_ind), + ); + parents.push(parent_id); + // + add_node(&mut tr_tbls, false, start_time, pop_anc, parent_ind); + add_node(&mut tr_tbls, false, start_time, pop_anc, parent_ind); + } + + // offspring population + let mut children = Vec::<(NodeId, NodeId)>::with_capacity(pop_size); + + for t in (0..start_time).rev() { + for i in 0..pop_size { + // select breakpoints + let breakpoint1 = find_breakpoint(rng, seqlen.into()); + let breakpoint2 = find_breakpoint(rng, seqlen.into()); + + // find child pop + let mut child_pop = pop_anc; + if t > split_time { + child_pop = if i < pop_size / 2 { pop_1 } else { pop_2 } + } + + // find parents + let (parent1, _parent1_pop) = find_parent(rng, &parents, child_pop); + let (parent2, _parent2_pop) = find_parent(rng, &parents, child_pop); + + // add individual + let child_ind = add_ind(&mut tables, parent1, parent2); + add_ind(&mut tr_tbls, parent1, parent2); + + // add nodes + let is_sample = t == 0; + let child_id = ( + add_node(&mut tables, is_sample, t, child_pop, child_ind), + add_node(&mut tables, is_sample, t, child_pop, child_ind), + ); + + add_node(&mut tr_tbls, is_sample, t, child_pop, child_ind); + add_node(&mut tr_tbls, is_sample, t, child_pop, child_ind); + + // add edges, sites & mutations to both tables and tr_tabls + let mu = 0.01f64; + for (s, e, p, c) in [ + (0.0.into(), breakpoint1, parent1.0, child_id.0), + (breakpoint1, seqlen.into(), parent1.1, child_id.0), + (0.0.into(), breakpoint2, parent2.0, child_id.1), + (breakpoint2, seqlen.into(), parent2.1, child_id.1), + ] { + add_edge(&mut tables, s, e, p, c); + + let mut_pos = find_mutation_pos(rng, s, e); + let mut mut_prob = f64::from(e - s) * mu; + if mut_prob > 1.0 { + mut_prob = 1.0; + } + let to_add_mut: bool = rng.gen_bool(mut_prob); + let derived_state = &calc_derived_state(&site_last_mutation_order, mut_pos); + let t = t as f64; + + if to_add_mut { + // add site + let site_not_exist = site_id_map_tables[mut_pos] == SiteId::NULL; + if site_not_exist { + site_id_map_tables[mut_pos] = + tables.add_site(mut_pos as f64, Some(&[b'a'])).unwrap(); + } + // add mutation + let parent_mut = site_last_mutation_tables[mut_pos]; + let site = site_id_map_tables[mut_pos]; + let new_mutation = tables + .add_mutation(site, c, parent_mut, t, Some(derived_state)) + .unwrap(); + + site_last_mutation_tables[mut_pos] = new_mutation; + site_last_mutation_order[mut_pos] += 1; + } + + find_overlaps(s, e, &intervals, &mut buffer); + for (s_, e_) in buffer.iter() { + add_edge(&mut tr_tbls, *s_, *e_, p, c); + let mut_pos_f = mut_pos as f64; + + if to_add_mut && (*s_ <= mut_pos_f) && (*e_ > mut_pos_f) { + // add site + let site_not_exist = site_id_map_tr_tbls[mut_pos] == SiteId::NULL; + if site_not_exist { + site_id_map_tr_tbls[mut_pos] = + tr_tbls.add_site(mut_pos as f64, Some(&[b'a'])).unwrap(); + } + // add mutation + let parent_mut = site_last_mutation_tr_tbls[mut_pos]; + let site = site_id_map_tr_tbls[mut_pos]; + let new_mutation = tr_tbls + .add_mutation(site, c, parent_mut, t, Some(derived_state)) + .unwrap(); + site_last_mutation_tr_tbls[mut_pos] = new_mutation; + } + } + } + + // add edges for tr_tbls + children.push(child_id); + } + // NOTE: avoid simplifcation so that both tables and tr_tables share the same ids + + // set children as parents and clear children + std::mem::swap(&mut children, &mut parents); + children.clear(); + } + + let sort_opts = TableSortOptions::all(); + tables.full_sort(sort_opts).unwrap(); + tr_tbls.full_sort(sort_opts).unwrap(); + + // simplify + let mut samples = Vec::::with_capacity(pop_size * 2); + parents + .iter() + .for_each(|p| samples.extend([p.0, p.1].iter())); + + let simplify_opts = SimplificationOptions::default(); + tables.simplify(&samples, simplify_opts, false).unwrap(); + tr_tbls.simplify(&samples, simplify_opts, false).unwrap(); + + // build indices + tables.build_index().unwrap(); + tr_tbls.build_index().unwrap(); + + // to tree sequences + let treeseq_opts = TreeSequenceFlags::default(); + let full_trees = TreeSequence::new(tables, treeseq_opts).unwrap(); + let truncated_trees = TreeSequence::new(tr_tbls, treeseq_opts).unwrap(); + + Ok((full_trees, truncated_trees)) + } + + pub fn generate_simple_treesequence(add_migration_records: bool) -> TreeSequence { + let snode = NodeFlags::new_sample(); + let anode = NodeFlags::default(); + let pop = PopulationId::NULL; + let ind = IndividualId::NULL; + let seqlen = 100.0; + let (t0, t10) = (0.0, 10.0); + let (left, right) = (0.0, 100.0); + + let sim_opts = SimplificationOptions::default(); + let mut tables = TableCollection::new(seqlen).unwrap(); + let child1 = tables.add_node(snode, t0, pop, ind).unwrap(); + let child2 = tables.add_node(snode, t0, pop, ind).unwrap(); + let parent = tables.add_node(anode, t10, pop, ind).unwrap(); + tables.add_edge(left, right, parent, child1).unwrap(); + tables.add_edge(left, right, parent, child2).unwrap(); + + tables.full_sort(TableSortOptions::all()).unwrap(); + let id_map = tables + .simplify(&[child1, child2], sim_opts, true) + .unwrap() + .unwrap() + .to_owned(); + + // add migration records after simplification to avoid errors when + // simplifying a treesequence that contains a nonempty migration table + if add_migration_records { + let pop_anc = tables.add_population().unwrap(); + let pop_1 = tables.add_population().unwrap(); + let pop_2 = tables.add_population().unwrap(); + // get new ids after simplifcation + let child1 = id_map[child1.to_usize().unwrap()]; + let child2 = id_map[child2.to_usize().unwrap()]; + tables + .add_migration((left, right), child1, (pop_anc, pop_1), t0 + 1.0) + .unwrap(); + tables + .add_migration((left, right), child2, (pop_anc, pop_2), t0 + 5.0) + .unwrap(); + } + + tables.build_index().unwrap(); + + let flags = TreeSequenceFlags::default(); + TreeSequence::new(tables, flags).unwrap() + } +} + +#[cfg(test)] +mod keep_intervals { + use crate::*; + + use super::simulation::{generate_simple_treesequence, simulate_two_treesequences}; + + #[test] + fn test_keep_intervals_invalid_input() { + let intervals_lst = vec![ + vec![(20.0, 10.0)], // out of order + vec![(10.0, 20.0), (19.0, 30.0)], // overlapping intervals + ]; + for intervals in intervals_lst { + let add_migration_table = false; + let trees = generate_simple_treesequence(add_migration_table); + let res = trees.keep_intervals(intervals.into_iter(), true); + assert!(res.is_err()); + } + } + + #[test] + fn test_keep_intervals_nonempty_migration_table() { + let intervals = [(10.0, 20.0)]; + + let add_migration_table = true; + let to_simplify = true; + let trees = generate_simple_treesequence(add_migration_table); + let res = trees.keep_intervals(intervals.iter().copied(), to_simplify); + assert!(res.is_err()); + + let add_migration_table = true; + let to_simply = false; + let trees = generate_simple_treesequence(add_migration_table); + let res = trees.keep_intervals(intervals.iter().copied(), to_simply); + assert!(res.is_ok()); + + let add_migration_table = false; + let to_simply = true; + let trees = generate_simple_treesequence(add_migration_table); + let res = trees.keep_intervals(intervals.iter().copied(), to_simply); + assert!(res.is_ok()); + + let add_migration_table = false; + let to_simply = false; + let trees = generate_simple_treesequence(add_migration_table); + let res = trees.keep_intervals(intervals.iter().copied(), to_simply); + assert!(res.is_ok()); + } + + #[test] + fn test_keep_intervals() { + let seqlen = 1000.0; + let intervals_lst = vec![ + vec![(seqlen + 1.0, seqlen + 2.0)], // out of range: > seqlen + vec![(10.0, 20.0), (700.0, 850.0)], // multiple intervals + vec![(10.0, 20.0)], // single intervals + ]; + let popsize = 50; + let total_time = 300; + let split_time = 20; + + for intervals in intervals_lst { + for seed in [123, 3224] { + let (full_trees, exepected) = simulate_two_treesequences( + seqlen, popsize, total_time, split_time, &intervals, seed, + ) + .unwrap(); + + if exepected.edges().num_rows() > 0 { + let truncated = full_trees + .keep_intervals(intervals.iter().copied(), true) + .expect("error") + .expect("empty table"); + + // dump tables for comparision + let truncated = truncated.dump_tables().unwrap(); + let expected = exepected.dump_tables().unwrap(); + + let res = truncated.equals(&expected, TableEqualityOptions::all()); + assert!(res); + } else { + let trucated = full_trees + .keep_intervals(intervals.iter().copied(), true) + .unwrap(); + assert!(trucated.is_none()); + } + } + } + } +} diff --git a/src/trees/treeseq.rs b/src/trees/treeseq.rs index f02d91ec8..3e550aea9 100644 --- a/src/trees/treeseq.rs +++ b/src/trees/treeseq.rs @@ -344,6 +344,66 @@ impl TreeSequence { )) } + /// Truncate the [TreeSequence] to specified genome intervals. + /// + /// - `Ok(None)`: when truncation leads to empty edge table. + /// - `Ok(Some(TableCollection))`: when trunction is successfully performed + /// and results in non-empty edge table. + /// - `Error(TskitError)`: Any errors from the C API propagate. An + /// [TskitError::RangeError] will occur when `intervals` are not + /// sorted. Note that as `tskit` currently does not support `simplify` + /// on [TreeSequence] with a non-empty migration table, calling + /// `keep_intervals` on those [TreeSequence] with `simplify` set to `true` + /// will return an error. + /// + /// # Example + /// ```rust + /// # use tskit::*; + /// # let snode = NodeFlags::new_sample(); + /// # let anode = NodeFlags::default(); + /// # let pop = PopulationId::NULL; + /// # let ind = IndividualId::NULL; + /// # let seqlen = 100.0; + /// # let (t0, t10) = (0.0, 10.0); + /// # let (left, right) = (0.0, 100.0); + /// # let sim_opts = SimplificationOptions::default(); + /// # + /// # let mut tables = TableCollection::new(seqlen).unwrap(); + /// # let child1 = tables.add_node(snode, t0, pop, ind).unwrap(); + /// # let child2 = tables.add_node(snode, t0, pop, ind).unwrap(); + /// # let parent = tables.add_node(anode, t10, pop, ind).unwrap(); + /// # + /// # tables.add_edge(left, right, parent, child1).unwrap(); + /// # tables.add_edge(left, right, parent, child2).unwrap(); + /// # tables.full_sort(TableSortOptions::all()).unwrap(); + /// # tables.simplify(&[child1, child2], sim_opts, false).unwrap(); + /// # tables.build_index().unwrap(); + /// # + /// # let trees = TreeSequence::new(tables, TreeSequenceFlags::default()).unwrap(); + /// # + /// let intervals = [(0.0, 10.0), (90.0, 100.0)].into_iter(); + /// trees.keep_intervals(intervals, true).unwrap().unwrap(); + /// ``` + /// + /// Note that no new provenance will be appended. + pub fn keep_intervals

( + self, + intervals: impl Iterator, + simplify: bool, + ) -> Result, TskitError> + where + P: Into, + { + let tables = self.dump_tables()?; + match tables.keep_intervals(intervals, simplify) { + Ok(Some(tables)) => { + Self::new(tables, TreeSequenceFlags::default().build_indexes()).map(Some) + } + Ok(None) => Ok(None), + Err(e) => Err(e), + } + } + #[cfg(feature = "provenance")] #[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))] /// Add provenance record with a time stamp.