diff --git a/assets/containers/alphafold-predict/predict.py b/assets/containers/alphafold-predict/predict.py index f52ff71..390691a 100644 --- a/assets/containers/alphafold-predict/predict.py +++ b/assets/containers/alphafold-predict/predict.py @@ -30,10 +30,10 @@ import matplotlib.pyplot as plt from resource import getrusage, RUSAGE_SELF -from uuid import uuid4 logging.set_verbosity(logging.INFO) +flags.DEFINE_string("target_id", None, "Target id info.") flags.DEFINE_string("features_path", None, "Path to features pkl file.") flags.DEFINE_string("model_dir", None, "Path to unzipped model dir.") flags.DEFINE_string( @@ -108,6 +108,7 @@ def plot_pae(pae, output) -> None: def predict_structure( + target_id: str, output_dir: str, features_path: str, model_runners: Dict[str, model.RunModel], @@ -115,11 +116,12 @@ def predict_structure( random_seed: int, ): """Predicts structure using AlphaFold for the given sequence.""" - logging.info("Predicting target") + logging.info("Predicting %s", target_id) metrics = { "model_name": "AlphaFold", "model_version": "2.3.1", "start_time": strftime("%d %b %Y %H:%M:%S +0000", gmtime()), + "target_id": target_id, } timings = {} metrics["timings"] = {} @@ -148,7 +150,7 @@ def predict_structure( num_models = len(model_runners) metrics["model_results"] = {} for model_index, (model_name, model_runner) in enumerate(model_runners.items()): - logging.info("Running model %s", model_name) + logging.info("Running model %s on %s", model_name, target_id) t_0 = time.time() model_random_seed = model_index + random_seed * num_models processed_feature_dict = model_runner.process_features( @@ -165,6 +167,7 @@ def predict_structure( logging.info( "Total JAX model %s on %s predict time (includes compilation time): %.1fs", model_name, + target_id, t_diff, ) @@ -263,7 +266,7 @@ def predict_structure( ]["max_predicted_aligned_error"] # Write out metrics - logging.info("Final timings: %s", timings) + logging.info("Final timings for %s: %s", target_id, timings) timings_output_path = os.path.join(output_dir, "timings.json") with open(timings_output_path, "w") as f: f.write(json.dumps(timings, indent=4)) @@ -333,9 +336,12 @@ def main(argv): random_seed = FLAGS.random_seed + target_id = FLAGS.target_id + if random_seed is None: random_seed = random.randrange(sys.maxsize // len(model_runners)) predict_structure( + target_id=target_id, features_path=FLAGS.features_path, output_dir=FLAGS.output_dir, model_runners=model_runners, @@ -347,6 +353,7 @@ def main(argv): if __name__ == "__main__": flags.mark_flags_as_required( [ + "target_id", "features_path", "output_dir", "model_dir", diff --git a/assets/containers/protein-utils/code/src/putils/check_and_validate_inputs.py b/assets/containers/protein-utils/code/src/putils/check_and_validate_inputs.py index 8a5ec73..3828a00 100644 --- a/assets/containers/protein-utils/code/src/putils/check_and_validate_inputs.py +++ b/assets/containers/protein-utils/code/src/putils/check_and_validate_inputs.py @@ -1,7 +1,9 @@ import argparse import logging +# from numpy.polynomial import Polynomial from Bio import SeqIO import json +import re logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -13,7 +15,7 @@ def write_seq_file(seq, filename): with open(filename, "w") as out_fh: SeqIO.write(seq, out_fh, "fasta") -def split_and_get_sequence_metrics(seq_list, output_prefix): +def split_and_get_sequence_metrics(target_id, seq_list, output_prefix): seq_length = 0 seq_count = 0 total_length = 0 @@ -26,6 +28,7 @@ def split_and_get_sequence_metrics(seq_list, output_prefix): for seq_record in seq_list: seq_length += len(seq_record.seq) seq_count += 1 + # id = seq_record.id write_seq_file(seq_list, "inputs.fasta") @@ -33,13 +36,14 @@ def split_and_get_sequence_metrics(seq_list, output_prefix): return seq_count, total_length -def check_inputs(fasta_path, output_prefix): +def check_inputs(target_id, fasta_path, output_prefix): with open(fasta_path, "r") as in_fh: seq_list = list(SeqIO.parse(in_fh, "fasta")) - seq_count, total_length = split_and_get_sequence_metrics(seq_list, output_prefix) + seq_count, total_length = split_and_get_sequence_metrics(target_id, seq_list, output_prefix) seq_info = { + "target_id": str(target_id), "total_length": str(total_length), "seq_count": str(seq_count) } @@ -47,12 +51,20 @@ def check_inputs(fasta_path, output_prefix): # write the sequence info to a json file with open("seq_info.json", "w") as out_fh: json.dump(seq_info, out_fh) + # return seq_info + # return f'{total_length}\n{seq_count}\n' return total_length if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument( + "--target_id", + help="The ID of the target", + type=str, + required=True + ) parser.add_argument( "--fasta_path", @@ -69,5 +81,5 @@ def check_inputs(fasta_path, output_prefix): ) args = parser.parse_args() - output = check_inputs(args.fasta_path, args.output_prefix) + output = check_inputs(args.target_id, args.fasta_path, args.output_prefix) print(output) diff --git a/assets/workflows/alphafold2-multimer/README.md b/assets/workflows/alphafold2-multimer/README.md index 99edd2a..fce0b4f 100644 --- a/assets/workflows/alphafold2-multimer/README.md +++ b/assets/workflows/alphafold2-multimer/README.md @@ -16,6 +16,7 @@ Pick your favorite small fasta file to run your fist end-to-end test. The follow ### Inputs +`target_id`: The ID of the target you wish to predict `fasta_path`: S3 URI to a single FASTA file that is in multi-FASTA format. Currently supports 1-chain per record. ### Example params.json diff --git a/assets/workflows/alphafold2-multimer/config.yaml b/assets/workflows/alphafold2-multimer/config.yaml index 3e38011..1db3c7d 100644 --- a/assets/workflows/alphafold2-multimer/config.yaml +++ b/assets/workflows/alphafold2-multimer/config.yaml @@ -3,6 +3,9 @@ description: "Predict multi-chain protein structures with AlphaFold2-Multimer" engine: NEXTFLOW main: main.nf parameterTemplate: + target_id: + description: "The ID of the target being run." + optional: false fasta_path: description: "Input file in multi-FASTA format." optional: false diff --git a/assets/workflows/alphafold2-multimer/main.nf b/assets/workflows/alphafold2-multimer/main.nf index b7cf6e2..49da065 100644 --- a/assets/workflows/alphafold2-multimer/main.nf +++ b/assets/workflows/alphafold2-multimer/main.nf @@ -21,7 +21,7 @@ include { } from './unpack' workflow AlphaFold2Multimer { - CheckAndValidateInputsTask(params.fasta_path) + CheckAndValidateInputsTask(params.target_id, params.fasta_path) // split fasta run parallel searches (Scatter) split_seqs = CheckAndValidateInputsTask.out.fasta @@ -70,7 +70,8 @@ workflow AlphaFold2Multimer { // Predict. Five separate models model_nums = Channel.of(0, 1, 2, 3, 4) - AlphaFoldMultimerInference(GenerateFeaturesTask.out.features, + AlphaFoldMultimerInference(params.target_id, + GenerateFeaturesTask.out.features, params.alphafold_model_parameters, model_nums, params.random_seed, params.run_relax) @@ -86,6 +87,7 @@ process CheckAndValidateInputsTask { publishDir '/mnt/workflow/pubdir/inputs' input: + val target_id path fasta_path output: @@ -99,7 +101,7 @@ process CheckAndValidateInputsTask { ls -alR /opt/venv/bin/python \ /home/putils/src/putils/check_and_validate_inputs.py \ - --fasta_path=$fasta_path + --target_id=$target_id --fasta_path=$fasta_path """ } @@ -154,6 +156,7 @@ process AlphaFoldMultimerInference { maxRetries 2 publishDir '/mnt/workflow/pubdir' input: + val target_id path features path alphafold_model_parameters val modelnum @@ -171,7 +174,7 @@ process AlphaFoldMultimerInference { export XLA_PYTHON_CLIENT_MEM_FRACTION=4.0 export TF_FORCE_UNIFIED_MEMORY=1 /opt/conda/bin/python /app/alphafold/predict.py \ - --features_path=$features --model_preset=multimer \ + --target_id=$target_id --features_path=$features --model_preset=multimer \ --model_dir=model --random_seed=$random_seed --output_dir=output_model_${modelnum} \ --run_relax=${run_relax} --use_gpu_relax=${run_relax} --model_num=$modelnum