Skip to content

Commit

Permalink
Allow empirical error and saving CSV files (#129)
Browse files Browse the repository at this point in the history
* Allow empirical error and saving CSV files

* Tidy some linting

* more linting

* final linting

* Add extra columns to allow simple CSV combining
  • Loading branch information
hyanwong authored and jeromekelleher committed Oct 15, 2018
1 parent 343e3c1 commit 34fc801
Showing 1 changed file with 117 additions and 45 deletions.
162 changes: 117 additions & 45 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,75 @@ def make_errors(v, p):
return w


def generate_samples(ts, error_p):
"""
Returns samples with a bits flipped with a specified probability.
Rejects any variants that result in a fixed column.
"""
G = np.zeros((ts.num_sites, ts.num_samples), dtype=np.int8)
for variant in ts.variants():
done = False
# Reject any columns that have no 1s or no zeros
while not done:
G[variant.index] = make_errors(variant.genotypes, error_p)
s = np.sum(G[variant.index])
done = 0 < s < ts.sample_size
return G
def make_errors_genotype_model(g, error_probs):
"""
Given an empirically estimated error probability matrix, resample for a particular
variant. Determine variant frequency and true genotype (g0, g1, or g2),
then return observed genotype based on row in error_probs with nearest
frequency. Treat each pair of alleles as a diploid individual.
"""
w = np.copy(g)

# Make diploid (iterate each pair of alleles)
genos = [(w[i], w[i+1]) for i in range(0, w.shape[0], 2)]

# Record the true genotypes
g0 = [i for i, x in enumerate(genos) if x == (0, 0)]
g1a = [i for i, x in enumerate(genos) if x == (1, 0)]
g1b = [i for i, x in enumerate(genos) if x == (0, 1)]
g2 = [i for i, x in enumerate(genos) if x == (1, 1)]

for idx in g0:
result = [(0, 0), (1, 0), (1, 1)][
np.random.choice(3, p=error_probs[['p00', 'p01', 'p02']].values[0])]
if result == (1, 0):
genos[idx] = [(0, 1), (1, 0)][np.random.choice(2)]
else:
genos[idx] = result
for idx in g1a:
genos[idx] = [(0, 0), (1, 0), (1, 1)][
np.random.choice(3, p=error_probs[['p10', 'p11', 'p12']].values[0])]
for idx in g1b:
genos[idx] = [(0, 0), (0, 1), (1, 1)][
np.random.choice(3, p=error_probs[['p10', 'p11', 'p12']].values[0])]
for idx in g2:
result = [(0, 0), (1, 0), (1, 1)][
np.random.choice(3, p=error_probs[['p20', 'p21', 'p22']].values[0])]
if result == (1, 0):
genos[idx] = [(0, 1), (1, 0)][np.random.choice(2)]
else:
genos[idx] = result

return(np.array(sum(genos, ())))


def generate_samples(ts, error_param=0):
"""
Generate a samples file from a simulated ts based on the empirically estimated
error matrix saved in self.error_matrix.
Reject any variants that result in a fixed column.
"""
assert ts.num_sites != 0
sd = tsinfer.SampleData(sequence_length=ts.sequence_length)
try:
e = float(error_param)
for v in ts.variants():
g = v.genotypes if error_param == 0 else make_errors(v.genotypes, e)
sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g)
except ValueError:
error_matrix = pd.read_csv(error_param)
# Error_param is not a number => is a error file
# First record the allele frequency
for v in ts.variants():
m = v.genotypes.shape[0]
frequency = np.sum(v.genotypes) / m
# Find closest row in error matrix file
closest_row = (error_matrix['freq']-frequency).abs().argsort()[:1]
closest_freq = error_matrix.iloc[closest_row]
g = make_errors_genotype_model(v.genotypes, closest_freq)
sd.add_site(position=v.site.position, alleles=v.alleles, genotypes=g)
sd.finalise()
return sd


