-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
ensemble_contact_maps.py
124 lines (109 loc) · 4.57 KB
/
ensemble_contact_maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Form a weighted average of several distograms.
Can also/instead form a weighted average of a set of distance histogram pickle
files, so long as they have identical hyperparameters.
"""
import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from alphafold_casp13 import distogram_io
from alphafold_casp13 import parsers
flags.DEFINE_list(
'pickle_dirs', [],
'Comma separated list of directories with pickle files to ensemble.')
flags.DEFINE_list(
'weights', [],
'Comma separated list of weights for the pickle files from different dirs.')
flags.DEFINE_string(
'output_dir', None, 'Directory where to save results of the evaluation.')
FLAGS = flags.FLAGS
def ensemble_distance_histograms(pickle_dirs, weights, output_dir):
"""Find all the contact maps in the first dir, then ensemble across dirs."""
if len(pickle_dirs) <= 1:
logging.warning('Pointless to ensemble %d pickle_dirs %s',
len(pickle_dirs), pickle_dirs)
# Carry on if there's one dir, otherwise do nothing.
if not pickle_dirs:
return
tf.io.gfile.makedirs(output_dir)
one_dir_pickle_files = tf.io.gfile.glob(
os.path.join(pickle_dirs[0], '*.pickle'))
assert one_dir_pickle_files, pickle_dirs[0]
original_files = len(one_dir_pickle_files)
logging.info('Found %d files %d in first of %d dirs',
original_files, len(one_dir_pickle_files), len(pickle_dirs))
targets = [os.path.splitext(os.path.basename(f))[0]
for f in one_dir_pickle_files]
skipped = 0
wrote = 0
for t in targets:
dump_file = os.path.join(output_dir, t + '.pickle')
pickle_files = [os.path.join(pickle_dir, t + '.pickle')
for pickle_dir in pickle_dirs]
_, new_dict = ensemble_one_distance_histogram(pickle_files, weights)
if new_dict is not None:
wrote += 1
distogram_io.save_distance_histogram_from_dict(dump_file, new_dict)
msg = 'Distograms Wrote %s %d / %d Skipped %d %s' % (
t, wrote, len(one_dir_pickle_files), skipped, dump_file)
logging.info(msg)
def ensemble_one_distance_histogram(pickle_files, weights):
"""Average the given pickle_files and dump."""
dicts = []
sequence = None
max_dim = None
for picklefile in pickle_files:
if not tf.io.gfile.exists(picklefile):
logging.warning('missing %s', picklefile)
break
logging.info('loading pickle file %s', picklefile)
distance_histogram_dict = parsers.parse_distance_histogram_dict(picklefile)
if sequence is None:
sequence = distance_histogram_dict['sequence']
else:
assert sequence == distance_histogram_dict['sequence'], '%s vs %s' % (
sequence, distance_histogram_dict['sequence'])
dicts.append(distance_histogram_dict)
assert dicts[-1]['probs'].shape[0] == dicts[-1]['probs'].shape[1], (
'%d vs %d' % (dicts[-1]['probs'].shape[0], dicts[-1]['probs'].shape[1]))
assert (dicts[0]['probs'].shape[0:2] == dicts[-1]['probs'].shape[0:2]
), ('%d vs %d' % (dicts[0]['probs'].shape, dicts[-1]['probs'].shape))
if max_dim is None or max_dim < dicts[-1]['probs'].shape[2]:
max_dim = dicts[-1]['probs'].shape[2]
if len(dicts) != len(pickle_files):
logging.warning('length mismatch\n%s\nVS\n%s', dicts, pickle_files)
return sequence, None
ensemble_hist = (
sum(w * c['probs'] for w, c in zip(weights, dicts)) / sum(weights))
new_dict = dict(dicts[0])
new_dict['probs'] = ensemble_hist
return sequence, new_dict
def main(argv):
del argv # Unused.
num_dirs = len(FLAGS.pickle_dirs)
if FLAGS.weights:
assert len(FLAGS.weights) == num_dirs, (
'Supply as many weights as pickle_dirs, or no weights')
weights = [float(w) for w in FLAGS.weights]
else:
weights = [1.0 for w in range(num_dirs)]
ensemble_distance_histograms(
pickle_dirs=FLAGS.pickle_dirs,
weights=weights,
output_dir=FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)