diff --git a/opal.py b/opal.py index 61cb569..b881bd3 100755 --- a/opal.py +++ b/opal.py @@ -111,14 +111,13 @@ def print_by_tool(output_dir, pd_metrics): table.fillna('na').to_csv(os.path.join(output_dir, "by_tool", toolname + ".tsv"), sep='\t') -def compute_metrics(sample_metadata, profile, gs_pf_profile, gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage, - normalize, branch_length_fun): +def compute_metrics(sample_metadata, profile, gs_pf_profile, gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage): # Unifrac if isinstance(profile, PF.Profile): pf_profile = profile else: - pf_profile = PF.Profile(sample_metadata=sample_metadata, profile=profile, branch_length_fun=branch_length_fun) - unifrac = uf.compute_unifrac(gs_pf_profile, pf_profile, normalize) + pf_profile = PF.Profile(sample_metadata=sample_metadata, profile=profile) + unifrac = uf.compute_unifrac(gs_pf_profile, pf_profile) # Shannon shannon = sh.compute_shannon_index(rank_to_taxid_to_percentage) @@ -149,7 +148,7 @@ def load_profiles(gold_standard_file, profiles_files, normalize): return sample_ids_list, gs_samples_list, profiles_list_to_samples_list -def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, normalize, filter_tail_percentage, branch_length_fun): +def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, filter_tail_percentage): gs_id_to_rank_to_taxid_to_percentage = {} gs_id_to_pf_profile = {} pd_metrics = pd.DataFrame() @@ -157,14 +156,12 @@ def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, normalize, for sample in gs_samples_list: sample_id, sample_metadata, profile = sample gs_id_to_rank_to_taxid_to_percentage[sample_id] = load_data.get_rank_to_taxid_to_percentage(profile) - gs_id_to_pf_profile[sample_id] = PF.Profile(sample_metadata=sample_metadata, profile=profile, branch_length_fun=branch_length_fun) + gs_id_to_pf_profile[sample_id] = PF.Profile(sample_metadata=sample_metadata, profile=profile) unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, gs_id_to_pf_profile[sample_id], gs_id_to_pf_profile[sample_id], gs_id_to_rank_to_taxid_to_percentage[sample_id], - gs_id_to_rank_to_taxid_to_percentage[sample_id], - normalize, - branch_length_fun) + gs_id_to_rank_to_taxid_to_percentage[sample_id]) pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, c.GS, braycurtis, shannon, binary_metrics, l1norm, unifrac)], ignore_index=True) if filter_tail_percentage: metrics_list = pd_metrics['metric'].unique().tolist() @@ -189,20 +186,17 @@ def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, normalize, unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile, gs_pf_profile, gs_rank_to_taxid_to_percentage, - rank_to_taxid_to_percentage, - normalize, - branch_length_fun) + rank_to_taxid_to_percentage) rename_as_unfiltered = True if filter_tail_percentage else False pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac, rename_as_unfiltered)], ignore_index=True) if filter_tail_percentage: rank_to_taxid_to_percentage_filtered = \ load_data.get_rank_to_taxid_to_percentage_filtered(rank_to_taxid_to_percentage, filter_tail_percentage) - unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile, gs_pf_profile, + profile_filtered = [prediction for prediction in profile if prediction.taxid in rank_to_taxid_to_percentage_filtered[prediction.rank]] + unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile_filtered, gs_pf_profile, gs_rank_to_taxid_to_percentage, - rank_to_taxid_to_percentage_filtered, - normalize, - branch_length_fun) + rank_to_taxid_to_percentage_filtered) pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac)], ignore_index=True) one_profile_assessed = True @@ -319,7 +313,6 @@ def main(): group2 = parser.add_argument_group('optional arguments') group2.add_argument('-n', '--normalize', help='Normalize samples', action='store_true') group2.add_argument('-f', '--filter', help='Filter out the predictions with the smallest relative abundances summing up to [FILTER]%% within a rank (affects only precision, default: 0)', type=float) - group2.add_argument('-b', '--branch_length_function', help='UniFrac tree branch length function (default: "lambda x: 1/x", x=tree depth)', required=False, default='lambda x: 1/x') group2.add_argument('-p', '--plot_abundances', help='Plot abundances in the gold standard (can take some minutes)', action='store_true') group2.add_argument('-l', '--labels', help='Comma-separated profiles names', required=False) group2.add_argument('-t', '--time', help='Comma-separated runtimes in hours', required=False) @@ -356,9 +349,7 @@ def main(): pd_metrics = evaluate(gs_samples_list, profiles_list_to_samples_list, labels, - args.normalize, - args.filter, - uf.get_branch_length_function(args.branch_length_function)) + args.filter) time_list, memory_list = get_time_memory(args.time, args.memory, args.profiles_files) if time_list or memory_list: pd_metrics = concat_time_memory(labels, time_list, memory_list, pd_metrics) diff --git a/version.py b/version.py index a682442..d521168 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = '1.0.9' +__version__ = '1.0.10'