def run_infer(ts, engine=tsinfer.C_ENGINE, path_compression=True, exact_ancestors=False):
Expand Down Expand Up @@ -576,12 +630,7 @@ def sim_true_and_inferred_ancestors(args):
"random_seed": rng.randint(1, 2**30)}
ts = msprime.simulate(**sim_args)

# ts = tsinfer.insert_errors(ts, args.error_probability, seed=args.random_seed)
V = generate_samples(ts, args.error_probability)

with tsinfer.SampleData(sequence_length=ts.sequence_length) as sample_data:
for s, v in zip(ts.sites(), V):
sample_data.add_site(s.position, v, ["0", "1"])
sample_data = generate_samples(ts, args.error)

inferred_anc = tsinfer.generate_ancestors(sample_data, engine=args.engine)
true_anc = tsinfer.AncestorData(sample_data)
Expand Down Expand Up @@ -616,10 +665,16 @@ def run_ancestor_comparison(args):
estimated_anc_length = estimated_anc.ancestors_length / 1000
exact_anc_length = exact_anc.ancestors_length / 1000
max_length = sample_data.sequence_length / 1000
try:
err = float(args.error)
except ValueError:
err = args.error.replace("/", "_")
if err.endswith(".csv"):
err = err[:-len(".csv")]
name_format = os.path.join(
args.destination_dir, "anc-comp_n={}_L={}_mu={}_rho={}_err={}_{{}}".format(
args.sample_size, args.length, args.mutation_rate, args.recombination_rate,
args.error_probability))
err))
if args.store_data:
# TODO Are we using this option for anything?
filename = name_format.format("length.json")
Expand Down Expand Up @@ -823,11 +878,16 @@ def run_ancestor_quality(args):
so that we only look at the regions of overlap between true and inferred ancestors
"""
sample_data, exact_anc, estim_anc = sim_true_and_inferred_ancestors(args)

try:
err = float(args.error)
except ValueError:
err = args.error.replace("/", "_")
if err.endswith(".csv"):
err = err[:-len(".csv")]
name_format = os.path.join(
args.destination_dir, "anc-qual_n={}_L={}_mu={}_rho={}_err={}_{{}}".format(
args.sample_size, args.length, args.mutation_rate, args.recombination_rate,
args.error_probability))
err))

anc_indices = ancestor_data_by_pos(exact_anc, estim_anc)
shared_positions = np.array(list(sorted(anc_indices.keys())))
Expand All @@ -850,7 +910,8 @@ def run_ancestor_quality(args):
olap_n_should_be_0_low_eq_freq = {}
olap_lft = {}
olap_rgt = {}
true_length = {}
true_len = {}
est_len = {}
true_time = {}
# find the left and right edges of the overlap - iterate by true time in reverse
for i, focal_pos in enumerate(
Expand Down Expand Up @@ -911,15 +972,16 @@ def run_ancestor_quality(args):
olap_n_sites[focal_pos] = len(exact_comp)
olap_lft[focal_pos] = olap_start
olap_rgt[focal_pos] = olap_end
true_length[focal_pos] = exact_anc.ancestors_length[:][exact_index]
true_len[focal_pos] = exact_anc.ancestors_length[:][exact_index]
est_len[focal_pos] = estim_anc.ancestors_length[:][estim_index]
true_time[focal_pos] = exact_anc.ancestors_time[:][exact_index]
sites_freq = estim_freq[olap_start_estim:olap_end_estim]
higher_freq = sites_freq[small_estim_mask] > freq[focal_pos]
olap_n_should_be_1_higher_freq[focal_pos] = np.sum(should_be_1 & higher_freq)
olap_n_should_be_0_higher_freq[focal_pos] = np.sum(should_be_0 & higher_freq)
olap_n_should_be_1_low_eq_freq[focal_pos] = np.sum(should_be_1 & ~higher_freq)
olap_n_should_be_0_low_eq_freq[focal_pos] = np.sum(should_be_0 & ~higher_freq)
assert olap_rgt[focal_pos]-olap_lft[focal_pos] <= true_length[focal_pos]
assert olap_rgt[focal_pos]-olap_lft[focal_pos] <= true_len[focal_pos]
assert (olap_n_should_be_1_higher_freq[focal_pos] +
olap_n_should_be_0_higher_freq[focal_pos] +
olap_n_should_be_1_low_eq_freq[focal_pos] +
Expand Down Expand Up @@ -1003,13 +1065,13 @@ def run_ancestor_quality(args):

# create the data for use, ordered by real time (and make a new time index)
data = pd.DataFrame.from_records(
[(p, freq[p], olap_n_sites[p], true_length[p], olap_rgt[p]-olap_lft[p],
[(p, freq[p], olap_n_sites[p], true_len[p], est_len[p], olap_rgt[p]-olap_lft[p],
olap_n_should_be_1_higher_freq[p], olap_n_should_be_1_low_eq_freq[p],
olap_n_should_be_0_higher_freq[p], olap_n_should_be_0_low_eq_freq[p],
t, true_time[p])
for t, p in enumerate(sorted(shared_positions, key=lambda x: true_time[x]))],
columns=(
"position", "Frequency", "n_sites", "Real length", "Overlap length",
"position", "Frequency", "n_sites", "Real length", "Estim length", "Overlap",
"err_hiF should be 1", "err_loF should be 1",
"err_hiF should be 0", "err_loF should be 0",
"Known time order", "orig_time"))
Expand All @@ -1021,7 +1083,6 @@ def run_ancestor_quality(args):
freq_repeated = np.repeat(np.arange(len(freq_bins)), freq_bins)
# add another column on to the expected freq, as calculated from the actual time
data['expected_Frequency'] = freq_repeated[data["Known time order"].values]

data['n_mismatches'] = (data["err_hiF should be 1"] + data["err_loF should be 1"] +
data["err_hiF should be 0"] + data["err_loF should be 0"])
data['Inaccuracy'] = data.n_mismatches / data.n_sites
Expand All @@ -1030,20 +1091,27 @@ def run_ancestor_quality(args):
(data["err_hiF should be 1"] + data["err_loF should be 1"]) / data.n_mismatches
data['err_hiF'] = (data["err_hiF should be 1"] + data["err_hiF should be 0"])
data['err_loF'] = (data["err_loF should be 1"] + data["err_loF should be 0"])
Inaccuracy_label = "Sequence difference in overlapping region"

print("{} ancestors, {} with at least one error".format(
len(data), np.sum(data.n_mismatches != 0)))
print(data[["err_hiF should be 1", "err_loF should be 1",
"err_hiF should be 0", "err_loF should be 0"]].sum())
data[["err_hiF should be 1", "err_loF should be 1",
"err_hiF should be 0", "err_loF should be 0", "n_sites"]].to_csv(
name_format.format("error_data.csv"))

if args.csv_only:
# Add some standard params to the CSV to make it easy to paste CSVs together
data["sample_size"] = args.sample_size
data["seq_length"] = args.length
data["mu"] = args.mutation_rate
data["rho"] = args.recombination_rate
data["seq_error"] = args.error
data.to_csv(name_format.format("error_data.csv"))
return

# Now do the plots
Inaccuracy_label = "Sequence difference in overlapping region"
name = "quality-by-missingness"
x_axis_length_metric = "fraction" # or e.g. "fraction"
data['abs_missing_l'] = (data["Real length"] - data["Overlap length"])+1
data['rel_missing_l'] = 1 - (data["Overlap length"] / data["Real length"])
data['abs_missing_l'] = (data["Real length"] - data["Overlap"])+1
data['rel_missing_l'] = 1 - (data["Overlap"] / data["Real length"])
if x_axis_length_metric == "absolute":
x_col = 'abs_missing_l'
ax_params = {
Expand Down Expand Up @@ -1095,6 +1163,7 @@ def run_ancestor_quality(args):
x='Known time order', y="Inaccuracy", c='Inference error bias', cmap='coolwarm',
s=data.n_mismatches.values+1)
"""
# Add some tiny labels, to aid identification in a pdf plot
labels = ["{:.1f}\n{:.0f}\n{:.0f}".format(
r['position'], r['orig_time'], r['n_mismatches']) for i, r in data.iterrows()]
for x,y,s in zip(data['Known time order'].values, data["Inaccuracy"].values, labels):
Expand Down Expand Up @@ -1167,7 +1236,7 @@ def run_ancestor_quality(args):

