diff --git a/Cargo.toml b/Cargo.toml index a5e7bb30..5bbe9e16 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ serde_json = {version = "1.0.114", optional = true} bincode = {version = "1.3.1", optional = true} tskit-derive = {version = "0.2.0", path = "tskit-derive", optional = true} delegate = "0.12.0" +rand = "0.8.3" + [dev-dependencies] anyhow = {version = "1.0.79"} diff --git a/src/lib.rs b/src/lib.rs index d91935c7..2ffffac4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -177,4 +177,5 @@ mod tests { } // Testing modules +pub mod test_data; mod test_fixtures; diff --git a/src/table_collection.rs b/src/table_collection.rs index c67d0a69..0eacfff9 100644 --- a/src/table_collection.rs +++ b/src/table_collection.rs @@ -32,6 +32,7 @@ use crate::TskReturnValue; use crate::{EdgeId, NodeId}; use ll_bindings::tsk_id_t; use ll_bindings::tsk_size_t; +use streaming_iterator::StreamingIterator; /// A table collection. /// @@ -1379,30 +1380,51 @@ impl TableCollection { /// Truncate the Table Collection to specified genome intervals. /// - /// # Error - /// Any errors from the C API propagate. An [TskitError::RangeError] will - /// occur when `intervals` are not sorted. + /// # Return + /// - `Ok(None)`: when truncation leads to empty edge table. + /// - `Ok(Some(TableCollection))`: when trunction is successfully performed + /// and results in non-empty edge table. + /// - `Error(TskitError)`: Any errors from the C API propagate. An + /// [TskitError::RangeError] will occur when `intervals` are not sorted. /// /// # Example /// ```rust - /// use tskit::TreeSequence; - /// let mut tables = TreeSequence::load("./testdata/1.trees") + /// # use tskit::test_data::simulation::simulate_two_treesequences; + /// # let intervals = vec![(10.0, 20.0), (700.0, 850.0)]; + /// # let seqlen = 100.0; + /// # let popsize = 100; + /// # let totle_generations = 50; + /// # let popsplit_time = 10; + /// # let seed = 123; + + /// # let (full_trees, _exepected) = simulate_two_treesequences( + /// # seqlen, + /// # popsize, + /// # totle_generations, + /// # popsplit_time, + /// # &intervals, + /// # seed, + /// # ) + /// # .unwrap(); + /// # + /// # let tables = full_trees.dump_tables().unwrap(); + /// + /// let _trucated_tables = tables + /// .keep_intervals(intervals.iter().map(|a| *a), true) /// .unwrap() - /// .dump_tables() - /// .unwrap(); - /// tables - /// .keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true) /// .unwrap(); /// ``` /// /// Note that no new provenance will be appended. - pub fn keep_intervals( - &mut self, - intervals: impl Iterator, + pub fn keep_intervals

