From 34fc80169cb5d791255bd5536089e8a488aa2c2b Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 15 Oct 2018 16:12:33 +0100 Subject: [PATCH] Allow empirical error and saving CSV files (#129) * Allow empirical error and saving CSV files * Tidy some linting * more linting * final linting * Add extra columns to allow simple CSV combining --- evaluation.py | 162 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 117 insertions(+), 45 deletions(-) diff --git a/evaluation.py b/evaluation.py index 3fdcb4dc..b87adbe3 100644 --- a/evaluation.py +++ b/evaluation.py @@ -57,21 +57,75 @@ def make_errors(v, p): return w -def generate_samples(ts, error_p): - """ - Returns samples with a bits flipped with a specified probability. - - Rejects any variants that result in a fixed column. - """ - G = np.zeros((ts.num_sites, ts.num_samples), dtype=np.int8) - for variant in ts.variants(): - done = False - # Reject any columns that have no 1s or no zeros - while not done: - G[variant.index] = make_errors(variant.genotypes, error_p) - s = np.sum(G[variant.index]) - done = 0 < s < ts.sample_size - return G +def make_errors_genotype_model(g, error_probs): + """ + Given an empirically estimated error probability matrix, resample for a particular + variant. Determine variant frequency and true genotype (g0, g1, or g2), + then return observed genotype based on row in error_probs with nearest + frequency. Treat each pair of alleles as a diploid individual. + """ + w = np.copy(g) + + # Make diploid (iterate each pair of alleles) + genos = [(w[i], w[i+1]) for i in range(0, w.shape[0], 2)] + + # Record the true genotypes + g0 = [i for i, x in enumerate(genos) if x == (0, 0)] + g1a = [i for i, x in enumerate(genos) if x == (1, 0)] + g1b = [i for i, x in enumerate(genos) if x == (0, 1)] + g2 = [i for i, x in enumerate(genos) if x == (1, 1)] + + for idx in g0: + result = [(0, 0), (1, 0), (1, 1)][ + np.random.choice(3, p=error_probs[['p00', 'p01', 'p02']].values[0])] + if result == (1, 0): + genos[idx] = [(0, 1), (1, 0)][np.random.choice(2)] + else: + genos[idx] = result + for idx in g1a: + genos[idx] = [(0, 0), (1, 0), (1, 1)][ + np.random.choice(3, p=error_probs[['p10', 'p11', 'p12']].values[0])] + for idx in g1b: + genos[idx] = [(0, 0), (0, 1), (1, 1)][ + np.random.choice(3, p=error_probs[['p10', 'p11', 'p12']].values[0])] + for idx in g2: + result = [(0, 0), (1, 0), (1, 1)][ + np.random.choice(3, p=error_probs[['p20', 'p21', 'p22']].values[0])] + if result == (1, 0): + genos[idx] = [(0, 1), (1, 0)][np.random.choice(2)] + else: + genos[idx] = result + + return(np.array(sum(genos, ()))) + + +def generate_samples(ts, error_param=0): + """ + Generate a samples file from a simulated ts based on the empirically estimated + error matrix saved in self.error_matrix. + Reject any variants that result in a fixed column. + """ + assert ts.num_sites != 0 + sd = tsinfer.SampleData(sequence_length=ts.sequence_length) + try: + e = float(error_param) + for v in ts.variants(): + g = v.genotypes if error_param == 0 else make_errors(v.genotypes, e) + sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g) + except ValueError: + error_matrix = pd.read_csv(error_param) + # Error_param is not a number => is a error file + # First record the allele frequency + for v in ts.variants(): + m = v.genotypes.shape[0] + frequency = np.sum(v.genotypes) / m + # Find closest row in error matrix file + closest_row = (error_matrix['freq']-frequency).abs().argsort()[:1] + closest_freq = error_matrix.iloc[closest_row] + g = make_errors_genotype_model(v.genotypes, closest_freq) + sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g) + sd.finalise() + return sd def run_infer(ts, engine=tsinfer.C_ENGINE, path_compression=True, exact_ancestors=False): @@ -576,12 +630,7 @@ def sim_true_and_inferred_ancestors(args): "random_seed": rng.randint(1, 2**30)} ts = msprime.simulate(**sim_args) - # ts = tsinfer.insert_errors(ts, args.error_probability, seed=args.random_seed) - V = generate_samples(ts, args.error_probability) - - with tsinfer.SampleData(sequence_length=ts.sequence_length) as sample_data: - for s, v in zip(ts.sites(), V): - sample_data.add_site(s.position, v, ["0", "1"]) + sample_data = generate_samples(ts, args.error) inferred_anc = tsinfer.generate_ancestors(sample_data, engine=args.engine) true_anc = tsinfer.AncestorData(sample_data) @@ -616,10 +665,16 @@ def run_ancestor_comparison(args): estimated_anc_length = estimated_anc.ancestors_length / 1000 exact_anc_length = exact_anc.ancestors_length / 1000 max_length = sample_data.sequence_length / 1000 + try: + err = float(args.error) + except ValueError: + err = args.error.replace("/", "_") + if err.endswith(".csv"): + err = err[:-len(".csv")] name_format = os.path.join( args.destination_dir, "anc-comp_n={}_L={}_mu={}_rho={}_err={}_{{}}".format( args.sample_size, args.length, args.mutation_rate, args.recombination_rate, - args.error_probability)) + err)) if args.store_data: # TODO Are we using this option for anything? filename = name_format.format("length.json") @@ -823,11 +878,16 @@ def run_ancestor_quality(args): so that we only look at the regions of overlap between true and inferred ancestors """ sample_data, exact_anc, estim_anc = sim_true_and_inferred_ancestors(args) - + try: + err = float(args.error) + except ValueError: + err = args.error.replace("/", "_") + if err.endswith(".csv"): + err = err[:-len(".csv")] name_format = os.path.join( args.destination_dir, "anc-qual_n={}_L={}_mu={}_rho={}_err={}_{{}}".format( args.sample_size, args.length, args.mutation_rate, args.recombination_rate, - args.error_probability)) + err)) anc_indices = ancestor_data_by_pos(exact_anc, estim_anc) shared_positions = np.array(list(sorted(anc_indices.keys()))) @@ -850,7 +910,8 @@ def run_ancestor_quality(args): olap_n_should_be_0_low_eq_freq = {} olap_lft = {} olap_rgt = {} - true_length = {} + true_len = {} + est_len = {} true_time = {} # find the left and right edges of the overlap - iterate by true time in reverse for i, focal_pos in enumerate( @@ -911,7 +972,8 @@ def run_ancestor_quality(args): olap_n_sites[focal_pos] = len(exact_comp) olap_lft[focal_pos] = olap_start olap_rgt[focal_pos] = olap_end - true_length[focal_pos] = exact_anc.ancestors_length[:][exact_index] + true_len[focal_pos] = exact_anc.ancestors_length[:][exact_index] + est_len[focal_pos] = estim_anc.ancestors_length[:][estim_index] true_time[focal_pos] = exact_anc.ancestors_time[:][exact_index] sites_freq = estim_freq[olap_start_estim:olap_end_estim] higher_freq = sites_freq[small_estim_mask] > freq[focal_pos] @@ -919,7 +981,7 @@ def run_ancestor_quality(args): olap_n_should_be_0_higher_freq[focal_pos] = np.sum(should_be_0 & higher_freq) olap_n_should_be_1_low_eq_freq[focal_pos] = np.sum(should_be_1 & ~higher_freq) olap_n_should_be_0_low_eq_freq[focal_pos] = np.sum(should_be_0 & ~higher_freq) - assert olap_rgt[focal_pos]-olap_lft[focal_pos] <= true_length[focal_pos] + assert olap_rgt[focal_pos]-olap_lft[focal_pos] <= true_len[focal_pos] assert (olap_n_should_be_1_higher_freq[focal_pos] + olap_n_should_be_0_higher_freq[focal_pos] + olap_n_should_be_1_low_eq_freq[focal_pos] + @@ -1003,13 +1065,13 @@ def run_ancestor_quality(args): # create the data for use, ordered by real time (and make a new time index) data = pd.DataFrame.from_records( - [(p, freq[p], olap_n_sites[p], true_length[p], olap_rgt[p]-olap_lft[p], + [(p, freq[p], olap_n_sites[p], true_len[p], est_len[p], olap_rgt[p]-olap_lft[p], olap_n_should_be_1_higher_freq[p], olap_n_should_be_1_low_eq_freq[p], olap_n_should_be_0_higher_freq[p], olap_n_should_be_0_low_eq_freq[p], t, true_time[p]) for t, p in enumerate(sorted(shared_positions, key=lambda x: true_time[x]))], columns=( - "position", "Frequency", "n_sites", "Real length", "Overlap length", + "position", "Frequency", "n_sites", "Real length", "Estim length", "Overlap", "err_hiF should be 1", "err_loF should be 1", "err_hiF should be 0", "err_loF should be 0", "Known time order", "orig_time")) @@ -1021,7 +1083,6 @@ def run_ancestor_quality(args): freq_repeated = np.repeat(np.arange(len(freq_bins)), freq_bins) # add another column on to the expected freq, as calculated from the actual time data['expected_Frequency'] = freq_repeated[data["Known time order"].values] - data['n_mismatches'] = (data["err_hiF should be 1"] + data["err_loF should be 1"] + data["err_hiF should be 0"] + data["err_loF should be 0"]) data['Inaccuracy'] = data.n_mismatches / data.n_sites @@ -1030,20 +1091,27 @@ def run_ancestor_quality(args): (data["err_hiF should be 1"] + data["err_loF should be 1"]) / data.n_mismatches data['err_hiF'] = (data["err_hiF should be 1"] + data["err_hiF should be 0"]) data['err_loF'] = (data["err_loF should be 1"] + data["err_loF should be 0"]) - Inaccuracy_label = "Sequence difference in overlapping region" print("{} ancestors, {} with at least one error".format( len(data), np.sum(data.n_mismatches != 0))) print(data[["err_hiF should be 1", "err_loF should be 1", "err_hiF should be 0", "err_loF should be 0"]].sum()) - data[["err_hiF should be 1", "err_loF should be 1", - "err_hiF should be 0", "err_loF should be 0", "n_sites"]].to_csv( - name_format.format("error_data.csv")) - + if args.csv_only: + # Add some standard params to the CSV to make it easy to paste CSVs together + data["sample_size"] = args.sample_size + data["seq_length"] = args.length + data["mu"] = args.mutation_rate + data["rho"] = args.recombination_rate + data["seq_error"] = args.error + data.to_csv(name_format.format("error_data.csv")) + return + + # Now do the plots + Inaccuracy_label = "Sequence difference in overlapping region" name = "quality-by-missingness" x_axis_length_metric = "fraction" # or e.g. "fraction" - data['abs_missing_l'] = (data["Real length"] - data["Overlap length"])+1 - data['rel_missing_l'] = 1 - (data["Overlap length"] / data["Real length"]) + data['abs_missing_l'] = (data["Real length"] - data["Overlap"])+1 + data['rel_missing_l'] = 1 - (data["Overlap"] / data["Real length"]) if x_axis_length_metric == "absolute": x_col = 'abs_missing_l' ax_params = { @@ -1095,6 +1163,7 @@ def run_ancestor_quality(args): x='Known time order', y="Inaccuracy", c='Inference error bias', cmap='coolwarm', s=data.n_mismatches.values+1) """ + # Add some tiny labels, to aid identification in a pdf plot labels = ["{:.1f}\n{:.0f}\n{:.0f}".format( r['position'], r['orig_time'], r['n_mismatches']) for i, r in data.iterrows()] for x,y,s in zip(data['Known time order'].values, data["Inaccuracy"].values, labels): @@ -1167,7 +1236,7 @@ def run_ancestor_quality(args): name = "quality-by-length" ax = data.plot.scatter( - x="Overlap length", y="Inaccuracy", c="Frequency", cmap='brg', s=2, + x="Overlap", y="Inaccuracy", c="Frequency", cmap='brg', s=2, norm=NormalizeBandWidths(band_widths=freq_bins)) ax.set(ylabel=Inaccuracy_label, xscale='log', ylim=(-0.01, 1), xlim=(1)) save_figure(name_format.format(name)) @@ -1434,7 +1503,7 @@ def setup_logging(args): "for a single instance.")) cli.add_logging_arguments(parser) parser.set_defaults(runner=run_ancestor_comparison) - parser.add_argument("--sample-size", "-n", type=int, default=60) + parser.add_argument("--sample-size", "-n", type=int, default=100) parser.add_argument( "--length", "-l", type=float, default=1, help="Sequence length in MB") parser.add_argument( @@ -1444,13 +1513,13 @@ def setup_logging(args): "--mutation-rate", "-u", type=float, default=1e-8, help="Mutation rate") parser.add_argument( - "--error-probability", "-e", type=float, default=0, - help="Error probability") + "--error", "-e", default="0", + help="Error: either a probability or a csv filename to use for empirical error") parser.add_argument("--random-seed", "-s", type=int, default=None) parser.add_argument("--destination-dir", "-d", default="") parser.add_argument( "--store-data", "-S", action="store_true", - help="Store the raw data.") + help="Store some raw data.") parser.add_argument( "--length-scale", "-X", choices=['linear', 'log'], default="linear", help='Length scale for distances when plotting') @@ -1467,7 +1536,7 @@ def setup_logging(args): "for a single instance.")) cli.add_logging_arguments(parser) parser.set_defaults(runner=run_ancestor_quality) - parser.add_argument("--sample-size", "-n", type=int, default=60) + parser.add_argument("--sample-size", "-n", type=int, default=100) parser.add_argument( "--length", "-l", type=float, default=1, help="Sequence length in MB") parser.add_argument( @@ -1477,13 +1546,16 @@ def setup_logging(args): "--mutation-rate", "-u", type=float, default=1e-8, help="Mutation rate") parser.add_argument( - "--error-probability", "-e", type=float, default=0, - help="Error probability") + "--error", "-e", default="0", + help="Error: either a probability or a csv filename to use for empirical error") parser.add_argument("--random-seed", "-s", type=int, default=None) parser.add_argument("--destination-dir", "-d", default="") parser.add_argument( "--print-bad-ancestors", "-b", nargs='?', const="inferred", choices=['inferred', 'all'], help="Also print out all the bad ancestor matches") + parser.add_argument( + "--csv-only", "-C", action="store_true", + help='Do not create plots, but output a csv file of the data for later plotting') parser.add_argument( "--length-scale", "-X", choices=['linear', 'log'], default="linear", help='Length scale for distances when plotting')