Skip to content

Commit

Permalink
Revert target_id
Browse files Browse the repository at this point in the history
  • Loading branch information
brianloyal committed Dec 11, 2024
1 parent a225bc2 commit 67f798f
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 12 deletions.
15 changes: 11 additions & 4 deletions assets/containers/alphafold-predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
import matplotlib.pyplot as plt
from resource import getrusage, RUSAGE_SELF

from uuid import uuid4

logging.set_verbosity(logging.INFO)

flags.DEFINE_string("target_id", None, "Target id info.")
flags.DEFINE_string("features_path", None, "Path to features pkl file.")
flags.DEFINE_string("model_dir", None, "Path to unzipped model dir.")
flags.DEFINE_string(
Expand Down Expand Up @@ -108,18 +108,20 @@ def plot_pae(pae, output) -> None:


def predict_structure(
target_id: str,
output_dir: str,
features_path: str,
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
random_seed: int,
):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info("Predicting target")
logging.info("Predicting %s", target_id)
metrics = {
"model_name": "AlphaFold",
"model_version": "2.3.1",
"start_time": strftime("%d %b %Y %H:%M:%S +0000", gmtime()),
"target_id": target_id,
}
timings = {}
metrics["timings"] = {}
Expand Down Expand Up @@ -148,7 +150,7 @@ def predict_structure(
num_models = len(model_runners)
metrics["model_results"] = {}
for model_index, (model_name, model_runner) in enumerate(model_runners.items()):
logging.info("Running model %s", model_name)
logging.info("Running model %s on %s", model_name, target_id)
t_0 = time.time()
model_random_seed = model_index + random_seed * num_models
processed_feature_dict = model_runner.process_features(
Expand All @@ -165,6 +167,7 @@ def predict_structure(
logging.info(
"Total JAX model %s on %s predict time (includes compilation time): %.1fs",
model_name,
target_id,
t_diff,
)

Expand Down Expand Up @@ -263,7 +266,7 @@ def predict_structure(
]["max_predicted_aligned_error"]

# Write out metrics
logging.info("Final timings: %s", timings)
logging.info("Final timings for %s: %s", target_id, timings)
timings_output_path = os.path.join(output_dir, "timings.json")
with open(timings_output_path, "w") as f:
f.write(json.dumps(timings, indent=4))
Expand Down Expand Up @@ -333,9 +336,12 @@ def main(argv):

random_seed = FLAGS.random_seed

target_id = FLAGS.target_id

if random_seed is None:
random_seed = random.randrange(sys.maxsize // len(model_runners))
predict_structure(
target_id=target_id,
features_path=FLAGS.features_path,
output_dir=FLAGS.output_dir,
model_runners=model_runners,
Expand All @@ -347,6 +353,7 @@ def main(argv):
if __name__ == "__main__":
flags.mark_flags_as_required(
[
"target_id",
"features_path",
"output_dir",
"model_dir",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
import logging
# from numpy.polynomial import Polynomial
from Bio import SeqIO
import json
import re

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand All @@ -13,7 +15,7 @@ def write_seq_file(seq, filename):
with open(filename, "w") as out_fh:
SeqIO.write(seq, out_fh, "fasta")

def split_and_get_sequence_metrics(seq_list, output_prefix):
def split_and_get_sequence_metrics(target_id, seq_list, output_prefix):
seq_length = 0
seq_count = 0
total_length = 0
Expand All @@ -26,33 +28,43 @@ def split_and_get_sequence_metrics(seq_list, output_prefix):
for seq_record in seq_list:
seq_length += len(seq_record.seq)
seq_count += 1
# id = seq_record.id

write_seq_file(seq_list, "inputs.fasta")

total_length += seq_length
return seq_count, total_length


def check_inputs(fasta_path, output_prefix):
def check_inputs(target_id, fasta_path, output_prefix):
with open(fasta_path, "r") as in_fh:
seq_list = list(SeqIO.parse(in_fh, "fasta"))

seq_count, total_length = split_and_get_sequence_metrics(seq_list, output_prefix)
seq_count, total_length = split_and_get_sequence_metrics(target_id, seq_list, output_prefix)

seq_info = {
"target_id": str(target_id),
"total_length": str(total_length),
"seq_count": str(seq_count)
}

# write the sequence info to a json file
with open("seq_info.json", "w") as out_fh:
json.dump(seq_info, out_fh)
# return seq_info
# return f'{total_length}\n{seq_count}\n'
return total_length


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument(
"--target_id",
help="The ID of the target",
type=str,
required=True
)

parser.add_argument(
"--fasta_path",
Expand All @@ -69,5 +81,5 @@ def check_inputs(fasta_path, output_prefix):
)

args = parser.parse_args()
output = check_inputs(args.fasta_path, args.output_prefix)
output = check_inputs(args.target_id, args.fasta_path, args.output_prefix)
print(output)
1 change: 1 addition & 0 deletions assets/workflows/alphafold2-multimer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Pick your favorite small fasta file to run your fist end-to-end test. The follow

### Inputs

`target_id`: The ID of the target you wish to predict
`fasta_path`: S3 URI to a single FASTA file that is in multi-FASTA format. Currently supports 1-chain per record.

### Example params.json
Expand Down
3 changes: 3 additions & 0 deletions assets/workflows/alphafold2-multimer/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ description: "Predict multi-chain protein structures with AlphaFold2-Multimer"
engine: NEXTFLOW
main: main.nf
parameterTemplate:
target_id:
description: "The ID of the target being run."
optional: false
fasta_path:
description: "Input file in multi-FASTA format."
optional: false
Expand Down
11 changes: 7 additions & 4 deletions assets/workflows/alphafold2-multimer/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ include {
} from './unpack'

workflow AlphaFold2Multimer {
CheckAndValidateInputsTask(params.fasta_path)
CheckAndValidateInputsTask(params.target_id, params.fasta_path)

// split fasta run parallel searches (Scatter)
split_seqs = CheckAndValidateInputsTask.out.fasta
Expand Down Expand Up @@ -70,7 +70,8 @@ workflow AlphaFold2Multimer {

// Predict. Five separate models
model_nums = Channel.of(0, 1, 2, 3, 4)
AlphaFoldMultimerInference(GenerateFeaturesTask.out.features,
AlphaFoldMultimerInference(params.target_id,
GenerateFeaturesTask.out.features,
params.alphafold_model_parameters,
model_nums, params.random_seed,
params.run_relax)
Expand All @@ -86,6 +87,7 @@ process CheckAndValidateInputsTask {
publishDir '/mnt/workflow/pubdir/inputs'

input:
val target_id
path fasta_path

output:
Expand All @@ -99,7 +101,7 @@ process CheckAndValidateInputsTask {
ls -alR
/opt/venv/bin/python \
/home/putils/src/putils/check_and_validate_inputs.py \
--fasta_path=$fasta_path
--target_id=$target_id --fasta_path=$fasta_path
"""
}

Expand Down Expand Up @@ -154,6 +156,7 @@ process AlphaFoldMultimerInference {
maxRetries 2
publishDir '/mnt/workflow/pubdir'
input:
val target_id
path features
path alphafold_model_parameters
val modelnum
Expand All @@ -171,7 +174,7 @@ process AlphaFoldMultimerInference {
export XLA_PYTHON_CLIENT_MEM_FRACTION=4.0
export TF_FORCE_UNIFIED_MEMORY=1
/opt/conda/bin/python /app/alphafold/predict.py \
--features_path=$features --model_preset=multimer \
--target_id=$target_id --features_path=$features --model_preset=multimer \
--model_dir=model --random_seed=$random_seed --output_dir=output_model_${modelnum} \
--run_relax=${run_relax} --use_gpu_relax=${run_relax} --model_num=$modelnum
Expand Down

0 comments on commit 67f798f

Please sign in to comment.