name = "quality-by-length"
ax = data.plot.scatter(
x="Overlap length", y="Inaccuracy", c="Frequency", cmap='brg', s=2,
x="Overlap", y="Inaccuracy", c="Frequency", cmap='brg', s=2,
norm=NormalizeBandWidths(band_widths=freq_bins))
ax.set(ylabel=Inaccuracy_label, xscale='log', ylim=(-0.01, 1), xlim=(1))
save_figure(name_format.format(name))
Expand Down Expand Up @@ -1434,7 +1503,7 @@ def setup_logging(args):
"for a single instance."))
cli.add_logging_arguments(parser)
parser.set_defaults(runner=run_ancestor_comparison)
parser.add_argument("--sample-size", "-n", type=int, default=60)
parser.add_argument("--sample-size", "-n", type=int, default=100)
parser.add_argument(
"--length", "-l", type=float, default=1, help="Sequence length in MB")
parser.add_argument(
Expand All @@ -1444,13 +1513,13 @@ def setup_logging(args):
"--mutation-rate", "-u", type=float, default=1e-8,
help="Mutation rate")
parser.add_argument(
"--error-probability", "-e", type=float, default=0,
help="Error probability")
"--error", "-e", default="0",
help="Error: either a probability or a csv filename to use for empirical error")
parser.add_argument("--random-seed", "-s", type=int, default=None)
parser.add_argument("--destination-dir", "-d", default="")
parser.add_argument(
"--store-data", "-S", action="store_true",
help="Store the raw data.")
help="Store some raw data.")
parser.add_argument(
"--length-scale", "-X", choices=['linear', 'log'], default="linear",
help='Length scale for distances when plotting')
Expand All @@ -1467,7 +1536,7 @@ def setup_logging(args):
"for a single instance."))
cli.add_logging_arguments(parser)
parser.set_defaults(runner=run_ancestor_quality)
parser.add_argument("--sample-size", "-n", type=int, default=60)
parser.add_argument("--sample-size", "-n", type=int, default=100)
parser.add_argument(
"--length", "-l", type=float, default=1, help="Sequence length in MB")
parser.add_argument(
Expand All @@ -1477,13 +1546,16 @@ def setup_logging(args):
"--mutation-rate", "-u", type=float, default=1e-8,
help="Mutation rate")
parser.add_argument(
"--error-probability", "-e", type=float, default=0,
help="Error probability")
"--error", "-e", default="0",
help="Error: either a probability or a csv filename to use for empirical error")
parser.add_argument("--random-seed", "-s", type=int, default=None)
parser.add_argument("--destination-dir", "-d", default="")
parser.add_argument(
"--print-bad-ancestors", "-b", nargs='?', const="inferred",
choices=['inferred', 'all'], help="Also print out all the bad ancestor matches")
parser.add_argument(
"--csv-only", "-C", action="store_true",
help='Do not create plots, but output a csv file of the data for later plotting')
parser.add_argument(
"--length-scale", "-X", choices=['linear', 'log'], default="linear",
help='Length scale for distances when plotting')
Expand Down

0 comments on commit 34fc801

Please sign in to comment.