Skip to content

Commit

Permalink
added new feature for group
Browse files Browse the repository at this point in the history
  • Loading branch information
NPSDC committed Apr 17, 2024
1 parent 1b24c30 commit b33b2a1
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 33 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "treeterminus"
version = "0.2.0"
version = "0.3.0"
authors = ["Noor Pratap Singh <[email protected]>", "Rob Patro <[email protected]>"]
edition = "2021"

Expand Down Expand Up @@ -39,6 +39,7 @@ assert_cmd = "0.12.0"
serde-pickle = "0.6"
serde_stacker = "0.1"
run_script = "^0.7.0"
statrs = "0.16.0"

[dev-dependencies]
predicates = "1.0.2"
16 changes: 16 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Changelog

## [0.3.0] - 2024-04-16
### Added
- New flag `red_quant` added to `group`
- Fixed the default code for computing threshold for reduction in infRV

## [0.2.0] - 2023-05-11
### Added
- PHYLIP function called from inside RUST
- multiple instances of TreeTerminus can be run

## [0.1.0] - 2022-11-04
### Added
- Initial release

12 changes: 6 additions & 6 deletions src/collapse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn create_union_find(g: &[String], ntxps: usize) -> UnionFind<usize> {
let mut unionfind_struct = UnionFind::new(ntxps);
let mut visited: Vec<i32> = vec![-1; ntxps];
let mut count = 0;
for (_i, group) in g.iter().enumerate() {
for group in g.iter() {
let g_set: Vec<usize> = group
.clone()
.split('_')
Expand Down Expand Up @@ -56,7 +56,7 @@ fn get_merged_bparts(
) -> HashMap<String, HashMap<String, u32>> {
let all_groups: Vec<String> = all_groups_bpart.keys().cloned().collect();
let mut merged_bparts: HashMap<String, HashMap<String, u32>> = HashMap::new();
for (_j, old_g) in all_groups.iter().enumerate() {
for old_g in all_groups.iter() {
let f_txp = old_g
.clone()
.split('_')
Expand All @@ -67,7 +67,7 @@ fn get_merged_bparts(
let m_group = strings.join("_").to_string();
let m_bpart_key = merged_bparts
.entry(sort_group_id(&m_group.clone()))
.or_insert_with(HashMap::new);
.or_default();

for (b_part, count) in all_groups_bpart.get(&old_g.clone()).unwrap().iter() {
let c_count = m_bpart_key.entry(b_part.clone()).or_insert(0);
Expand Down Expand Up @@ -97,7 +97,7 @@ fn find_groups_in_merged(
let m_group = strings.join("_").to_string();
merged_groups
.entry(m_group)
.or_insert_with(Vec::new)
.or_default()
.push(all_groups[j].clone());
}
merged_groups
Expand Down Expand Up @@ -145,7 +145,7 @@ fn get_group_trees(
for (_i, samp_hash) in samp_group_trees.iter().enumerate() {
let mut g_vec: Vec<String> = Vec::new();
let mut s_trees: Vec<String> = Vec::new();
for (_j, g) in groups.iter().enumerate() {
for g in groups.iter() {
if samp_hash.contains_key(g) {
g_vec.push(g.clone());
//println!("{}\t{:?}",g, samp_group_trees[_i].get(g).unwrap().traverse_tree());
Expand Down Expand Up @@ -268,7 +268,7 @@ pub fn use_phylip(dir_paths: &[&str], out: &String, all_groups: &[String], ntxps
let mut samp_group_trees: Vec<HashMap<String, TreeNode>> = Vec::new(); //Vector containing group trees from each sample
let mut msamp_nwk_file: Vec<File> = Vec::new(); //Vector containing newick trees corresponding to each group
// Storing group trees in each sample in an array along with ....
for (_i, dname) in dir_paths.iter().enumerate() {
for dname in dir_paths.iter() {
let compo: Vec<&str> = dname.rsplit('/').collect();
let experiment_name = compo[0];
let mut prefix_path = out.clone();
Expand Down
25 changes: 19 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ fn do_group(sub_m: &ArgMatches) -> Result<bool, io::Error> {
.parse::<f64>()
.expect("could not parse inf percentile");

let red_quant = sub_m
.value_of("red_quant")
.unwrap()
.parse::<f64>()
.expect("could not parse reduction in inferential variance");

let mut dir_paths: Vec<String> = Vec::new();
if mean_inf {
let sd = read_dir(dname.clone());
Expand Down Expand Up @@ -349,18 +355,17 @@ fn do_group(sub_m: &ArgMatches) -> Result<bool, io::Error> {
let thr = match thr_bool {
true => {
if !mean_inf {
util::get_threshold(&gibbs_array, p, seed, &file_list_out)
util::get_threshold(&gibbs_array, p, seed, &file_list_out, red_quant)
} else {
let mut thresh = 0.0;
for gb in gibbs_array_vec.iter() {
thresh += util::get_threshold(gb, p, seed, &file_list_out);
thresh += util::get_threshold(gb, p, seed, &file_list_out, red_quant);
}
thresh / (gibbs_array_vec.len() as f64)
}
}
false => 1e7,
};

println!("threshold: {}", thr);
println!("{}", eq_class.ntarget);

Expand Down Expand Up @@ -461,6 +466,7 @@ fn do_group(sub_m: &ArgMatches) -> Result<bool, io::Error> {
"allele_mode":asemode,
"txp_mode":txpmode,
"inf_perc":inf_perc,
"red_quant":red_quant,
"p":p,
"thr":thr,
"ntxps":eq_class.ntarget,
Expand Down Expand Up @@ -540,10 +546,10 @@ fn do_collapse(sub_m: &ArgMatches) -> Result<bool, io::Error> {
//let node_vec = group_bipart.entry(node.id.clone()).or_insert(Vec::<String>::new());
let dir_group_key = dir_bipart_counter
.entry(req_group.clone())
.or_insert_with(HashMap::new);
.or_default();
let overall_group_key = bipart_counter
.entry(req_group.clone())
.or_insert_with(HashMap::new);
.or_default();

//binary_tree::compute_bipart_count(node, &mut bipart_counter, &mut dir_bipart_counter, &node_set, node_vec);
group_keys.push(req_group.clone());
Expand Down Expand Up @@ -585,7 +591,7 @@ fn do_collapse(sub_m: &ArgMatches) -> Result<bool, io::Error> {
fn main() -> io::Result<()> {
let matches = App::new("TreeTerminus")
.setting(AppSettings::ArgRequiredElseHelp)
.version("0.1.0")
.version("0.3.0")
.author("Singh et al.")
// .about("Data-driven grouping of transcripts to reduce inferential uncertainty")
.subcommand(
Expand Down Expand Up @@ -668,6 +674,13 @@ fn main() -> io::Result<()> {
.default_value("0")
.help("inferential variance percentile threshold that determines whether a transcript will be considered for grouping")
)
.arg(
Arg::with_name("red_quant")
.long("red_quant")
.takes_value(true)
.default_value("2.5")
.help("Reduction in inferential variance percentile threshold that determines to detemine if transcripts should be grouped")
)
)
.subcommand(
SubCommand::with_name("consensus")
Expand Down
2 changes: 1 addition & 1 deletion src/salmon_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl<'a> Iterator for IterEqList<'a> {
}
self.pos += 1;
let p = self.inner.offsets[i];
let l = self.inner.offsets[(i + 1)] - p;
let l = self.inner.offsets[i + 1] - p;
Some((
&self.inner.labels[p..(p + l)],
&self.inner.weights[p..(p + l)],
Expand Down
51 changes: 32 additions & 19 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use refinery::Partition;
use crate::binary_tree::{get_binary_rooted_newick_string, sort_group_id, TreeNode};
use crate::salmon_types::{EdgeInfo, EqClassExperiment, FileList, MetaInfo, TxpRecord};
use flate2::read::GzDecoder;
use statrs::distribution::{ContinuousCDF, Normal};
use std::iter::FromIterator;

Check warning on line 31 in src/util.rs

View workflow job for this annotation

GitHub Actions / Test Suite (beta)

the item `FromIterator` is imported redundantly

// use flate2::write::GzEncoder;
// use flate2::Compression;

Expand Down Expand Up @@ -66,7 +68,11 @@ fn conv_names(g: &str, tnames: &[String]) -> String {
// }

// impl MapTrait for HashMap<String, HashMap<String, u32>> {
pub fn bipart_writer(part_hash:&HashMap<String, HashMap<String, u32>>, g_bp_file: &mut File, tnames: &[String]) -> Result<bool, io::Error> {
pub fn bipart_writer(
part_hash: &HashMap<String, HashMap<String, u32>>,
g_bp_file: &mut File,
tnames: &[String],
) -> Result<bool, io::Error> {
//let l = group_bipart.len();
//let mut i = 0;
for (group_id, bpart_hash) in part_hash {
Expand Down Expand Up @@ -287,7 +293,7 @@ pub fn get_map_bw_ent(
let mut ent2_map = HashMap::<String, usize>::new();
*ent1_ent2map = vec![0; tnames.len()];
let mut j = 0;
for (_i, l) in buf_reader.lines().enumerate() {
for l in buf_reader.lines() {
let s = l.expect("Can't read line");
let mut iter = s.split_ascii_whitespace();
let ent1: String = iter.next().expect("Txp/Allele name").to_string();
Expand Down Expand Up @@ -325,7 +331,7 @@ pub fn get_t2g(
let mut genenames = Vec::<String>::new();

let mut gene_id = 0;
for (_i, l) in buf_reader.lines().enumerate() {
for l in buf_reader.lines() {
let s = l.expect("Can't read line");
let mut iter = s.split_ascii_whitespace();
let transcript: String = iter.next().expect("expect transcript name").to_string();
Expand All @@ -348,7 +354,7 @@ pub fn group_reader(filename: &std::path::Path) -> Vec<Vec<usize>> {
let buf_reader = BufReader::new(file);

let mut groups = Vec::new();
for (_i, l) in buf_reader.lines().enumerate() {
for l in buf_reader.lines() {
let s = l.unwrap();
let v: Vec<_> = s.trim().rsplit(',').collect();
let group: Vec<usize> = v.iter().map(|n| n.parse::<usize>().unwrap()).collect();
Expand Down Expand Up @@ -431,6 +437,7 @@ pub fn get_threshold(
infrv_quant: f64,
seed: u64,
file_list: &FileList,
red_quant: f64,
) -> f64 {
println!("Calculating threshold");
let gibbs_mat_sum = gibbs_mat.sum_axis(Axis(1));
Expand All @@ -454,6 +461,7 @@ pub fn get_threshold(
// let infrv_array = variance(&gibbs_mat, Axis(1));
let mut converged = false;
let starting_num_samples = (gibbs_nz.len() as f64) * 1.;
// let starting_num_samples = 1000 as f64;
println!("\n\nstarting samp : {}\n\n", starting_num_samples);

let mut starting_num_samples = starting_num_samples as usize;
Expand All @@ -463,6 +471,7 @@ pub fn get_threshold(

// let mut rng = thread_rng();
let mut rng = Pcg64::seed_from_u64(seed);
let std_norm = Normal::new(0.0, 1.0).unwrap();
while !converged {
//starting_num_samples < gibbs_nz.len(){
let die_range = Uniform::new(0, gibbs_nz.len());
Expand Down Expand Up @@ -502,25 +511,29 @@ pub fn get_threshold(
print!("dice roll: {}\r", dice_iter);
}
}
// calculate threhold
// calculate threshold
// z=(x-mu)/sigma, => x = mu + z*sigma
// We assume reduction in inferential relative variance follows a normal distribution
// x = mu + mad*1.48*quant_norm(q),
// since sd = mad*1.48, (a more robust estimator of sd for normal distribution),
// similarly, zscore can be obtained by using the inverse cumulative distribution on the quantile
sampled_infrv.sort();
let mean = mean_sum / (dice_iter as f64);
let shifted_samples: Vec<f64> = sampled_infrv
.iter()
.map(|s| s.to_f64().unwrap() - mean)
.collect();
let shifted_samples_pos: Vec<f64> = shifted_samples
.iter()
.map(|s| s.to_f64().unwrap() - mean)
.collect();

let mid = shifted_samples_pos.len() / 2;
let median = shifted_samples_pos[mid];
/* let shifted_samples_pos: Vec<f64> = shifted_samples
.iter()
.map(|s| s.to_f64().unwrap() - mean)
.collect(); */
let mid = shifted_samples.len() / 2;
let mad = shifted_samples[mid];
//let median = sampled_infrv[sampled_infrv.len()/2].to_f64().unwrap();
new_threshold = mean - (median * 1.48 * 1.95);
//let sinfrv : Vec<f64> = sampled_infrv.iter().map(|x| x.into_inner()).collect();
//new_threshold = rgsl::statistics::quantile_from_sorted_data(&sinfrv, 1, sinfrv.len(), 0.025);

new_threshold = mean + (mad.abs() * 1.48 * std_norm.inverse_cdf(red_quant / 100.0));

// let sinfrv : Vec<f64> = sampled_infrv.iter().map(|x| x.into_inner()).collect();
if ((new_threshold - old_threshold) / new_threshold) < 0.001 {
//- new_threshold).abs() < 1e-3{
converged = true;
Expand Down Expand Up @@ -796,7 +809,7 @@ pub fn eq_experiment_to_graph(
let mut golden_collapses = 0;
let mut t_golden_collapses = 0;

for (_, p) in part_vec.iter().enumerate() {
for p in part_vec.iter() {
if p.len() > 1 {
//println!("{:?}", p);
if valid_transcripts[p[0]] {
Expand Down Expand Up @@ -1004,7 +1017,7 @@ pub fn eq_experiment_to_graph(
let e = og.find_edge(va, vb);
match e {
Some(ei) => {
let mut ew = og.edge_weight_mut(ei).unwrap();
let ew = og.edge_weight_mut(ei).unwrap();
ew.count += eq_count;
ew.eqlist.push(i);
}
Expand Down Expand Up @@ -1373,7 +1386,7 @@ pub fn work_on_component(

let xn = pg::graph::NodeIndex::new(*x);
let u_to_x_inner = og.find_edge(source_node, xn).unwrap();
let mut u_to_x_info_inner = og.edge_weight_mut(u_to_x_inner).unwrap();
let u_to_x_info_inner = og.edge_weight_mut(u_to_x_inner).unwrap();
let curr_state = u_to_x_info_inner.state;

let delta = match mean_inf {
Expand Down Expand Up @@ -1505,7 +1518,7 @@ pub fn work_on_component(
v_to_x_eq = v_to_x_info.eqlist.clone();
}

let mut u_to_x_info = og.edge_weight_mut(u_to_x_inner).unwrap();
let u_to_x_info = og.edge_weight_mut(u_to_x_inner).unwrap();

// v_to_x_eq.sort();
let intersecting_eqlist = intersect(&v_to_x_eq, &u_to_x_info.eqlist);
Expand Down

0 comments on commit b33b2a1

Please sign in to comment.