( + self, + intervals: impl Iterator, simplify: bool, - ) -> Result<(), TskitError> { - use streaming_iterator::StreamingIterator; - + ) -> Result, TskitError> + where + P: Into, + { + let mut tables = self; // use tables from sys to allow easier process with metadata let options = 0; let mut new_edges = crate::sys::EdgeTable::new(options)?; @@ -1411,13 +1433,14 @@ impl TableCollection { let mut new_mutations = crate::sys::MutationTable::new(options)?; // for old site id to new site id mapping - let mut site_map = vec![-1i32; self.sites().num_rows().as_usize()]; + let mut site_map = vec![-1i32; tables.sites().num_rows().as_usize()]; // logicals to indicate whether a site (old) will be kept in new site table - let mut keep_sites = vec![false; self.sites().num_rows().try_into()?]; + let mut keep_sites = vec![false; tables.sites().num_rows().try_into()?]; let mut last_interval = (Position::from(0.0), Position::from(0.0)); for (s, e) in intervals { + let (s, e) = (s.into(), e.into()); // make sure intervals are sorted if (s > e) || (s < last_interval.1) { return Err(TskitError::RangeError( @@ -1426,13 +1449,13 @@ impl TableCollection { } keep_sites .iter_mut() - .zip(self.sites_iter()) + .zip(tables.sites_iter()) .for_each(|(k, site_row)| { *k = *k || ((site_row.position >= s) && (site_row.position < e)); }); // use stream_iter and while-let pattern for easier ? operator within a loop - let mut edge_iter = self + let mut edge_iter = tables .edges() .lending_iter() .filter(|edge_row| !((edge_row.right <= s) || (edge_row.left >= e))); @@ -1452,7 +1475,7 @@ impl TableCollection { )?; } - let mut migration_iter = self + let mut migration_iter = tables .migrations() .lending_iter() .filter(|mrow| !!((mrow.right <= s) || (mrow.left >= e))); @@ -1471,7 +1494,7 @@ impl TableCollection { } let mut running_site_id = 0; - let mut site_iter = self.sites().lending_iter(); + let mut site_iter = tables.sites().lending_iter(); while let Some(site_row) = site_iter.next() { let old_id = site_row.id.to_usize().unwrap(); if keep_sites[old_id] { @@ -1486,19 +1509,22 @@ impl TableCollection { } // build mutation_map - let mutation_map = { - let mut v = Vec::with_capacity(keep_sites.len()); + let mutation_map: Vec<_> = { let mut n = 0; - self.mutations().site_slice().iter().for_each(|site| { - if keep_sites[site.as_usize()] { - n += 1 - }; - v.push(n - 1); - }); - v + tables + .mutations() + .site_slice() + .iter() + .map(|site| { + if keep_sites[site.as_usize()] { + n += 1 + }; + n - 1 + }) + .collect() }; - let mut mutations_iter = self.mutations().lending_iter(); + let mut mutations_iter = tables.mutations().lending_iter(); while let Some(mutation_row) = mutations_iter.next() { let old_id = mutation_row.site.to_usize().unwrap(); if keep_sites[old_id] { @@ -1528,20 +1554,25 @@ impl TableCollection { let new_sites = SiteTable::new_from_table(new_sites.as_mut())?; // replace old tables with new tables - self.set_edges(&new_edges).map(|_| ())?; - self.set_migrations(&new_migrations).map(|_| ())?; - self.set_mutations(&new_mutations).map(|_| ())?; - self.set_sites(&new_sites)?; + tables.set_edges(&new_edges).map(|_| ())?; + tables.set_migrations(&new_migrations).map(|_| ())?; + tables.set_mutations(&new_mutations).map(|_| ())?; + tables.set_sites(&new_sites)?; // sort tables - self.full_sort(TableSortOptions::default())?; + tables.full_sort(TableSortOptions::default())?; // simplify tables if simplify { - let samples = self.samples_as_vector(); - self.simplify(samples.as_slice(), SimplificationOptions::default(), false)?; + let samples = tables.samples_as_vector(); + tables.simplify(samples.as_slice(), SimplificationOptions::default(), false)?; } - Ok(()) + // return None when edge table is empty + if tables.edges().num_rows() == 0 { + Ok(None) + } else { + Ok(Some(tables)) + } } } diff --git a/src/test_data.rs b/src/test_data.rs new file mode 100644 index 00000000..98b657d9 --- /dev/null +++ b/src/test_data.rs @@ -0,0 +1,373 @@ +/// mimic the c simulate function in tskit c api document +/// https://tskit.dev/tskit/docs/stable/c-api.html#basic-forwards-simulator +pub mod simulation { + use core::panic; + + use crate::{ + metadata::{MetadataError, MetadataRoundtrip, PopulationMetadata}, + EdgeId, IndividualId, MutationId, NodeFlags, NodeId, PopulationId, Position, + SimplificationOptions, SiteId, TableCollection, TableSortOptions, TreeSequence, + TreeSequenceFlags, TskitError, + }; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + struct MyMeta { + inner: String, + } + impl From for MyMeta { + fn from(value: String) -> Self { + MyMeta { inner: value } + } + } + impl<'a> From<&'a str> for MyMeta { + fn from(value: &'a str) -> Self { + MyMeta { + inner: value.to_owned(), + } + } + } + + // helper structs, impls and functions + + impl MetadataRoundtrip for MyMeta { + fn encode(&self) -> Result, MetadataError> { + Ok(self.inner.as_bytes().to_owned()) + } + fn decode(md: &[u8]) -> Result + where + Self: Sized, + { + Ok(MyMeta { + inner: String::from_utf8(md.to_owned()).unwrap(), + }) + } + } + + impl PopulationMetadata for MyMeta {} + + fn add_pop(tables: &mut TableCollection, name: &str) -> PopulationId { + tables + .add_population_with_metadata(&MyMeta::from(name)) + .unwrap() + } + + fn add_ind( + tables: &mut TableCollection, + parent1: (NodeId, NodeId), + parent2: (NodeId, NodeId), + ) -> IndividualId { + let parent1_ind = tables.nodes().individual(parent1.0).unwrap(); + let parent2_ind = tables.nodes().individual(parent2.0).unwrap(); + let flags = 0u32; + let loc_null = None; + tables + .add_individual(flags, loc_null, [parent1_ind, parent2_ind]) + .unwrap() + } + + fn find_parent( + rng: &mut StdRng, + parents: &Vec<(NodeId, NodeId)>, + child_pop: PopulationId, + ) -> ((NodeId, NodeId), PopulationId) { + assert_eq!(parents.len() % 2, 0); + let (pop_anc, pop_1, pop_2) = (0, 1, 2); + let child_pop: i32 = child_pop.into(); + + let pop_size = parents.len(); + let mut parent_pop = child_pop; + + let is_migrant = (child_pop != pop_anc) && rng.gen_bool(0.01); + if is_migrant { + parent_pop = if child_pop == pop_1 { pop_2 } else { pop_1 }; + }; + let parent = match parent_pop { + // pop_anc + 0 => parents[rng.gen_range(0..pop_size)], + // pop_1 + 1 => parents[rng.gen_range(0..(pop_size / 2))], + // pop_2 + 2 => parents[rng.gen_range((pop_size / 2)..pop_size)], + _ => panic!("wrong population id encountered"), + }; + (parent, parent_pop.into()) + } + + fn find_breakpoint(rng: &mut StdRng, seqlen: Position) -> Position { + // avoid breaking as edges + let seqlen = f64::from(seqlen).floor() as usize; + let sel = rng.gen_range(1..seqlen) as f64; + Position::from(sel) + } + + fn add_node( + tables: &mut TableCollection, + is_sample: bool, + time: usize, + pop: PopulationId, + ind: IndividualId, + ) -> NodeId { + tables + .add_node( + if is_sample { + NodeFlags::new_sample() + } else { + NodeFlags::default() + }, + time as f64, + pop, + ind, + ) + .unwrap() + } + + fn add_edge( + tables: &mut TableCollection, + start: impl Into, + end: impl Into, + parent_node: NodeId, + child_node: NodeId, + ) -> EdgeId { + tables + .add_edge(start, end, parent_node, child_node) + .unwrap() + } + + fn find_overlaps

(start: P, end: P, intervals: &Vec<(P, P)>, out: &mut Vec<(P, P)>) + where + P: Into + Copy + PartialOrd, + { + // assert intervals is sorted + assert!(intervals.iter().all(|(a, b)| *a <= *b)); + assert!(intervals + .iter() + .zip(intervals.iter().skip(1)) + .all(|(p1, p2)| p1.1 <= p2.0)); + // clear out + out.clear(); + + for (m, n) in intervals { + // no overlap + if (*n <= start) || (end <= *m) { + continue; + } + let new_start = if *m < start { start } else { *m }; + let new_end = if *n < end { *n } else { end }; + out.push((new_start, new_end)); + } + } + + fn find_mutation_pos

(rng: &mut StdRng, s: P, e: P) -> usize + where + P: Into, + { + let s = f64::from(Into::::into(s)).ceil() as usize; + let e = f64::from(Into::::into(e)).floor() as usize; + rng.gen_range(s..e) + } + + fn calc_derived_state(site_last_mutation_order: &Vec, mut_pos: usize) -> [u8; 1] { + [b'a' + + match site_last_mutation_order[mut_pos] + 1 { + x if x > 45 => 45 as u8, + x => x as u8, + }] + } + + /// simulate diplid individual with migration between two subpopulations + /// + /// Both full_trees and trucated_trees will be generated + pub fn simulate_two_treesequences

( + seqlen: P, + pop_size: usize, + start_time: usize, + split_time: usize, + intervals: &Vec<(P, P)>, + seed: u64, + ) -> Result<(TreeSequence, TreeSequence), TskitError> + where + P: Into + Copy + PartialOrd, + { + let rng = &mut StdRng::seed_from_u64(seed); + let intervals: Vec<(Position, Position)> = intervals + .iter() + .map(|(a, b)| ((*a).into(), (*b).into())) + .collect(); + assert!(split_time < start_time); + assert_eq!(pop_size % 2, 0); + // tables without truncation + let mut tables = TableCollection::new(seqlen).unwrap(); + // expected tables after truncation + // it is built following `tables` except for positions for edge table + let mut tr_tbls = TableCollection::new(seqlen).unwrap(); + + let mut buffer = Vec::new(); + + // add pop + let pop_anc = add_pop(&mut tables, "ancestor"); + let pop_1 = add_pop(&mut tables, "pop1"); + let pop_2 = add_pop(&mut tables, "pop2"); + + add_pop(&mut tr_tbls, "ancestor"); + add_pop(&mut tr_tbls, "pop1"); + add_pop(&mut tr_tbls, "pop2"); + + // state variables for site/mutation tables + let num_sites = f64::from(seqlen.into()) as usize; + let mut site_last_mutation_order = vec![0usize; num_sites]; + + let mut site_last_mutation_tables = vec![MutationId::NULL; num_sites]; + let mut site_last_mutation_tr_tbls = vec![MutationId::NULL; num_sites]; + + let mut site_id_map_tables = vec![SiteId::NULL; num_sites]; + let mut site_id_map_tr_tbls = vec![SiteId::NULL; num_sites]; + + // base population + let mut parents = Vec::<(NodeId, NodeId)>::with_capacity(pop_size); + for _ in 0..pop_size { + const FLAGS: u32 = 0; + let loc_null = None; + let parent_ind = tables.add_individual(FLAGS, loc_null, None).unwrap(); + tr_tbls.add_individual(FLAGS, loc_null, None).unwrap(); + + let parent_id = ( + add_node(&mut tables, false, start_time, pop_anc, parent_ind), + add_node(&mut tables, false, start_time, pop_anc, parent_ind), + ); + parents.push(parent_id); + // + add_node(&mut tr_tbls, false, start_time, pop_anc, parent_ind); + add_node(&mut tr_tbls, false, start_time, pop_anc, parent_ind); + } + + // offspring population + let mut children = Vec::<(NodeId, NodeId)>::with_capacity(pop_size); + + for t in (0..start_time).rev() { + for i in 0..pop_size { + // select breakpoints + let breakpoint1 = find_breakpoint(rng, seqlen.into()); + let breakpoint2 = find_breakpoint(rng, seqlen.into()); + + // find child pop + let mut child_pop = pop_anc; + if t > split_time { + child_pop = if i < pop_size / 2 { pop_1 } else { pop_2 } + } + + // find parents + let (parent1, _parent1_pop) = find_parent(rng, &parents, child_pop); + let (parent2, _parent2_pop) = find_parent(rng, &parents, child_pop); + + // add individual + let child_ind = add_ind(&mut tables, parent1, parent2); + add_ind(&mut tr_tbls, parent1, parent2); + + // add nodes + let is_sample = if t == 0 { true } else { false }; + let child_id = ( + add_node(&mut tables, is_sample, t, child_pop, child_ind), + add_node(&mut tables, is_sample, t, child_pop, child_ind), + ); + + add_node(&mut tr_tbls, is_sample, t, child_pop, child_ind); + add_node(&mut tr_tbls, is_sample, t, child_pop, child_ind); + + // add edges, sites & mutations to both tables and tr_tabls + let mu = 0.01f64; + for (s, e, p, c) in [ + (0.0.into(), breakpoint1, parent1.0, child_id.0), + (breakpoint1, seqlen.into(), parent1.1, child_id.0), + (0.0.into(), breakpoint2, parent2.0, child_id.1), + (breakpoint2, seqlen.into(), parent2.1, child_id.1), + ] { + add_edge(&mut tables, s, e, p, c); + + let mut_pos = find_mutation_pos(rng, s, e); + let mut mut_prob = f64::from(e - s) * mu; + if mut_prob > 1.0 { + mut_prob = 1.0; + } + let to_add_mut: bool = rng.gen_bool(mut_prob); + let derived_state = &calc_derived_state(&site_last_mutation_order, mut_pos); + let t = t as f64; + + if to_add_mut { + // add site + let site_not_exist = site_id_map_tables[mut_pos] == SiteId::NULL; + if site_not_exist { + site_id_map_tables[mut_pos] = + tables.add_site(mut_pos as f64, Some(&[b'a'])).unwrap(); + } + // add mutation + let parent_mut = site_last_mutation_tables[mut_pos]; + let site = site_id_map_tables[mut_pos]; + let new_mutation = tables + .add_mutation(site, c, parent_mut, t, Some(derived_state)) + .unwrap(); + + site_last_mutation_tables[mut_pos] = new_mutation; + site_last_mutation_order[mut_pos] += 1; + } + + find_overlaps(s, e, &intervals, &mut buffer); + for (s_, e_) in buffer.iter() { + add_edge(&mut tr_tbls, *s_, *e_, p, c); + let mut_pos_f = mut_pos as f64; + + if to_add_mut && (*s_ <= mut_pos_f) && (*e_ > mut_pos_f) { + // add site + let site_not_exist = site_id_map_tr_tbls[mut_pos] == SiteId::NULL; + if site_not_exist { + site_id_map_tr_tbls[mut_pos] = + tr_tbls.add_site(mut_pos as f64, Some(&[b'a'])).unwrap(); + } + // add mutation + let parent_mut = site_last_mutation_tr_tbls[mut_pos]; + let site = site_id_map_tr_tbls[mut_pos]; + let new_mutation = tr_tbls + .add_mutation(site, c, parent_mut, t, Some(derived_state)) + .unwrap(); + site_last_mutation_tr_tbls[mut_pos] = new_mutation; + } + } + } + + // add edges for tr_tbls + children.push(child_id); + } + // NOTE: avoid simplifcation so that both tables and tr_tables share the same ids + + // set children as parents and clear children + std::mem::swap(&mut children, &mut parents); + children.clear(); + } + + tables.full_sort(TableSortOptions::all()).unwrap(); + tr_tbls.full_sort(TableSortOptions::all()).unwrap(); + + // simplify + let mut samples = Vec::::with_capacity(pop_size * 2); + parents + .iter() + .for_each(|p| samples.extend([p.0, p.1].iter())); + + tables + .simplify(&samples, SimplificationOptions::default(), false) + .unwrap(); + + tr_tbls + .simplify(&samples, SimplificationOptions::default(), false) + .unwrap(); + + // build indices + tables.build_index().unwrap(); + tr_tbls.build_index().unwrap(); + + // to tree sequences + let full_trees = TreeSequence::new(tables, TreeSequenceFlags::default()).unwrap(); + let truncated_trees = TreeSequence::new(tr_tbls, TreeSequenceFlags::default()).unwrap(); + + Ok((full_trees, truncated_trees)) + } +} diff --git a/src/trees/treeseq.rs b/src/trees/treeseq.rs index 6d24b559..7cbb0e76 100644 --- a/src/trees/treeseq.rs +++ b/src/trees/treeseq.rs @@ -352,20 +352,46 @@ impl TreeSequence { /// /// # Example /// ```rust - /// use tskit::TreeSequence; - /// let mut ts = TreeSequence::load("testdata/1.trees").expect("error loading ts"); - /// let new_ts = ts.keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true).unwrap(); + /// # use tskit::test_data::simulation::simulate_two_treesequences; + /// # let intervals = vec![(10.0, 20.0), (700.0, 850.0)]; + /// # let seqlen = 100.0; + /// # let popsize = 100; + /// # let totle_generations = 50; + /// # let popsplit_time = 10; + /// # let seed = 123; + + /// # let (full_trees, _exepected) = simulate_two_treesequences( + /// # seqlen, + /// # popsize, + /// # totle_generations, + /// # popsplit_time, + /// # &intervals, + /// # seed, + /// # ) + /// # .unwrap(); + /// + /// let _trucated_trees = full_trees + /// .keep_intervals(intervals.into_iter(), true) + /// .unwrap() + /// .unwrap(); /// ``` /// /// Note that no new provenance will be appended. - pub fn keep_intervals( + pub fn keep_intervals

( self, - intervals: impl Iterator, + intervals: impl Iterator, simplify: bool, - ) -> Result { - let mut tables = self.dump_tables()?; - tables.keep_intervals(intervals, simplify)?; - Self::new(tables, TreeSequenceFlags::default().build_indexes()) + ) -> Result, TskitError> + where + P: Into, + { + let tables = self.dump_tables()?; + match tables.keep_intervals(intervals, simplify) { + Ok(Some(tables)) => Self::new(tables, TreeSequenceFlags::default().build_indexes()) + .map(|trees| Some(trees)), + Ok(None) => Ok(None), + Err(e) => Err(e), + } } #[cfg(feature = "provenance")] diff --git a/testdata/1.trees b/testdata/1.trees deleted file mode 100644 index 5383cc3b..00000000 Binary files a/testdata/1.trees and /dev/null differ diff --git a/testdata/2.trees b/testdata/2.trees deleted file mode 100644 index 2e3c5758..00000000 Binary files a/testdata/2.trees and /dev/null differ diff --git a/testdata/3.trees b/testdata/3.trees deleted file mode 100644 index d860e8de..00000000 Binary files a/testdata/3.trees and /dev/null differ diff --git a/testdata/gen_trees.py b/testdata/gen_trees.py deleted file mode 100644 index da86b096..00000000 --- a/testdata/gen_trees.py +++ /dev/null @@ -1,27 +0,0 @@ -import msprime -import tskit - -demography = msprime.Demography() -demography.add_population(name="A", initial_size=10_000) -demography.add_population(name="B", initial_size=5_000) -demography.add_population(name="C", initial_size=1_000) -# demography.set_migration_rate("A", "B", 0.01) -demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C") -ts1: tskit.TreeSequence = msprime.sim_ancestry( - samples={"A": 1, "B": 1}, - demography=demography, - random_seed=12, - recombination_rate=0.00001, - sequence_length=300.0, - # record_migrations=True -) -ts1 = msprime.sim_mutations(ts1, rate=0.0001, random_seed=12) - -ts2 = ts1.keep_intervals([(10, 130)], record_provenance=False) -ts3 = ts1.keep_intervals([(10, 40), (100, 200)], record_provenance=False) - - -ts1.dump("testdata/1.trees") -ts2.dump("testdata/2.trees") -ts3.dump("testdata/3.trees") - diff --git a/tests/test_tables.rs b/tests/test_tables.rs index cf1ea9e3..6cb26c96 100644 --- a/tests/test_tables.rs +++ b/tests/test_tables.rs @@ -551,66 +551,46 @@ mod test_metadata_round_trips { } } -#[test] -fn test_keep_intervals() { - use tskit::{TableEqualityOptions, TreeSequence}; - - // test on tables - let mut tables1 = TreeSequence::load("./testdata/1.trees") - .unwrap() - .dump_tables() - .unwrap(); - // run keep_intervals - tables1 - .keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true) - .unwrap(); - // expected tree sequences - let tables2 = TreeSequence::load("./testdata/2.trees") - .unwrap() - .dump_tables() - .unwrap(); - - assert!(tables1.equals(&tables2, TableEqualityOptions::all())); - - // test on tables - let mut tables1 = TreeSequence::load("./testdata/1.trees") - .unwrap() - .dump_tables() - .unwrap(); - // run keep_intervals - tables1 - .keep_intervals( - vec![(10.0.into(), 40.0.into()), (100.0.into(), 200.0.into())].into_iter(), - true, - ) - .unwrap(); - // expected tree sequences - let tables2 = TreeSequence::load("./testdata/3.trees") - .unwrap() - .dump_tables() - .unwrap(); - - assert!(tables1.equals(&tables2, TableEqualityOptions::all())); - - // test on treeseq - let ts1 = TreeSequence::load("./testdata/1.trees") - .unwrap() - .keep_intervals(vec![(10.0.into(), 130.0.into())].into_iter(), true) - .unwrap(); - let ts2 = TreeSequence::load("./testdata/2.trees").unwrap(); - ts1.dump_tables() - .unwrap() - .equals(&ts2.dump_tables().unwrap(), TableEqualityOptions::all()); - - let ts1 = TreeSequence::load("./testdata/1.trees") - .unwrap() - .keep_intervals( - vec![(10.0.into(), 40.0.into()), (100.0.into(), 200.0.into())].into_iter(), - true, - ) - .unwrap(); - let ts3 = TreeSequence::load("./testdata/3.trees").unwrap(); - ts1.dump_tables() - .unwrap() - .equals(&ts3.dump_tables().unwrap(), TableEqualityOptions::all()); +#[cfg(test)] +mod keep_intervals { + use tskit::test_data::simulation::simulate_two_treesequences; + #[test] + fn test_keep_intervals() { + use tskit::TableEqualityOptions; + let intervals_lst = vec![ + // vec![(0.0, 1.0)], + vec![(10.0, 20.0), (700.0, 850.0)], + // vec![(900.0, 1000.0)], + ]; + let seqlen = 1000.0; + let popsize = 100; + let totle_generations = 500; + let popsplit_time = 100; + + for intervals in intervals_lst { + for seed in [123, 3224] { + let (full_trees, exepected) = simulate_two_treesequences( + seqlen, + popsize, + totle_generations, + popsplit_time, + &intervals, + seed, + ) + .unwrap(); + + let trucated = full_trees + .keep_intervals(intervals.iter().map(|a| *a), true) + .unwrap() + .unwrap(); + + let res = trucated.dump_tables().unwrap().equals( + &exepected.dump_tables().unwrap(), + TableEqualityOptions::all(), + ); + + assert!(res); + } + } + } }