forked from google-deepmind/alphafold3
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpathify_input_JSON.py
executable file
·249 lines (221 loc) · 9.99 KB
/
pathify_input_JSON.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
WARNING: This script has been generated by ChatGPT.
A script that processes a JSON input describing protein/RNA/DNA/ligand sequences,
extracts embedded MSA/template strings, saves them to files, and updates paths
in the JSON. The processed JSON is then saved to a new file. Use --help to see
available command-line options.
Example usage:
python pathify_input_JSON.py input.json
python pathify_input_JSON.py input.json --force
python pathify_input_JSON.py input.json --output-dir mydir --output-json myprocessed.json
"""
#TODO: Parse the input using AlphaFold3 parsers to include validation? Or keep it like this to support any JSON format?
import os
import sys
import json
import argparse
import re
def parse_arguments():
parser = argparse.ArgumentParser(
description=(
"Extract MSAs and templates from AlphaFold3 input JSON, save them to files, "
"and update the JSON to point to the new files."
"WARNING: This script has been generated by ChatGPT."
)
)
parser.add_argument("json_file", help="Path to the input JSON file.")
parser.add_argument("--output-dir", default=None,
help="Name of the output directory. Default is <input_basename>_msa_files.")
parser.add_argument("--output-json", default=None,
help="Name of the output JSON file. Default is <input_basename>_processed.json.")
parser.add_argument("--force", action="store_true",
help="Overwrite existing files if they already exist.")
return parser.parse_args()
def load_json(json_path):
"""Load and return the JSON content from the provided file path."""
if not os.path.isfile(json_path):
print(f"ERROR: {json_path} does not exist.", file=sys.stderr)
sys.exit(1)
try:
with open(json_path, "r") as f:
data = json.load(f)
return data
except json.JSONDecodeError as e:
print(f"ERROR: Failed to parse JSON file: {e}", file=sys.stderr)
sys.exit(1)
def create_output_directory(directory_name):
"""Create an output directory if it does not exist. Raise an error if creation fails."""
if not os.path.exists(directory_name):
try:
os.makedirs(directory_name)
print(f"Created directory: {directory_name}")
except OSError as e:
print(f"ERROR: Failed to create directory {directory_name}: {e}", file=sys.stderr)
sys.exit(1)
def write_file_if_allowed(content, out_path, force=False):
"""
Write content to out_path if it does not exist, or if force=True.
Otherwise, skip writing if the file already exists.
"""
if os.path.exists(out_path) and not force:
print(f"File {out_path} already exists. Use --force to overwrite. Skipping.")
return
try:
with open(out_path, "w") as f:
f.write(content)
print(f"Saved file: {out_path}")
except OSError as e:
print(f"ERROR: Failed to write file {out_path}: {e}", file=sys.stderr)
def ensure_list(x):
"""Ensure x is a list. If x is a string, wrap it in a list. If x is empty or None, return empty list."""
if x is None:
return []
if isinstance(x, list):
return x
if isinstance(x, str) and x.strip():
return [x.strip()]
return []
def process_sequence_item(
item, # e.g., item["protein"], item["rna"], etc.
seq_type, # 'protein', 'rna', 'dna', 'ligand'
output_dir,
force=False
):
"""
Process a single dict representing a protein/rna/dna/ligand.
Extract unpairedMsa, pairedMsa, mmcif from templates, etc.
Write them out if present, and update item accordingly.
"""
# The 'id' field might be a string or a list. Convert to list for uniform processing.
seq_ids = ensure_list(item.get("id", None))
# If no valid id is found, log and skip.
if not seq_ids:
print(f"Warning: No valid 'id' found for {seq_type}. Skipping MSA writes.")
return
# Handle unpairedMsa
if "unpairedMsa" in item and item["unpairedMsa"] is not None:
msa_content = item["unpairedMsa"]
for sid in seq_ids:
out_path = os.path.join(output_dir, f"unpairedMsa_{sid}.a3m")
write_file_if_allowed(msa_content, out_path, force=force)
# Replace with unpairedMsaPath
del item["unpairedMsa"]
if len(seq_ids) == 1:
item["unpairedMsaPath"] = os.path.join(output_dir, f"unpairedMsa_{seq_ids[0]}.a3m")
else:
# If multiple ids, store them in a list
item["unpairedMsaPath"] = [
os.path.join(output_dir, f"unpairedMsa_{sid}.a3m") for sid in seq_ids
]
# Handle pairedMsa
if "pairedMsa" in item and item["pairedMsa"] is not None:
msa_content = item["pairedMsa"]
for sid in seq_ids:
out_path = os.path.join(output_dir, f"pairedMsa_{sid}.a3m")
write_file_if_allowed(msa_content, out_path, force=force)
# Replace with pairedMsaPath
del item["pairedMsa"]
if len(seq_ids) == 1:
item["pairedMsaPath"] = os.path.join(output_dir, f"pairedMsa_{seq_ids[0]}.a3m")
else:
item["pairedMsaPath"] = [
os.path.join(output_dir, f"pairedMsa_{sid}.a3m") for sid in seq_ids
]
# Process templates if present
if "templates" in item and isinstance(item["templates"], list):
for t_index, template in enumerate(item["templates"]):
# If mmcif is present, write it out
if "mmcif" in template and template["mmcif"] is not None:
mmcif_content = template["mmcif"]
for sid in seq_ids:
out_path = os.path.join(output_dir, f"template_{sid}_{t_index}.a3m")
write_file_if_allowed(mmcif_content, out_path, force=force)
del template["mmcif"]
# If only one id, store a single path
if len(seq_ids) == 1:
template["mmcifPath"] = os.path.join(output_dir, f"template_{seq_ids[0]}_{t_index}.a3m")
else:
template["mmcifPath"] = [
os.path.join(output_dir, f"template_{sid}_{t_index}.a3m") for sid in seq_ids
]
# If mmcifPath is present, we do nothing (by instructions).
def process_json(data, output_dir, force=False):
"""
Process the JSON data in-place. For each entry in "sequences",
extract unpaired/paired MSAs or mmcif from templates and save to files.
Update JSON paths accordingly.
"""
if "sequences" not in data or not isinstance(data["sequences"], list):
print("No 'sequences' array found in JSON.")
return data
for seq_block in data["sequences"]:
# seq_block is something like {"protein": {...}} or {"rna": {...}}, ...
# In each block, exactly one of these keys might be present.
for seq_type in ["protein", "rna", "dna", "ligand"]:
if seq_type in seq_block and isinstance(seq_block[seq_type], dict):
process_sequence_item(
seq_block[seq_type],
seq_type,
output_dir,
force=force
)
return data
def merge_template_indices_on_one_line(json_text):
"""
Attempt to serialize templateIndices and queryIndices as single-line arrays.
This post-processes the pretty-printed JSON.
It looks for lines containing "templateIndices": [ ... ] or "queryIndices": [ ... ].
"""
# Regex that captures multiline arrays for these specific keys
pattern = re.compile(r'("(templateIndices|queryIndices)"\s*:\s*\[\s*)([\s\S]*?)(\s*\],)')
def replacer(match):
start = match.group(1) # e.g. '"templateIndices": ['
content = match.group(3) # everything inside
end = match.group(4) # '],'
# Parse out all the numeric tokens or string tokens
# We assume these arrays contain integers. If they contain other data, this may need adjustments.
# We'll parse them by reading the bracketed content as JSON to robustly handle spacing.
# Temporarily wrap content in brackets so we can parse it.
temp_str = "[" + content + "]"
try:
arr = json.loads(temp_str)
except json.JSONDecodeError:
# If it fails, return the original match (no change).
return match.group(0)
# Convert arr to a single line
arr_str = ", ".join(str(x) for x in arr)
return f"{start}{arr_str}{end}"
new_text = pattern.sub(replacer, json_text)
return new_text
def save_updated_json(data, original_basename, output_json_path):
"""
Save 'data' as nicely indented JSON, with templateIndices and queryIndices on one line.
"""
try:
# First do a standard pretty-print
pretty_json = json.dumps(data, indent=2)
# Then post-process templateIndices and queryIndices
compact_json = merge_template_indices_on_one_line(pretty_json)
with open(output_json_path, "w") as f:
f.write(compact_json)
print(f"Saved updated JSON: {output_json_path}")
except OSError as e:
print(f"ERROR: Could not write updated JSON file {output_json_path}: {e}", file=sys.stderr)
def main():
args = parse_arguments()
# Derive defaults for output directory and processed JSON
input_basename = os.path.splitext(os.path.basename(args.json_file))[0]
default_output_dir = f"{input_basename}_files"
default_output_json = f"{input_basename}_processed.json"
if args.output_dir is None:
args.output_dir = default_output_dir
if args.output_json is None:
args.output_json = default_output_json
data = load_json(args.json_file)
create_output_directory(args.output_dir)
# Process JSON, writing out MSA/template files
updated_data = process_json(data, args.output_dir, force=args.force)
# Save updated JSON to specified output path
save_updated_json(updated_data, input_basename, args.output_json)
if __name__ == "__main__":
main()