diff --git a/Dockerfile b/Dockerfile index 293abab..c3e2447 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,9 +19,6 @@ WORKDIR /opt/bin RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \ rm Miniconda3-latest-Linux-x86_64.sh && \ - wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz -P /opt/models && \ - tar -zxvf /opt/models/clair3_models.tar.gz -C /opt/models && \ - rm /opt/models/clair3_models.tar.gz && \ conda config --add channels defaults && \ conda config --add channels bioconda && \ conda config --add channels conda-forge && \ @@ -48,4 +45,7 @@ COPY . . RUN cd /opt/bin/preprocess/realign && \ g++ -std=c++14 -O1 -shared -fPIC -o realigner ssw_cpp.cpp ssw.c realigner.cpp && \ g++ -std=c++11 -shared -fPIC -o debruijn_graph -O3 debruijn_graph.cpp && \ + wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz -P /opt/models && \ + tar -zxvf /opt/models/clair3_models.tar.gz -C /opt/models && \ + rm /opt/models/clair3_models.tar.gz && \ echo "source activate clair3" > ~/.bashrc \ No newline at end of file diff --git a/README.md b/README.md index 4839b79..9d71d07 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Clair3 is the 3rd generation of [Clair](https://github.com/HKU-BAL/Cl * [Training Data](docs/training_data.md) * [VCF/GVCF Output Formats](#vcfgvcf-output-formats) * [Pileup Model Training](docs/pileup_training.md) -* [Full-Alignment Model Training](docs/full_alignment_training.md) +* [Full-Alignment Model Training](docs/full_alignment_training_r1.md) * [Representation Unification](docs/representation_unification.md) * [Visualization](docs) * [Model Input](docs/model_input_visualization.md) @@ -54,13 +54,14 @@ Clair3 is the 3rd generation of [Clair](https://github.com/HKU-BAL/Cl ## Latest Updates +*v0.1-r7 (Oct 18)* : 1. Increased `var_pct_full` in ONT mode from 0.3 to 0.7. Indel F1-score increased ~0.2%, but took ~30 minutes longer to finish calling a ~50x ONT dataset. 2. Expand fall through to next most likely variant if network prediction has insufficient read coverage ([#53](https://github.com/HKU-BAL/Clair3/pull/53) commit 09a7d185, contributor @[ftostevin-ont](https://github.com/ftostevin-ont)), accuracy improved on complex Indels. 3. Streamized pileup and full-alignment training workflows. Reduce diskspace demand in model training ([#55](https://github.com/HKU-BAL/Clair3/pull/55) commit 09a7d185, contributor @[ftostevin-ont](https://github.com/ftostevin-ont)). 4. Added `mini_epochs` option in Train.py, performance slightly improved in training a model for ONT Q20 data using mini-epochs([#60](https://github.com/HKU-BAL/Clair3/pull/60), contributor @[ftostevin-ont](https://github.com/ftostevin-ont)). 5. Massively reduced disk space demand when outputting GVCF. Now compressing GVCF intermediate files with lz4, five times smaller with little speed penalty. 6. Added `--remove_intermediate_dir`to remove intermediate files as soon as no longer needed ([#48](https://github.com/HKU-BAL/Clair3/issues/48)). 7. Renamed ONT pre-trained models with [Medaka](https://github.com/nanoporetech/medaka/blob/master/medaka/options.py#L22)'s naming convention. 8. Fixed training data spilling over to validation data ([#57](https://github.com/HKU-BAL/Clair3/issues/57)). + *ONT-provided Models (Sep 23)*: ONT also provides Clair3 models for specific chemistries and basecallers through [Rerio](https://github.com/nanoporetech/rerio). *v0.1-r6 (Sep 4)* : 1. Reduced memory footprint at the `SortVcf` stage([#45](https://github.com/HKU-BAL/Clair3/issues/45)). 2. Reduced `ulimit -n` (number of files simultaneously opened) requirement ([#45](https://github.com/HKU-BAL/Clair3/issues/45), [#47](https://github.com/HKU-BAL/Clair3/issues/47)). 3. Added Clair3-Illumina package in bioconda([#42](https://github.com/HKU-BAL/Clair3/issues/42)). *v0.1-r5 (July 19)* : 1. Modified data generator in model training to avoid memory exhaustion and unexpected segmentation fault by Tensorflow (contributor @[ftostevin-ont](https://github.com/ftostevin-ont) ). 2. Simplified dockerfile workflow to reuse container caching (contributor @[amblina](https://github.com/amblina)). 3. Fixed ALT output for reference calls (contributor @[wdecoster](https://github.com/wdecoster)). 4. Fixed a bug in multi-allelic AF computation (AF of [ACGT]Del variants was wrong before r5). 5. Added AD tag to the GVCF output. 6. Added the `--call_snp_only` option to only call SNP only ([#40](https://github.com/HKU-BAL/Clair3/issues/40)). 7. Added pileup and full-alignment output validity check to avoid workflow crashing ([#32](https://github.com/HKU-BAL/Clair3/issues/32), [#38](https://github.com/HKU-BAL/Clair3/issues/38)). - *v0.1-r4 (June 28)* : 1. Install via [bioconda](https://github.com/HKU-BAL/Clair3#option-3--bioconda). 2. Added an ONT Guppy2 model to the images (`ont_guppy2`). Click [here](https://github.com/HKU-BAL/Clair3/blob/main/docs/guppy2.md) for more benchmarking results. **The results show you have to use the Guppy2 model for Guppy2 or earlier data**. 3. Added [google colab notebooks](https://github.com/HKU-BAL/Clair3/blob/main/colab) for quick demo. 4. Fixed a bug when there are too few variant candidates ([#28](https://github.com/HKU-BAL/Clair3/issues/28)). *v0.1-r3 (June 9)* : 1. Added `ulimit -u` (max user processes) check (lowers the `THREADS` if the resource is insufficient) and automatic retries on failed jobs ([#20](https://github.com/HKU-BAL/Clair3/issues/20), [#23](https://github.com/HKU-BAL/Clair3/issues/23), [#24](https://github.com/HKU-BAL/Clair3/issues/24)). 2. Added an ONT Guppy5 model to the images (`ont_guppy5`). Click [here](docs/guppy5.md) for more benchmarks on the Guppy5 model and data. @@ -83,14 +84,16 @@ Clair3 is the 3rd generation of [Clair](https://github.com/HKU-BAL/Cl Download models from [here](http://www.bio8.cs.hku.hk/clair3/clair3_models/) or click on the links below. -| File | Platform | Training samples | Included in the bioconda package | Included in the docker image | Release | Date | Basecaller | Link | -| :---------------: | :---------: | :----------------------------------------------------------: | -------------------------------- | :--------------------------: | :-----: | :------: | :--------: | :----------------------------------------------------------: | -| ont.tar.gz | ONT | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | Guppy3,4 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont.tar.gz) | -| ont_1235.tar.gz | ONT | HG001,2,3,5 | | | 1 | 20210517 | Guppy3,4 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_1235.tar.gz) | -| ont_guppy5.tar.gz | ONT | Base model: HG001,2,4,5 (Guppy3,4)
Fine-tuning data: HG002 (Guppy5_sup) | Yes | Yes | 1 | 20210609 | Guppy5 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_guppy5.tar.gz) | -| ont_guppy2.tar.gz | ONT | HG001,2,3,4 | | Yes | 1 | 20210627 | Guppy2 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_guppy2.tar.gz) | -| hifi.tar.gz | PacBio HiFi | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/hifi.tar.gz) | -| ilmn.tar.gz | Illumina | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ilmn.tar.gz) | +In a docker installation, models are in `/opt/models/`. In a bioconda installation, models are in `{CONDA_PREFIX}/bin/models/`. + +| Model name | Platform | Training samples | Included in the bioconda package | Included in the docker image | Release | Date | Basecaller | File | Link | +| :--------------------------: | :---------: | :----------------------------------------------------------: | -------------------------------- | :--------------------------: | :-----: | :------: | :--------: | ----------------------------------- | :----------------------------------------------------------: | +| r941_prom_hac_g360+g422 | ONT | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | Guppy3,4 | r941_prom_hac_g360+g422.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422.tar.gz) | +| r941_prom_hac_g360+g422_1235 | ONT | HG001,2,3,5 | | | 1 | 20210517 | Guppy3,4 | r941_prom_hac_g360+g422_1235.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422_1235.tar.gz) | +| r941_prom_sup_g506 | ONT | Base model: HG001,2,4,5 (Guppy3,4)
Fine-tuning data: HG002 (Guppy5_sup) | Yes | Yes | 1 | 20210609 | Guppy5 | r941_prom_sup_g506.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_sup_g506.tar.gz) | +| r941_prom_hac_g238 | ONT | HG001,2,3,4 | | Yes | 1 | 20210627 | Guppy2 | r941_prom_hac_g238.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g238.tar.gz) | +| hifi | PacBio HiFi | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | hifi.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/hifi.tar.gz) | +| ilmn | Illumina | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | ilmn.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ilmn.tar.gz) | #### ONT-provided Models @@ -147,6 +150,7 @@ A pre-built docker image is available [here](https://hub.docker.com/r/hkubal/cla INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 docker run -it \ -v ${INPUT_DIR}:${INPUT_DIR} \ @@ -157,7 +161,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} ## absolute output path prefix ``` @@ -171,6 +175,7 @@ Check [Usage](#Usage) for more options. INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 conda config --add channels defaults conda create -n singularity-env -c conda-forge singularity -y @@ -186,7 +191,7 @@ singularity exec clair3_latest.sif \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} ## absolute output path prefix ``` @@ -206,12 +211,14 @@ conda create -n clair3 -c bioconda clair3 python=3.6.10 -y conda activate clair3 # run clair3 like this afterward +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + run_clair3.sh \ --bam_fn=input.bam \ ## change your bam file name here --ref_fn=ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="${CONDA_PREFIX}/bin/models/ont" \ + --model_path="${CONDA_PREFIX}/bin/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} ## output path prefix ``` @@ -265,12 +272,13 @@ wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz tar -zxvf clair3_models.tar.gz -C ./models # run clair3 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 ./run_clair3.sh \ --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path=`pwd`"/models/ont" \ ## model path prefix, change platform accordingly + --model_path=`pwd`"/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} ## output path prefix ``` @@ -334,7 +342,7 @@ docker run -it hkubal/clair3:latest /opt/bin/run_clair3.sh --help --vcf_fn=FILE Candidate sites VCF file input, variants will only be called at the sites in the VCF file if provided. --ctg_name=STR The name of the sequence to be processed. --sample_name=STR Define the sample name to be shown in the VCF file. - --qual=INT If set, variants with >=$qual will be marked PASS, or LowQual otherwise. + --qual=INT If set, variants with >$qual will be marked PASS, or LowQual otherwise. --samtools=STR Path of samtools, samtools version >= 1.10 is required. --python=STR Path of python, python3 >= 3.6 is required. --pypy=STR Path of pypy3, pypy3 >= 3.6 is required. @@ -345,6 +353,7 @@ docker run -it hkubal/clair3:latest /opt/bin/run_clair3.sh --help --print_ref_calls Show reference calls (0/0) in vcf file, default: disable. --include_all_ctgs Call variants on all contigs, otherwise call in chr{1..22,X,Y} and {1..22,X,Y}, default: disable. --gvcf Enable GVCF output, default: disable. + --remove_intermediate_dir Remove intermediate directory, including intermediate phased BAM, pileup and full-alignment results. default: disable. --snp_min_af=FLOAT Minimum SNP AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.08,hifi:0.08,ilmn:0.08. --indel_min_af=FLOAT Minimum INDEL AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.15,hifi:0.08,ilmn:0.08. --var_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3. @@ -365,6 +374,7 @@ CONTIGS_LIST="[YOUR_CONTIGS_LIST]" # e.g "chr21" or "chr21,chr22" INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 docker run -it \ -v ${INPUT_DIR}:${INPUT_DIR} \ @@ -375,7 +385,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} \ ## absolute output path prefix --ctg_name=${CONTIGS_LIST} ``` @@ -387,6 +397,7 @@ KNOWN_VARIANTS_VCF="[YOUR_VCF_PATH]" # e.g. /home/user1/known_variants.vcf.gz INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 docker run -it \ -v ${INPUT_DIR}:${INPUT_DIR} \ @@ -397,7 +408,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} \ ## absolute output path prefix --vcf_fn=${KNOWN_VARIANTS_VCF} ``` @@ -421,6 +432,7 @@ BED_FILE_PATH="[YOUR_BED_FILE]" # e.g. /home/user1/tmp.bed (absolute path INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 docker run -it \ -v ${INPUT_DIR}:${INPUT_DIR} \ @@ -431,7 +443,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} \ ## absolute output path prefix --bed_fn=${BED_FILE_PATH} ``` @@ -442,6 +454,7 @@ docker run -it \ INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 docker run -it \ -v ${INPUT_DIR}:${INPUT_DIR} \ @@ -452,7 +465,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ ## options: {ont,hifi,ilmn} - --model_path="/opt/models/ont" \ ## absolute model path prefix, change platform accordingly + --model_path="/opt/models/${MODEL_NAME}" \ --output=${OUTPUT_DIR} \ --no_phasing_for_fa \ ## disable phasing for full-alignment --include_all_ctgs \ ## call variants on all contigs in the reference fasta @@ -489,6 +502,8 @@ Submodules in __`clair3/`__ are for variant calling and model training. Submodul `SortVcf` | Sort VCF file. `SplitExtendBed` | Split BED file regions according to the contig names and extend bed region by 33bp by default for variant calling. `UnifyRepresentation` | Representation unification between candidate sites and true variants. +`MergeBin` | Combine tensor binaries into a single file. +`CreateTrainingTensor` | Create tensor binaries for pileup or full-alignment training. `Tensor2Bin` | Combine the variant and non-variant tensors and convert them to a binary, using `blosc:lz4hc` meta-compressor, the overall training memory is 10~15G (pypy incompatible). ---- diff --git a/clair3/Train.py b/clair3/Train.py index 9614f18..01f596f 100644 --- a/clair3/Train.py +++ b/clair3/Train.py @@ -104,17 +104,34 @@ def on_epoch_end(self): np.random.shuffle(self.chunk_list) -def get_chunk_list(chunk_offset, train_chunk_num): +def get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch=10, training_dataset_percentage=None): """ get chunk list for training and validation data. we will randomly split training and validation dataset, all training data is directly acquired from various tensor bin files. """ + need_split_validation_data = training_dataset_percentage is not None all_shuffle_chunk_list = [] + training_chunk_list, validation_chunk_list = [], [] for bin_idx, chunk_num in enumerate(chunk_offset): - all_shuffle_chunk_list += [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)] - np.random.seed(0) - np.random.shuffle(all_shuffle_chunk_list) # keep the same random validate dataset + current_chunk_list = [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)] + all_shuffle_chunk_list += current_chunk_list + if need_split_validation_data: + buffer_chunk_num = chunks_per_batch + if chunk_num < buffer_chunk_num: + training_chunk_list += [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)] + continue + + training_chunk_num = int((chunk_num - buffer_chunk_num) * training_dataset_percentage) + validation_chunk_num = int(chunk_num - buffer_chunk_num - training_chunk_num) + if training_chunk_num > 0: + training_chunk_list += current_chunk_list[:training_chunk_num] + if validation_chunk_num > 0: + validation_chunk_list += current_chunk_list[-validation_chunk_num:] + + if need_split_validation_data: + return np.array(training_chunk_list), np.array(validation_chunk_list) + return np.array(all_shuffle_chunk_list[:train_chunk_num]), np.array(all_shuffle_chunk_list[train_chunk_num:]) @@ -145,6 +162,7 @@ def train_model(args): label_shape = param.label_shape label_shape_cum = param.label_shape_cum batch_size, chunk_size = param.trainBatchSize, param.chunk_size + assert batch_size % chunk_size == 0 chunks_per_batch = batch_size // chunk_size random.seed(param.RANDOM_SEED) np.random.seed(param.RANDOM_SEED) @@ -159,7 +177,7 @@ def populate_dataset_table(file_list, file_path): for bin_idx, bin_file in enumerate(file_list): table_dataset = tables.open_file(os.path.join(file_path, bin_file), 'r') table_dataset_list.append(table_dataset) - chunk_num = (len(table_dataset.root.label) - chunk_size) // chunk_size + chunk_num = (len(table_dataset.root.label) - batch_size) // chunk_size chunk_offset[bin_idx] = chunk_num return table_dataset_list, chunk_offset @@ -185,13 +203,17 @@ def populate_dataset_table(file_list, file_path): total_chunks = train_chunk_num + validate_chunk_num else: total_chunks = int(sum(chunk_offset)) + training_dataset_percentage = param.trainingDatasetPercentage if add_validation_dataset else None if add_validation_dataset: total_batches = total_chunks // chunks_per_batch - validate_chunk_num = int(max(1., np.floor(total_batches * (1 - param.trainingDatasetPercentage))) * chunks_per_batch) + validate_chunk_num = int(max(1., np.floor(total_batches * (1 - training_dataset_percentage))) * chunks_per_batch) + # +++++++++++++**---- + # +:training *:buffer -:validation + # distribute one batch data as buffer for each bin file, avoiding shifting training data to validation data train_chunk_num = int(total_chunks - validate_chunk_num) else: train_chunk_num = total_chunks - train_shuffle_chunk_list, validate_shuffle_chunk_list = get_chunk_list(chunk_offset, train_chunk_num) + train_shuffle_chunk_list, validate_shuffle_chunk_list = get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch, training_dataset_percentage) train_chunk_num = len(train_shuffle_chunk_list) validate_chunk_num = len(validate_shuffle_chunk_list) @@ -223,7 +245,7 @@ def populate_dataset_table(file_list, file_path): metrics=metrics, optimizer=optimizer ) - early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode="min") + early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10*mini_epochs, mode="min") model_save_callback = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False) model_best_callback = tf.keras.callbacks.ModelCheckpoint("best_val_loss", monitor='val_loss', save_best_only=True, mode="min") train_log_callback = tf.keras.callbacks.CSVLogger("training.log", separator='\t') diff --git a/clair3/utils.py b/clair3/utils.py index 9797553..acfbc7e 100644 --- a/clair3/utils.py +++ b/clair3/utils.py @@ -105,17 +105,48 @@ def item_from(row): f.wait() +def remove_common_suffix(ref_base, alt_base): + min_length = min(len(ref_base) - 1, min([len(item) - 1 for item in alt_base])) # keep at least one base + prefix = ref_base[::-1] + for string in alt_base: + string = string[::-1] + while string[:len(prefix)] != prefix and prefix: + prefix = prefix[:len(prefix) - 1] + if not prefix: + break + res_length = len(prefix) + if res_length > min_length: + return ref_base, alt_base + return ref_base[:len(ref_base) - res_length], [item[:len(item) - res_length] for item in alt_base] + + return ref_base[-min_length], [item[-min_length] for item in alt_base] + + +def decode_alt(ref_base, alt_base): + if ',' not in alt_base: + return [ref_base], [alt_base] + alt_base = alt_base.split(',') + ref_base_list, alt_base_list = [], [] + for ab in alt_base: + rb,ab = remove_common_suffix(ref_base, [ab]) + ref_base_list.append(rb) + alt_base_list.append(ab[0]) + return ref_base_list, alt_base_list + + def variant_map_from(var_fn, tree, is_tree_empty): Y = {} + truth_alt_dict = {} miss_variant_set = set() if var_fn is None: - return Y, miss_variant_set + return Y, miss_variant_set, truth_alt_dict f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn))) for row in f.stdout: - columns = row.split() - ctg_name, position_str = columns[0], columns[1] - genotype1, genotype2 = columns[-2], columns[-1] + if row[0] == "#": + continue + columns = row.strip().split() + ctg_name, position_str, ref_base, alt_base, genotype1, genotype2 = columns key = ctg_name + ":" + position_str if genotype1 == '-1' or genotype2 == '-1': miss_variant_set.add(key) @@ -124,11 +155,41 @@ def variant_map_from(var_fn, tree, is_tree_empty): continue Y[key] = output_labels_from_vcf_columns(columns) - + ref_base_list, alt_base_list = decode_alt(ref_base, alt_base) + truth_alt_dict[int(position_str)] = (ref_base_list, alt_base_list) f.stdout.close() f.wait() - return Y, miss_variant_set - + return Y, miss_variant_set, truth_alt_dict + +def find_read_support(pos, truth_alt_dict, alt_info): + alt_info = alt_info.rstrip().split('-') + seqs = alt_info[1].split(' ') if len(alt_info) > 1 else '' + seq_alt_bases_dict = dict(zip(seqs[::2], [int(item) for item in seqs[1::2]])) if len(seqs) else {} + + pos = int(pos) + if pos not in truth_alt_dict: + # candidate position not in the truth vcf or unified truth vcf + return None + ref_base_list, alt_base_list = truth_alt_dict[pos] + found = 0 + for alt_type in seq_alt_bases_dict: + if '*' in alt_type or '#' in alt_type or 'R' in alt_type: + continue + if alt_type[0] == 'X': + if alt_type[1] in alt_base_list: + found += 1 + elif alt_type[0] == 'I': + if alt_type[1:] in alt_base_list: + found += 1 + elif alt_type[0] == 'D': + del_cigar = alt_type[1:] + for rb, ab in zip(ref_base_list, alt_base_list): + if rb[1:] == del_cigar and len(ab) == 1: + found += 1 + if found >= len(alt_base_list): + return True + # return False if we find any alternative bases missed in subsampled bam, then remove the position from training + return False def write_table_dict(table_dict, string, label, pos, total, alt_info, tensor_shape, pileup): """ @@ -207,7 +268,7 @@ def print_bin_size(path, prefix=None): print('[INFO] total: {}'.format(total)) -def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None): +def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, truth_alt_dict, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None): """ Bin reader generator for bin file generation. @@ -216,6 +277,7 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis Y: dictionary (contig name: label information) to store all variant and non variant information. tree: dictionary(contig name : intervaltree) for quick region querying. miss_variant_set: sometimes there will have true variant missing after downsampling reads. + truth_alt_dict: unified truth reference base and alternative bases to find read support. is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character. maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate. @@ -224,6 +286,8 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis X = {} ref_list = [] total = 0 + variant_set_with_read_support = set() + variants_without_read_support = 0 for row_idx, row in enumerate(tensor_fn): chrom, coord, seq, string, alt_info = row.split("\t") alt_info = alt_info.rstrip() @@ -238,6 +302,13 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis if key in miss_variant_set: continue + have_read_support = find_read_support(pos=coord, truth_alt_dict=truth_alt_dict, alt_info=alt_info) + if have_read_support is not None and not have_read_support: + miss_variant_set.add(key) + variants_without_read_support += 1 + continue + + variant_set_with_read_support.add(key) if key not in X: X[key] = (string, alt_info, seq) if is_reference: @@ -267,6 +338,7 @@ def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, mis if total % 100000 == 0: print("[INFO] Processed %d tensors" % total, file=sys.stderr) + print("[INFO] Variants with read support/variants without read support: {}/{}".format(len(variant_set_with_read_support), variants_without_read_support)) if maximum_non_variant_ratio is not None: _filter_non_variants(X, ref_list, maximum_non_variant_ratio) yield X, total, True @@ -306,7 +378,7 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow tree = bed_tree_from(bed_file_path=bed_fn) is_tree_empty = len(tree.keys()) == 0 - Y_true_var, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty) + Y_true_var, miss_variant_set, truth_alt_dict = variant_map_from(var_fn, tree, is_tree_empty) Y = copy.deepcopy(Y_true_var) global param @@ -367,6 +439,7 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow is_tree_empty=is_tree_empty, tree=tree, miss_variant_set=miss_variant_set, + truth_alt_dict=truth_alt_dict, is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos, maximum_non_variant_ratio=maximum_non_variant_ratio) diff --git a/docs/full_alignment_training.md b/docs/full_alignment_training.md index a6f2da3..3331651 100644 --- a/docs/full_alignment_training.md +++ b/docs/full_alignment_training.md @@ -1,4 +1,4 @@ -# Train a model for Clair3 full-alignment calling +# Train a model for Clair3 full-alignment calling (revision 0) This document shows how to train and fine-tune a deep learning model for Clair3 full-alignment calling. For training a model for pileup calling, please check [here](pileup_training.md). Clair3 needs both a pileup model and a full-alignment model to work. Compared to Clair3's pileup model training, training a full-alignment model needs much longer time. The disk space requirement also increases significantly. The training materials are grouped according to sample, coverage, and chromosome. The groups are converted into tensor binaries. The binaries are much space-efficient and easier to process. As required, multiples tensor binaries can be used together for model training and fine-tuning. diff --git a/docs/full_alignment_training_r1.md b/docs/full_alignment_training_r1.md new file mode 100644 index 0000000..56af22d --- /dev/null +++ b/docs/full_alignment_training_r1.md @@ -0,0 +1,361 @@ +# Train a model for Clair3 full-alignment calling (revision 1) + +Compare to [revision 0]((full_alignment_training.md)), revision 1 runs Representation Unification only once on the full-depth, providing trained models with similar accuracy and sensitivity, while trains faster. + +This document shows how to train and fine-tune a deep learning model for Clair3 full-alignment calling. For training a model for pileup calling, please check [here](pileup_training.md). Clair3 needs both a pileup model and a full-alignment model to work. Compared to Clair3's pileup model training, training a full-alignment model needs much longer time. The disk space requirement also increases significantly. The training materials are grouped according to sample, coverage, and chromosome. The groups are converted into tensor binaries. The binaries are much space-efficient and easier to process. As required, multiples tensor binaries can be used together for model training and fine-tuning. + +--- + +## Prerequisites + +- Clair3 installed +- GNU Parallel installed +- Sufficient hard-disk space +- Truth VCF file after representation unification (check [here](https://github.com/HKU-BAL/Clair3/blob/main/docs/representation_unification.md) on how to generate unified VCF) +- A high-end GPU (have tested on RTX Titan, and RTX 2080Ti) + +--- + +## Contents + +* [I. Training data phasing and haplotaging](#i-training-data-phasing-and-haplotaging) + - [1. Setup variables](#1-setup-variables) + - [2. Phase VCF file using WhatsHap](#2--phase-vcf-file-using-whatshap) + - [3. Haplotag read alignments using WhatsHap](#3--haplotag-read-alignments-using-whatshap) +* [II. Build compressed binary files for full-alignment model training](#ii-build-compressed-binary-files-for-full-alignment-model-training) + - [1. Run Clair3 pileup model](#1-run-Clair3-pileup-model) + - [2. Select low-quality pileup candidates](#2-Select-low-quality-pileup-candidates-using-the-SelectHetSnp-submodule) + - [3. Split and extend bed regions](#3-split-and-extend-bed-regions-using-the-splitextendbed-submodule) + - [4. Get truth variants from unified VCF file](#4-get-truth-variants-from-unified-vcf-using-the-gettruth-submodule) + - [5. Create full-alignment tensor](#5-create-full-alignment-tensor-using-the-createtrainingtensor-submodule) + - [6. Merge compressed binaries](#6-merge-compressed-binaries-using-the-mergebin-submodule) +* [III. Model training](#iii-model-training) + - [1. full-alignment model training](#1-full-alignment-model-training) + - [2. full-alignment model fine-tune using pre-trained model (optional)](#2-full-alignment-model-fine-tune-using-pre-trained-model-optional) + +--- + +## I. Training data phasing and haplotaging + +Full-alignment model utilizes phased alignment, phased alignments are required for training a full-alignment model. + +> - The whole procedure are breaking into blocks for better readability and error-tracing. +> - For each `parallel` command that run with the `--joblog` option, we can check the `Exitval` column from the job log output. If the column contains a non-zero value, it means error occurred; please rerun the failed block again. +> - We suggest using absolute path EVERYWHERE. +> - You can use a Truth VCF file without representation unification. You might want to do it only for testing because Clair3's performance would be significantly affected without representation unification. +> - If representation unification has applied, all phased alignment would be automatically generated in the `${OUTPUT_DIR}/phased_bam` folder, check [here](representation_unification.md#3--haplotag-read-alignment-using-whatshap) for more details. +> - WhatsHap `haplotag` submodule requires hard-disk space the same size as the input BAM. + +#### 1. Setup variables + +```bash +# Setup executable variables +CLAIR3_PATH="${CONDA_PREFIX}/bin" # clair3 installation path +CLAIR3="${CLAIR3_PATH}/clair3.py" # clair3.py +WHATSHAP="[WHATSHAP_BIN_PATH]" # e.g. whatshap +PARALLEL="[PARALLEL_BIN_PATH]" # e.g. parallel +SAMTOOLS="[SAMTOOLS_BIN_PATH]" # e.g. samtools +TABIX="[TABIX_BIN_PATH]" # e.g. tabix + +# Input parameters +PLATFORM="[SEQUENCING_PLATFORM]" # e.g. {ont, hifi, ilmn} +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER_PATH]" # e.g. output + +ALL_UNPHASED_BAM_FILE_PATH=( +'hg002_1000.bam' +'hg002_800.bam' +'hg004_1000.bam' +) + +# Each line represents a sample, a sample can be specified multiple times to allow downsampling +ALL_SAMPLE=( +'hg002' +'hg002' +'hg004' +) + +# A downsampling numerator (1000 as denominator) for each sample in ALL_SAMPLE, 1000 means no downsampling, 800 means 80% (800/1000) +DEPTHS=( +1000 +800 +1000 +) + +# Reference genome file for each sample +ALL_REFERENCE_FILE_PATH=( +'GRCh38_no_alt.fasta' +'GRCh38_no_alt.fasta' +'GRCh38_no_alt.fasta' +) + +# High-confident BED region file for each sample +ALL_BED_FILE_PATH=( +'HG002.bed' +'HG002.bed' +'HG004.bed' +) + +# GIAB truth VCF files (without representation unification) for each sample +TRUTH_VCF_FILE_PATH=( +'HG002_GRCh38_v4.2.1.vcf.gz' +'HG002_GRCh38_v4.2.1.vcf.gz' +'HG004_GRCh38_v4.2.1.vcf.gz' +) + +# Unified truth VCF file (with representation unification) for each sample +# For a same sample with multiple downsampling depths, only the unified truth VCF done at full depth is needed +UNIFIED_VCF_FILE_PATH=( +'hg002_1000.unified.vcf.gz' +'hg002_1000.unified.vcf.gz' +'hg004_1000.unified.vcf.gz' +) + +# Chromosome prefix ("chr" if chromosome names have the "chr"-prefix) +CHR_PREFIX="chr" + +# array of chromosomes (do not include "chr"-prefix) to training in all sample +CHR=(21 22) + +# Number of threads to be used +THREADS=36 +THREADS_LOW=$((${THREADS}*3/4)) +if [[ ${THREADS_LOW} < 1 ]]; then THREADS_LOW=1; fi + +# Number of chucks to be divided into for parallel processing +chunk_num=15 +CHUNK_LIST=`seq 1 ${chunk_num}` + +# Maximum non-variant ratio for full-alignment model training, for full-alignment model training, we use variant :non-variant = 1 : 1 +MAXIMUM_NON_VARIANT_RATIO=1 + +# Temporary working directory +DATASET_FOLDER_PATH="${OUTPUT_DIR}/build" +TENSOR_CANDIDATE_PATH="${DATASET_FOLDER_PATH}/tensor_can" +BINS_FOLDER_PATH="${DATASET_FOLDER_PATH}/bins" +CANDIDATE_DETAILS_PATH="${DATASET_FOLDER_PATH}/candidate_details" +CANDIDATE_BED_PATH="${DATASET_FOLDER_PATH}/candidate_bed" +SPLIT_BED_PATH="${DATASET_FOLDER_PATH}/split_beds" +VAR_OUTPUT_PATH="${DATASET_FOLDER_PATH}/var" +PILEUP_OUTPUT_PATH="${OUTPUT_DIR}/pileup_output" +UNPHASED_TRUTH_VCF_PATH="${OUTPUT_DIR}/unphased_truth_vcf" +PHASE_VCF_PATH="${OUTPUT_DIR}/phased_vcf" +PHASE_BAM_PATH="${OUTPUT_DIR}/phased_bam" + +mkdir -p ${DATASET_FOLDER_PATH} +mkdir -p ${TENSOR_CANDIDATE_PATH} +mkdir -p ${BINS_FOLDER_PATH} +mkdir -p ${CANDIDATE_DETAILS_PATH} +mkdir -p ${SPLIT_BED_PATH} +mkdir -p ${VAR_OUTPUT_PATH} +mkdir -p ${CANDIDATE_BED_PATH} +mkdir -p ${PILEUP_OUTPUT_PATH} +mkdir -p ${UNPHASED_TRUTH_VCF_PATH} +mkdir -p ${PHASE_VCF_PATH} +mkdir -p ${PHASE_BAM_PATH} +``` + +#### 2. Phase VCF file using WhatsHap + +```bash +cd ${OUTPUT_DIR} + +# Remove the phasing information if the VCF input is already phased +${PARALLEL} -j${THREADS} "${WHATSHAP} unphase {3} > ${UNPHASED_TRUTH_VCF_PATH}/unphased_truth_{1}_{2}.vcf.gz" ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${TRUTH_VCF_FILE_PATH[@]} + +# WhatsHap phasing +${PARALLEL} --joblog ${PHASE_VCF_PATH}/phase.log -j${THREADS} \ +"${WHATSHAP} phase \ + --output ${PHASE_VCF_PATH}/phased_{2}_{3}_{1}.vcf.gz \ + --reference {5} \ + --chromosome ${CHR_PREFIX}{1} \ + --ignore-read-groups \ + --distrust-genotypes \ + ${UNPHASED_TRUTH_VCF_PATH}/unphased_truth_{2}_{3}.vcf.gz \ + {4}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${ALL_UNPHASED_BAM_FILE_PATH[@]} :::+ ${ALL_REFERENCE_FILE_PATH[@]} |& tee ${PHASE_VCF_PATH}/PHASE.log + +# Index the phased VCF files using tabix, which is neccesary for read haplotagging +${PARALLEL} -j ${THREADS} ${TABIX} -p vcf ${PHASE_VCF_PATH}/phased_{2}_{3}_{1}.vcf.gz ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} +``` + +#### 3. Haplotag read alignments using WhatsHap + +```bash +# WhatsHap haplotaging +${PARALLEL} --joblog ${PHASE_BAM_PATH}/haplotag.log -j${THREADS} \ +"${WHATSHAP} haplotag \ + --output ${PHASE_BAM_PATH}/{2}_{3}_{1}.bam \ + --reference {5} \ + --regions ${CHR_PREFIX}{1} \ + --ignore-read-groups \ + ${PHASE_VCF_PATH}/phased_{2}_{3}_{1}.vcf.gz \ + {4}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${ALL_UNPHASED_BAM_FILE_PATH[@]} :::+ ${ALL_REFERENCE_FILE_PATH[@]} |& tee ${PHASE_VCF_PATH}/HAPLOTAG.log + +# Index the phased bam files using samtools +${PARALLEL} --joblog ${PHASE_BAM_PATH}/index.log -j ${THREADS} ${SAMTOOLS} index -@12 ${PHASE_BAM_PATH}/{2}_{3}_{1}.bam ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} + +``` + + + +---- + +## II. Build compressed binary files for full-alignment model training + +This section shows how to build multiple compressed tensor binary file for multiple samples with multiple coverages. + +#### 1. Run Clair3 pileup model + +```bash +# Call variants using Clair3‘s pileup model with the --pileup_only option +# Only select the candidates in the high-confident BED regions for model training (with --bed_fn) +${PARALLEL} -j1 ${CLAIR3_PATH}/run_clair3.sh \ + --bam_fn={3} \ + --ref_fn={4} \ + --threads=${THREADS} \ + --platform="ont" \ + --model_path="${CONDA_PREFIX}/bin/models/ont" \ + --output=${PILEUP_OUTPUT_PATH}/{1}_{2} \ + --bed_fn={5} \ + --pileup_only ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${ALL_UNPHASED_BAM_FILE_PATH[@]} :::+ ${ALL_REFERENCE_FILE_PATH[@]} :::+ ${ALL_BED_FILE_PATH[@]} +``` + +#### 2. Select low-quality pileup candidates using the `SelectHetSnp` submodule + +```bash +# Select all pileup called variants (0/1, 1/1 and 1/2) and some pileup reference calls (0/0) for full-alignment model training +${PARALLEL} --joblog ${DATASET_FOLDER_PATH}/select_pileup_candidates.log -j${THREADS} \ +"${PYPY} ${CLAIR3} SelectHetSnp \ +--alt_fn {4} \ +--split_folder ${CANDIDATE_BED_PATH} \ +--sampleName {2} \ +--depth {3} \ +--ref_pct_full 0.15 \ +--var_pct_full 1.0 \ +--chunk_num ${chunk_num} \ +--phasing_info_in_bam \ +--phase \ +--ctgName ${CHR_PREFIX}{1}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${PILEUP_VCF_FILE_PATH[@]} +``` + +#### 3. Split and extend bed regions using the `SplitExtendBed` submodule + +```bash +${PARALLEL} --joblog ${DATASET_FOLDER_PATH}/split_extend_bed.log -j${THREADS} \ +"${PYPY} ${CLAIR3} SplitExtendBed \ + --bed_fn {4} \ + --output_fn ${SPLIT_BED_PATH}/{2}_{3}_{1} \ + --ctgName ${CHR_PREFIX}{1}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${ALL_BED_FILE_PATH[@]} +``` + +#### 4. Get truth variants from unified VCF using the `GetTruth` submodule + +```bash +# Convert an unified VCF file into a simplified var file +${PARALLEL} --joblog ${VAR_OUTPUT_PATH}/get_truth.log -j${THREADS} \ +"${PYPY} ${CLAIR3} GetTruth \ + --vcf_fn {4} \ + --ctgName ${CHR_PREFIX}{1} \ + --var_fn ${VAR_OUTPUT_PATH}/var_{2}_{3}_{1}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${UNIFIED_VCF_FILE_PATH[@]} +``` + +#### 5. Create full-alignment tensor using the `CreateTrainingTensor` submodule + +```bash +# Create full-alignment tensors for model training +${PARALLEL} --joblog ${DATASET_FOLDER_PATH}/create_tensor_full_alignment.log -j${THREADS_LOW} \ +"${PYPY} ${CLAIR3} CreateTrainingTensor \ + --bam_fn ${PHASE_BAM_PATH}/{2}_{3}_{1}.bam \ + --ref_fn {5} \ + --var_fn ${VAR_OUTPUT_PATH}/var_{2}_{3}_{1} \ + --bin_fn ${TENSOR_CANDIDATE_PATH}/tensor_{2}_{3}_{1}_{7} \ + --ctgName ${CHR_PREFIX}{1} \ + --samtools ${SAMTOOLS} \ + --extend_bed ${SPLIT_BED_PATH}/{2}_{3}_{1} \ + --full_aln_regions ${CANDIDATE_BED_PATH}/{2}_{3}_{1}_{7} \ + --bed_fn {6} \ + --phasing_info_in_bam \ + --add_no_phasing_data_training \ + --allow_duplicate_chr_pos \ + --platform ${PLATFORM} \ + --shuffle \ + --maximum_non_variant_ratio ${MAXIMUM_NON_VARIANT_RATIO} \ + --chunk_id {7} \ + --chunk_num ${chunk_num}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} :::+ ${ALL_UNPHASED_BAM_FILE_PATH[@]} :::+ ${ALL_REFERENCE_FILE_PATH[@]} :::+ ${ALL_BED_FILE_PATH[@]} ::: ${CHUNK_LIST[@]} +``` + +**Options** + + - `--phasing_info_in_bam` : enabled by default, indicating the input BAM is phased using WhatsHap's `haplotag` module, and phased alignments are having a `HP` tag with haplotype details. + - `--allow_duplicate_chr_pos` : for multiple coverages training, this option is required to to allow different coverage training samples at the same variant site. + - `--shuffle` : as the input tensors are created in the order of starting positions, this option shuffles the training data in each chunk. During the training process, we also apply index reshuffling in each epoch. + - `--maximum_non_variant_ratio` : we set a maximum non-variant ratio (variant:non-variant = 1:1) for full-alignment model training, non-variants are randomly selected from the candidate set if the ratio is exceeded, or all non-variants will be used for training otherwise. + - `--add_no_phasing_data_training` : also include unphased alignments in additional to the phased alignments for training. We found including unphased alignments increased model robustness. + - `--full_aln_regions` : provide the pileup candidate regions to be included in full-alignment based calling. + +#### 6. Merge compressed binaries using the `MergeBin` submodule + +```bash +# Merge compressed binaries +${PARALLEL} --joblog ${DATASET_FOLDER_PATH}/mergeBin.log -j${THREADS} \ +"${PYTHON3} ${CLAIR3} MergeBin \ + ${TENSOR_CANDIDATE_PATH}/tensor_{2}_{3}_{1}_* \ + --out_fn ${BINS_FOLDER_PATH}/bin_{2}_{3}_{1}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} :::+ ${DEPTHS[@]} +``` + +---- + +## III. Model training + +We provide two optional training mode: + +​ **Option1**: Train pileup model using new dataset, in this mode, we will use randomly initialized model weights and train the model until reaches max epochs(30) or converge. + +​ **Option2**: Fine-tune pileup model using pre-trained parameters and choose a smaller learning rate for better converge in new dataset. + +***We recommend using the fine-tune mode (option 2) for better robustness.*** + +#### 1. full-alignment model training + +```bash +# Full-alignment model training +MODEL_FOLDER_PATH="${OUTPUT_DIR}/train" +mkdir -p ${MODEL_FOLDER_PATH} + +cd ${MODEL_FOLDER_PATH} + +# A single GPU is used for model training +export CUDA_VISIBLE_DEVICES="0" +${PYTHON3} ${CLAIR3} Train \ + --bin_fn ${BINS_FOLDER_PATH} \ + --ochk_prefix ${MODEL_FOLDER_PATH}/full_alignment \ + --add_indel_length True \ + --random_validation \ + --platform ${PLATFORM} +``` + +**Options** + + - `--add_indel_length` : enable or disable the two indel-length tasks. In the pre-trained models, the two tasks are enabled in full-alignment calling. + - `--random_validation`: randomly holdout 10% from all candidate sites as validation data, the best-performing epoch on the validation data are selected as our pre-trained model. + +#### 2. full-alignment model fine-tune using pre-trained model (optional) + +```bash +# Full-alignment model fine-tuning using a new sample +MODEL_FOLDER_PATH="${OUTPUT_DIR}/train" +mkdir -p ${MODEL_FOLDER_PATH} + +cd ${MODEL_FOLDER_PATH} + +export CUDA_VISIBLE_DEVICES="0" +${PYTHON3} ${CLAIR3} Train \ + --bin_fn ${BINS_FOLDER_PATH} \ + --ochk_prefix ${MODEL_FOLDER_PATH}/full_alignment \ + --add_indel_length True \ + --random_validation \ + --platform ${PLATFORM} \ + --learning_rate 0.0001 \ + --chkpnt_fn "[YOUR_PRETRAINED_MODEL]" ## use pre-trained full-alignment model here +``` + +We experimentally offer full-alignment model fine tuning using a pre-trained Clair3 full-alignment model, by using a smaller `learning_rate` and a pre-trained model `chkpnt_fn`. We recommend starting with a smaller learning rate such as `1e-4` to fine-tune a pre-trained full-alignment model. \ No newline at end of file diff --git a/docs/guppy2.md b/docs/guppy2.md index ef22ae6..e46a0b8 100644 --- a/docs/guppy2.md +++ b/docs/guppy2.md @@ -28,7 +28,7 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ - --model_path="/opt/models/ont_guppy2" \ + --model_path="/opt/models/r941_prom_hac_g238" \ --output=${OUTPUT_DIR} ## absolute output path prefix ``` diff --git a/docs/guppy5.md b/docs/guppy5.md index b1d0d1f..91e1c2d 100644 --- a/docs/guppy5.md +++ b/docs/guppy5.md @@ -31,9 +31,10 @@ docker run -it \ --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference name here --threads=${THREADS} \ ## maximum threads to be used --platform="ont" \ - --model_path="/opt/models/ont_guppy5" \ + --model_path="/opt/models/r941_prom_sup_g506" \ --output=${OUTPUT_DIR} ## absolute output path prefix ``` Check [Usage](https://github.com/HKU-BAL/Clair3#Usage) for more options. + diff --git a/docs/training_data.md b/docs/training_data.md index 8a31ec0..2e17407 100644 --- a/docs/training_data.md +++ b/docs/training_data.md @@ -54,11 +54,13 @@ Download models from [here](http://www.bio8.cs.hku.hk/clair3/clair3_models/) or click on the links below. -| File | Platform | Training samples | Included in the bioconda package | Included in the docker image | Release | Date | Basecaller | Link | -| :---------------: | :---------: | :----------------------------------------------------------: | -------------------------------- | :--------------------------: | :-----: | :------: | :--------: | :----------------------------------------------------------: | -| ont.tar.gz | ONT | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | Guppy3,4 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont.tar.gz) | -| ont_1235.tar.gz | ONT | HG001,2,3,5 | | | 1 | 20210517 | Guppy3,4 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_1235.tar.gz) | -| ont_guppy5.tar.gz | ONT | Base model: HG001,2,4,5 (Guppy3,4)
Fine-tuning data: HG002 (Guppy5_sup) | Yes | Yes | 1 | 20210609 | Guppy5 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_guppy5.tar.gz) | -| ont_guppy2.tar.gz | ONT | HG001,2,3,4 | | Yes | 1 | 20210627 | Guppy2 | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ont_guppy2.tar.gz) | -| hifi.tar.gz | PacBio HiFi | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/hifi.tar.gz) | -| ilmn.tar.gz | Illumina | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ilmn.tar.gz) | +In a docker installation, models are in `/opt/models/`. In a bioconda installation, models are in `{CONDA_PREFIX}/bin/models/`. + +| Model name | Platform | Training samples | Included in the bioconda package | Included in the docker image | Release | Date | Basecaller | File | Link | +| :--------------------------: | :---------: | :----------------------------------------------------------: | -------------------------------- | :--------------------------: | :-----: | :------: | :--------: | ----------------------------------- | :----------------------------------------------------------: | +| r941_prom_hac_g360+g422 | ONT | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | Guppy3,4 | r941_prom_hac_g360+g422.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422.tar.gz) | +| r941_prom_hac_g360+g422_1235 | ONT | HG001,2,3,5 | | | 1 | 20210517 | Guppy3,4 | r941_prom_hac_g360+g422_1235.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422_1235.tar.gz) | +| r941_prom_sup_g506 | ONT | Base model: HG001,2,4,5 (Guppy3,4)
Fine-tuning data: HG002 (Guppy5_sup) | Yes | Yes | 1 | 20210609 | Guppy5 | r941_prom_sup_g506.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_sup_g506.tar.gz) | +| r941_prom_hac_g238 | ONT | HG001,2,3,4 | | Yes | 1 | 20210627 | Guppy2 | r941_prom_hac_g238.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g238.tar.gz) | +| hifi | PacBio HiFi | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | hifi.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/hifi.tar.gz) | +| ilmn | Illumina | HG001,2,4,5 | Yes | Yes | 1 | 20210517 | NA | ilmn.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ilmn.tar.gz) | \ No newline at end of file diff --git a/preprocess/CreateTensorFullAlignment.py b/preprocess/CreateTensorFullAlignment.py index 6d4fb32..b90df00 100644 --- a/preprocess/CreateTensorFullAlignment.py +++ b/preprocess/CreateTensorFullAlignment.py @@ -169,7 +169,7 @@ def get_tensor_info(base_info, bq, ref_base, read_mq): return read_channel, ins_base, query_base -def decode_pileup_bases(pileup_bases, reference_base, minimum_af_for_candidate, has_pileup_candidates): +def decode_pileup_bases(pileup_bases, reference_base, minimum_af_for_candidate, minimum_snp_af_for_candidate, minimum_indel_af_for_candidate, has_pileup_candidates, platform='ont'): """ Decode mpileup input string. pileup_bases: pileup base string for each position, include all mapping information. @@ -219,13 +219,30 @@ def decode_pileup_bases(pileup_bases, reference_base, minimum_af_for_candidate, elif len(key) > 1 and key[1] == '-': pileup_dict['D'] += count + minimum_snp_af_for_candidate = minimum_snp_af_for_candidate if minimum_snp_af_for_candidate > 0 else param.min_af + minimum_indel_af_for_candidate = minimum_indel_af_for_candidate if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform] + denominator = depth if depth > 0 else 1 pileup_list = sorted(list(pileup_dict.items()), key=lambda x: x[1], reverse=True) + + pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base) + pass_snp_af = False + pass_indel_af = False + + for item, count in pileup_list: + if item == reference_base: + continue + elif item[0] in 'ID': + pass_indel_af = (pass_indel_af or (float(count) / denominator >= minimum_indel_af_for_candidate)) + continue + pass_snp_af = pass_snp_af or (float(count) / denominator >= minimum_snp_af_for_candidate) + af = (float(pileup_list[1][1]) / denominator) if len(pileup_list) > 1 else 0.0 - pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base or af >= minimum_af_for_candidate) af = (float(pileup_list[0][1]) / denominator) if len(pileup_list) >= 1 and pileup_list[0][ 0] != reference_base else af + pass_af = pass_af or pass_snp_af or pass_indel_af + return base_list, depth, pass_af, af @@ -438,9 +455,12 @@ def CreateTensorFullAlignment(args): extend_bp = param.extend_bp unify_repre = args.unify_repre minimum_af_for_candidate = args.min_af + minimum_snp_af_for_candidate = args.snp_min_af + minimum_indel_af_for_candidate = args.indel_min_af min_coverage = args.minCoverage platform = args.platform confident_bed_fn = args.bed_fn + is_confident_bed_file_given = confident_bed_fn is not None phased_vcf_fn = args.phased_vcf_fn alt_fn = args.indel_fn extend_bed = args.extend_bed @@ -511,18 +531,29 @@ def CreateTensorFullAlignment(args): Whole genome calling option, acquire contig start end position from reference fasta index(.fai), then split the reference accroding to chunk id and total chunk numbers. """ - contig_length = 0 - with open(fai_fn, 'r') as fai_fp: - for row in fai_fp: - columns = row.strip().split("\t") + if is_confident_bed_file_given: + # consistent with pileup generation, faster to extract tensor using bed region + tree, bed_start, bed_end = bed_tree_from(bed_file_path=extend_bed, + contig_name=ctg_name, + return_bed_region=True) + + chunk_size = (bed_end - bed_start) // chunk_num + 1 if (bed_end - bed_start) % chunk_num else ( + bed_end - bed_start) // chunk_num + ctg_start = bed_start + 1 + chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + else: + contig_length = 0 + with open(fai_fn, 'r') as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") - contig_name = columns[0] - if contig_name != ctg_name: - continue - contig_length = int(columns[1]) - chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num - ctg_start = chunk_size * chunk_id # 0-base to 1-base - ctg_end = ctg_start + chunk_size + contig_name = columns[0] + if contig_name != ctg_name: + continue + contig_length = int(columns[1]) + chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num + ctg_start = chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size # for illumina platform, the reads alignment is acquired after reads realignment from ReadsRealign.py if platform == 'ilmn' and bam_file_path != "PIPE": @@ -658,6 +689,8 @@ def samtools_pileup_generator_from(samtools_mpileup_process): base_list, depth, pass_af, af = decode_pileup_bases(pileup_bases=pileup_bases, reference_base=reference_base, minimum_af_for_candidate=minimum_af_for_candidate, + minimum_snp_af_for_candidate=minimum_snp_af_for_candidate, + minimum_indel_af_for_candidate=minimum_indel_af_for_candidate, has_pileup_candidates=has_pileup_candidates) if phasing_info_in_bam: @@ -848,6 +881,12 @@ def main(): parser.add_argument('--min_af', type=float, default=0.08, help="Minimum allele frequency for both SNP and Indel for a site to be considered as a condidate site, default: %(default)f") + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum snp allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.15, + help="Minimum indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + parser.add_argument('--ctgName', type=str, default=None, help="The name of sequence to be processed, required if --bed_fn is not defined") diff --git a/preprocess/CreateTensorPileup.py b/preprocess/CreateTensorPileup.py index 14084f8..9c473a9 100644 --- a/preprocess/CreateTensorPileup.py +++ b/preprocess/CreateTensorPileup.py @@ -321,7 +321,7 @@ def CreateTensorPileup(args): bed_ctg_end=extend_end) - empty_pileup_flag = True + empty_pileup_flag = True for row in samtools_mpileup_process.stdout: empty_pileup_flag = False columns = row.strip().split('\t') @@ -426,7 +426,7 @@ def CreateTensorPileup(args): if args.gvcf and empty_pileup_flag: nonVariantCaller.write_empty_pileup(ctg_name,ctg_start,ctg_end) if args.gvcf: - nonVariantCaller.vcf_writer.close() + nonVariantCaller.close_vcf_writer() samtools_mpileup_process.stdout.close() samtools_mpileup_process.wait() diff --git a/preprocess/CreateTrainingTensor.py b/preprocess/CreateTrainingTensor.py index 664b709..72913bc 100644 --- a/preprocess/CreateTrainingTensor.py +++ b/preprocess/CreateTrainingTensor.py @@ -83,6 +83,7 @@ def Run(args): var_fn = file_path_from(args.var_fn, exit_on_not_found=True) bin_fn = args.bin_fn extend_bed = file_path_from(args.extend_bed) + full_aln_regions = file_path_from(args.full_aln_regions) platform = args.platform if not platform or platform not in param.support_platform: @@ -128,6 +129,9 @@ def Run(args): CommandOption('samtools', samtoolsBin), CommandOption('bed_fn', bed_fn), CommandOption('extend_bed', extend_bed), + CommandOption('min_af', min_af), + CommandOption('snp_min_af', snp_min_af), + CommandOption('indel_min_af', indel_min_af), ctgStart, ctgEnd, chunk_id, @@ -135,12 +139,9 @@ def Run(args): ] if not pileup: - create_tensor_command_options.append(CommandOption('min_af', min_af)) create_tensor_command_options.append(phasing_info_mode) create_tensor_command_options.append(add_no_phasing_mode) - else: - create_tensor_command_options.append(CommandOption('snp_min_af', snp_min_af)) - create_tensor_command_options.append(CommandOption('indel_min_af', indel_min_af)) + create_tensor_command_options.append(CommandOption('full_aln_regions', full_aln_regions)) compress_tensor_command_options = [ pythonBin, @@ -201,7 +202,7 @@ def Run(args): def main(): - parser = ArgumentParser(description="Call variants using a trained model and a BAM file") + parser = ArgumentParser(description="Create tensor binaries for pileup or full-alignment training") parser.add_argument('--platform', type=str, default="ont", help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") @@ -263,6 +264,10 @@ def main(): parser.add_argument('--pileup', action='store_true', help=SUPPRESS) + ## Provide the regions to be included in full-alignment based calling + parser.add_argument('--full_aln_regions', type=str, default=None, + help=SUPPRESS) + parser.add_argument('--phasing_info_in_bam', action='store_true', help="DEBUG: Skip phasing and use the phasing info provided in the input BAM (HP tag), default: False") diff --git a/preprocess/GetTruth.py b/preprocess/GetTruth.py index 28ca050..23e950b 100644 --- a/preprocess/GetTruth.py +++ b/preprocess/GetTruth.py @@ -2,7 +2,7 @@ import shlex from subprocess import PIPE from argparse import ArgumentParser -from shared.utils import subprocess_popen +from shared.utils import subprocess_popen, vcf_candidates_from class TruthStdout(object): def __init__(self, handle): @@ -14,10 +14,15 @@ def __del__(self): def OutputVariant(args): var_fn = args.var_fn vcf_fn = args.vcf_fn + truth_vcf_fn = args.truth_vcf_fn ctg_name = args.ctgName ctg_start = args.ctgStart ctg_end = args.ctgEnd + truth_vcf_set = set() + variant_set = set() + if args.truth_vcf_fn is not None: + truth_vcf_set = set(vcf_candidates_from(vcf_fn=truth_vcf_fn, contig_name=ctg_name)) if args.var_fn != "PIPE": var_fpo = open(var_fn, "wb") var_fp = subprocess_popen(shlex.split("gzip -c"), stdin=PIPE, stdout=var_fpo) @@ -58,9 +63,17 @@ def OutputVariant(args): # * always have a genotype 1/2 genotype_1, genotype_2 = '0', '1' + + variant_set.add(int(position)) var_fp.stdin.write(" ".join((chromosome, position, reference, alternate, genotype_1, genotype_2))) var_fp.stdin.write("\n") + for position in truth_vcf_set: + if position not in variant_set: + # miss variant set used in Tensor2Bin + var_fp.stdin.write(" ".join((chromosome, str(position), "", "", "-1", "-1"))) + var_fp.stdin.write("\n") + vcf_fp.stdout.close() vcf_fp.wait() @@ -88,6 +101,9 @@ def main(): parser.add_argument('--ctgEnd', type=int, default=None, help="The 1-based inclusive ending position of the sequence to be processed") + parser.add_argument('--truth_vcf_fn', type=str, default=None, + help="Truth VCF file input, only used when vcf_fn is unified vcf. Marked truth variants not in unified as missing") + args = parser.parse_args() if len(sys.argv[1:]) == 0: diff --git a/preprocess/MergeVcf.py b/preprocess/MergeVcf.py index 34db67e..26657cb 100644 --- a/preprocess/MergeVcf.py +++ b/preprocess/MergeVcf.py @@ -314,7 +314,7 @@ def main(): help="Process variant only in the provided regions prefix") parser.add_argument('--qual', type=int, default=2, - help="If set, variants with >=$qual will be marked 'PASS', or 'LowQual' otherwise, optional") + help="If set, variants with >$qual will be marked 'PASS', or 'LowQual' otherwise, optional") parser.add_argument('--sampleName', type=str, default="SAMPLE", help="Define the sample name to be shown in the VCF file") diff --git a/preprocess/SelectHetSnp.py b/preprocess/SelectHetSnp.py index c686720..55f6285 100644 --- a/preprocess/SelectHetSnp.py +++ b/preprocess/SelectHetSnp.py @@ -215,8 +215,8 @@ def FiterHeteSnp(args): candidate_positions.add(pos) - #ref_call - if ref_base == alt_base: + #ref_call was marked as '.' after v0.1-r5 + if ref_base == alt_base or alt_base == ".": ref_call_pos_list.append((pos,qual)) else: need_phasing_list.append((pos,qual)) diff --git a/preprocess/SortVcf.py b/preprocess/SortVcf.py index 84ee9e4..979bce2 100644 --- a/preprocess/SortVcf.py +++ b/preprocess/SortVcf.py @@ -1,11 +1,12 @@ import os import subprocess - +import shlex from sys import stdin, exit from argparse import ArgumentParser from collections import defaultdict -from shared.utils import log_error, log_warning, file_path_from + +from shared.utils import log_error, log_warning, file_path_from, subprocess_popen major_contigs_order = ["chr" + str(a) for a in list(range(1, 23)) + ["X", "Y"]] + [str(a) for a in list(range(1, 23)) + ["X", "Y"]] @@ -54,11 +55,25 @@ def print_calling_step(output_fn=""): subprocess.run('cp {} {}'.format(pileup_output, merge_output), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) +def check_header_in_gvcf(header, contigs_list): + # Only output the contigs processed to be consistent with GATK + # Contig format: ##contig= + + update_header = [] + for row_id, row in enumerate(header): + if row.startswith("##contig="): + contig = row.split(',')[0].split('=')[2] + if contig not in contigs_list: + continue + update_header.append(row) + + return update_header + def sort_vcf_from_stdin(args): """ Sort vcf file according to variants start position and contig name. """ - + row_count = 0 header = [] contig_dict = defaultdict(defaultdict) @@ -139,13 +154,32 @@ def sort_vcf_from(args): header = [] no_vcf_output = True need_write_header = True - output = open(output_fn, 'w') + + # only compress intermediate gvcf using lz4 output and keep final gvcf in bgzip format + output_bgzip_gvcf = vcf_fn_suffix == '.gvcf' + compress_gvcf = 'gvcf' in vcf_fn_suffix + if compress_gvcf: + lz4_path = subprocess.run("which lz4", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip() + compress_gvcf = True if lz4_path != "" else False + is_lz4_format = compress_gvcf + compress_gvcf_output = compress_gvcf and not output_bgzip_gvcf + if compress_gvcf_output: + write_fpo = open(output_fn, 'w') + write_proc = subprocess_popen(shlex.split("lz4 -c"), stdin=subprocess.PIPE, stdout=write_fpo, stderr=subprocess.DEVNULL) + output = write_proc.stdin + else: + output = open(output_fn, 'w') for contig in contigs_order_list: contig_dict = defaultdict(str) contig_vcf_fns = [fn for fn in all_files if contig in fn] for vcf_fn in contig_vcf_fns: - fn = open(os.path.join(input_dir, vcf_fn), 'r') + file = os.path.join(input_dir, vcf_fn) + if is_lz4_format: + read_proc = subprocess_popen(shlex.split("{} {}".format("lz4 -fdc", file)), stderr=subprocess.DEVNULL) + fn = read_proc.stdout + else: + fn = open(file, 'r') for row in fn: row_count += 1 if row[0] == '#': @@ -161,14 +195,24 @@ def sort_vcf_from(args): contig_dict[int(pos)] = row no_vcf_output = False fn.close() + if is_lz4_format: + read_proc.wait() if need_write_header and len(header): + if output_bgzip_gvcf: + header = check_header_in_gvcf(header=header, contigs_list=all_contigs_list) output.write(''.join(header)) need_write_header = False all_pos = sorted(contig_dict.keys()) for pos in all_pos: output.write(contig_dict[pos]) - output.close() + if compress_gvcf_output: + write_proc.stdin.close() + write_proc.wait() + write_fpo.close() + return + else: + output.close() if row_count == 0: print (log_warning("[WARNING] No vcf file found, output empty vcf file")) @@ -183,6 +227,10 @@ def sort_vcf_from(args): print_calling_step(output_fn=output_fn) return + if vcf_fn_suffix == ".tmp.gvcf": + return + if vcf_fn_suffix == ".gvcf": + print("[INFO] Need some time to compress and index GVCF file...") compress_index_vcf(output_fn) diff --git a/preprocess/utils.py b/preprocess/utils.py index 3a45112..a980776 100644 --- a/preprocess/utils.py +++ b/preprocess/utils.py @@ -3,8 +3,10 @@ import os import sys import re +import subprocess +import shlex logging.getLogger().setLevel(logging.INFO) -from shared.utils import file_path_from +from shared.utils import file_path_from, subprocess_popen use_mpmath = True try: @@ -16,6 +18,62 @@ LOG_10 = 2.3025 LOG_2 = 0.3010 +# compress intermediate gvcf row using lz4, issue:https://github.com/HKU-BAL/Clair3/issues/48 +lz4_path = subprocess.run("which lz4", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip() +COMPRESS_GVCF = True if lz4_path != "" else False +LZ4_COMPRESS = "lz4 -c" +LZ4_DECOMPRESS = "lz4 -fdc" +GVCF_SUFFIX = ".tmp.gvcf" + +class compressReaderWriter(object): + def __init__(self, input_path=None, output_path=None, compress=False): + self.input_path = input_path + self.output_path = output_path + self.compress = compress + self.read_proc = None + self.reader = None + + self.writer = None + self.write_proc = None + self.write_fpo = None + + def read_input(self): + if self.compress: + self.read_proc = subprocess_popen(shlex.split("{} {}".format(LZ4_DECOMPRESS, self.input_path)), stderr=subprocess.DEVNULL) + a = subprocess_popen(shlex.split("{} {}".format(LZ4_DECOMPRESS, self.input_path)), stderr=subprocess.DEVNULL) + streamdata = a.communicate()[0] + rc = a.returncode + self.reader = self.read_proc.stdout + else: + self.reader = open(self.input_path, 'r') + return self.reader + + def close_reader(self): + if self.compress: + self.read_proc.stdout.close() + self.read_proc.wait() + else: + self.reader.close() + + def write_output(self): + if self.compress: + self.write_fpo = open(self.output_path, 'w') + self.write_proc = subprocess_popen(shlex.split(LZ4_COMPRESS), stdin=subprocess.PIPE, stdout=self.write_fpo, stderr=subprocess.DEVNULL) + self.writer = self.write_proc.stdin + + else: + self.writer = open(self.output_path, 'w') + return self.writer + + def close_writer(self): + if self.compress: + self.write_proc.stdin.close() + self.write_proc.wait() + self.write_fpo.close() + else: + self.writer.close() + + class gvcfGenerator(object): def __init__(self, ref_path, samtools='samtools'): @@ -24,94 +82,68 @@ def __init__(self, ref_path, samtools='samtools'): self.samtools = samtools pass - def readCalls(self, callPath, callType='variant', ctgName=None, ctgStart=None, ctgEnd=None): + def readCalls(self, callPath, callType='variant', ctgName=None, ctgStart=None, ctgEnd=None, add_header=False, writer=None): - with open(callPath, 'r') as reader: - for line in reader: - if (line.startswith('#')): - continue - else: - if (callType == 'non-variant'): - cur_non_variant_start = int(line.strip('\n').split('\t')[1]) - cur_non_variant_end = int(re.search(r'.*END=(.*)\tGT.*', line).group(1)) - cur_non_variant_chr = line.strip('\n').split('\t')[0] - if ((ctgName and cur_non_variant_chr == ctgName) or (not ctgName)): - if ((ctgStart and cur_non_variant_start >= ctgStart) or (not ctgStart)): - if ((ctgEnd and cur_non_variant_end <= ctgEnd) or (not ctgEnd)): - yield line.strip('\n'), cur_non_variant_start, cur_non_variant_end, 'original' - else: - # for variant calls, return "pos" - # DEL and INS should be considered here - tmp = line.strip('\n').split('\t') - ref = tmp[3] - alt = tmp[4] - n_alt = len(alt.split(',')) - cur_variant_start = int(line.strip('\n').split('\t')[1]) - cur_variant_end = cur_variant_start - 1 + len(ref) - is_reference_call = (alt == '.') or (ref == alt) - if not is_reference_call: - # assuming AD is at the columns [-3], add 0 to AD for gVCF - ori_info = tmp[-1].split(':') - ori_info[-3] += ',0' - tmp[-1] = ':'.join(ori_info) - - # assumeing PL is at the last column - # add to variant calls - tmp[4] = tmp[4] + ',' - if (n_alt == 1): - - tmp[-1] = tmp[-1] + ',990,990,990' - - elif (n_alt == 2): - tmp[-1] = tmp[-1] + ',990,990,990,990' - else: - # skip reference calls - continue - new_line = '\t'.join(tmp) - - cur_variant_chr = tmp[0] - - - if ((ctgName and cur_variant_chr == ctgName) or (not ctgName)): - if ((ctgStart and cur_variant_start >= ctgStart) or (not ctgStart)): - if ((ctgEnd and cur_variant_end <= ctgEnd) or (not ctgEnd)): - yield new_line, cur_variant_start, cur_variant_end - - - - def _print_vcf_header(self,save_writer,tmp_gvcf_path,tmp_vcf_path): - ''' - merge the two headers of tmp_gvcf and tmp_vcf - - ''' - headers=[] - contigs = [] - sample_line = "" - with open(tmp_gvcf_path,'r') as reader: - for line in reader: - if(not line.startswith('#')): - break - if(not line.startswith('##')): - sample_line = line - elif(line.startswith('##contig')): - contigs.append(line) - else: - headers.append(line) + CR = compressReaderWriter(input_path=callPath, compress=COMPRESS_GVCF) + reader = CR.read_input() + need_write_header = True + header = [] + for line in reader: + if (line.startswith('#')): + if add_header and line not in header: + header.append(line) - with open(tmp_vcf_path,'r') as reader: - for line in reader: - if(not line.startswith('##')): - break - if(line.startswith('##contig')): + continue + if add_header and len(header) and need_write_header: + print(''.join(header).rstrip(), file=writer) + need_write_header = False + if (callType == 'non-variant'): + cur_non_variant_start = int(line.strip('\n').split('\t')[1]) + cur_non_variant_end = int(re.search(r'.*END=(.*)\tGT.*', line).group(1)) + cur_non_variant_chr = line.strip('\n').split('\t')[0] + if ((ctgName and cur_non_variant_chr == ctgName) or (not ctgName)): + if ((ctgStart and cur_non_variant_start >= ctgStart) or (not ctgStart)): + if ((ctgEnd and cur_non_variant_end <= ctgEnd) or (not ctgEnd)): + yield line.strip('\n'), cur_non_variant_start, cur_non_variant_end, 'original' + else: + # for variant calls, return "pos" + # DEL and INS should be considered here + tmp = line.strip('\n').split('\t') + ref = tmp[3] + alt = tmp[4] + n_alt = len(alt.split(',')) + cur_variant_start = int(line.strip('\n').split('\t')[1]) + cur_variant_end = cur_variant_start - 1 + len(ref) + is_reference_call = (alt == '.') or (ref == alt) + if not is_reference_call: + # assuming AD is at the columns [-3], add 0 to AD for gVCF + ori_info = tmp[-1].split(':') + ori_info[-3] += ',0' + tmp[-1] = ':'.join(ori_info) + + # assumeing PL is at the last column + # add to variant calls + tmp[4] = tmp[4] + ',' + if (n_alt == 1): + + tmp[-1] = tmp[-1] + ',990,990,990' + + elif (n_alt == 2): + tmp[-1] = tmp[-1] + ',990,990,990,990' + else: + # skip reference calls continue - elif(line not in headers): - headers.append(line) - - print(''.join(headers).strip(),file=save_writer) - print(''.join(contigs).strip(),file=save_writer) - print(sample_line.strip(),file=save_writer) - pass - + new_line = '\t'.join(tmp) + + cur_variant_chr = tmp[0] + + + if ((ctgName and cur_variant_chr == ctgName) or (not ctgName)): + if ((ctgStart and cur_variant_start >= ctgStart) or (not ctgStart)): + if ((ctgEnd and cur_variant_end <= ctgEnd) or (not ctgEnd)): + yield new_line, cur_variant_start, cur_variant_end + + CR.close_reader() def readReferenceBaseAtPos(self, pos): @@ -153,16 +185,18 @@ def mergeCalls(self, variantCallPath, nonVarCallPath, savePath, sampleName, ctgN ctgEnd=None): ''' - merge calls between variant and non-variant - ''' varCallStop = False nonVarCallStop = False - printCurVar = True + + #output writer + CW = compressReaderWriter(output_path=savePath, compress=COMPRESS_GVCF) + save_writer = CW.write_output() + varCallGenerator = self.readCalls(variantCallPath, 'variant', ctgName, ctgStart, ctgEnd) - nonVarCallGenerator = self.readCalls(nonVarCallPath, 'non-variant', ctgName, ctgStart, ctgEnd) + nonVarCallGenerator = self.readCalls(nonVarCallPath, 'non-variant', ctgName, ctgStart, ctgEnd, add_header=True, writer=save_writer) hasVar = True # in case of empty file try: @@ -174,9 +208,7 @@ def mergeCalls(self, variantCallPath, nonVarCallPath, savePath, sampleName, ctgN curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos = next(nonVarCallGenerator) except StopIteration: nonVarCallStop = True - save_writer = open(savePath, 'w') - # print gvcf header - self._print_vcf_header(save_writer,nonVarCallPath,variantCallPath) + while True and (not varCallStop) and (not nonVarCallStop): if (curNonVarEnd < curVarStart): @@ -305,8 +337,7 @@ def mergeCalls(self, variantCallPath, nonVarCallPath, savePath, sampleName, ctgN for curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos in nonVarCallGenerator: print(curNonVarCall, file=save_writer) - save_writer.close() - + CW.close_writer() class variantInfoCalculator(object): @@ -325,11 +356,14 @@ def __init__(self, gvcfWritePath, ref_path, p_err, gq_bin_size, ctgName, bp_reso self.variantMath = mathcalculator() self.constant_log10_probs = self.variantMath.normalize_log10_prob([-1.0, -1.0, -1.0]) self.gq_bin_size = gq_bin_size + self.CW = None # set by the users if (gvcfWritePath != "PIPE"): if (not os.path.exists(gvcfWritePath)): os.mkdir(gvcfWritePath) - self.vcf_writer = open(os.path.join(gvcfWritePath, sample_name + '.tmp.g.vcf'), 'w') + + self.CW = compressReaderWriter(output_path=os.path.join(gvcfWritePath, sample_name + GVCF_SUFFIX), compress=COMPRESS_GVCF) + self.vcf_writer = self.CW.write_output() else: self.vcf_writer = sys.stdout self.writePath = gvcfWritePath @@ -342,7 +376,7 @@ def __init__(self, gvcfWritePath, ref_path, p_err, gq_bin_size, ctgName, bp_reso self.normalized_prob_pool = {} self.current_block = [] - self._print_vcf_header(ctgName=ctgName) + self._print_vcf_header() self.cur_gq_bin_index = None self.cur_gt = None self.cur_min_DP = None @@ -361,7 +395,6 @@ def make_gvcf_online(self, variant_summary, push_current=False): ''' make gvcf while reading from pileup - ''' if (push_current): @@ -529,7 +562,7 @@ def _cal_reference_likelihood(self, n_ref, n_total): validPL = log10_probs[0] == max(log10_probs) return validPL, gq, binned_gq, log10_probs - def _print_vcf_header(self,ctgName): + def _print_vcf_header(self): from textwrap import dedent print(dedent("""\ @@ -555,8 +588,7 @@ def _print_vcf_header(self,ctgName): for row in fai_fp: columns = row.strip().split("\t") contig_name, contig_size = columns[0], columns[1] - if(contig_name==ctgName): - print("##contig=" % (contig_name, contig_size), file=self.vcf_writer) + print("##contig=" % (contig_name, contig_size), file=self.vcf_writer) print('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (self.sampleName), file=self.vcf_writer) @@ -607,6 +639,8 @@ def write_to_gvcf(self, variant_info): print(_tmpLine, file=self.vcf_writer) + def close_vcf_writer(self): + self.CW.close_writer() class mathcalculator(object): diff --git a/run_clair3.sh b/run_clair3.sh index bcc3788..bbfcbb9 100755 --- a/run_clair3.sh +++ b/run_clair3.sh @@ -1,7 +1,7 @@ #!/bin/bash SCRIPT_NAME=$(basename "$0") SCRIPT_PATH=`dirname "$0"` -VERSION='v0.1-r6' +VERSION='v0.1-r7' Usage="Usage: ./${SCRIPT_NAME} --bam_fn=BAM --ref_fn=REF --output=OUTPUT_DIR --threads=THREADS --platform=PLATFORM --model_path=MODEL_PREFIX [--bed_fn=BED] [options]" set -e @@ -14,7 +14,7 @@ print_help_messages() echo $'Required parameters:' echo $' -b, --bam_fn=FILE BAM file input. The input file must be samtools indexed.' echo $' -f, --ref_fn=FILE FASTA reference file input. The input file must be samtools indexed.' - echo $' -m, --model_path=STR The folder path containing a Clair3 model (requiring six files in the folder, including pileup.data-00000-of-00002, pileup.data-00001-of-00002 pileup.index, full_alignment.data-00000-of-00002, full_alignment.data-00001-of-00002 and full_alignment.index).' + echo $' -m, --model_path=STR The folder path containing a Clair3 model (requiring six files in the folder, including pileup.data-00000-of-00002, pileup.data-00001-of-00002 pileup.index, full_alignment.data-00000-of-00002, full_alignment.data-00001-of-00002 and full_alignment.index).' echo $' -t, --threads=INT Max #threads to be used. The full genome will be divided into small chunks for parallel processing. Each chunk will use 4 threads. The #chunks being processed simultaneously is ceil(#threads/4)*3. 3 is the overloading factor.' echo $' -p, --platform=STR Select the sequencing platform of the input. Possible options: {ont,hifi,ilmn}.' echo $' -o, --output=PATH VCF/GVCF output directory.' @@ -25,7 +25,7 @@ print_help_messages() echo $' --vcf_fn=FILE Candidate sites VCF file input, variants will only be called at the sites in the VCF file if provided.' echo $' --ctg_name=STR The name of the sequence to be processed.' echo $' --sample_name=STR Define the sample name to be shown in the VCF file.' - echo $' --qual=INT If set, variants with >=$qual will be marked PASS, or LowQual otherwise.' + echo $' --qual=INT If set, variants with >$qual will be marked PASS, or LowQual otherwise.' echo $' --samtools=STR Path of samtools, samtools version >= 1.10 is required.' echo $' --python=STR Path of python, python3 >= 3.6 is required.' echo $' --pypy=STR Path of pypy3, pypy3 >= 3.6 is required.' @@ -36,6 +36,7 @@ print_help_messages() echo $' --print_ref_calls Show reference calls (0/0) in VCF file, default: disable.' echo $' --include_all_ctgs Call variants on all contigs, otherwise call in chr{1..22,X,Y} and {1..22,X,Y}, default: disable.' echo $' --gvcf Enable GVCF output, default: disable.' + echo $' --remove_intermediate_dir Remove intermediate directory, including intermediate phased BAM, pileup and full-alignment results. default: disable.' echo $' --snp_min_af=FLOAT Minimum SNP AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.08,hifi:0.08,ilmn:0.08.' echo $' --indel_min_af=FLOAT Minimum Indel AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.15,hifi:0.08,ilmn:0.08.' echo $' --var_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3.' @@ -63,7 +64,8 @@ NC="\\033[0m" ARGS=`getopt -o b:f:t:m:p:o:hv \ -l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ bed_fn::,vcf_fn::,ctg_name::,sample_name::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,ref_pct_full::,\ -snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"` +snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,\ +remove_intermediate_dir,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"` if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi eval set -- "${ARGS}" @@ -81,19 +83,20 @@ WHATSHAP='whatshap' CHUNK_NUM=0 CHUNK_SIZE=5000000 QUAL=2 -PRO=0.3 -REF_PRO=0 +PRO="0" +REF_PRO="0" GVCF=False PILEUP_ONLY=False FAST_MODE=False SHOW_REF=False -SNP_AF=0 -INDEL_AF=0 +SNP_AF="0" +INDEL_AF="0" HAP_PRE=False HAP_SEN=False SNP_ONLY=False INCLUDE_ALL_CTGS=False NO_PHASING=False +RM_TMP_DIR=False PILEUP_PREFIX="pileup" FA_PREFIX="full_alignment" @@ -132,6 +135,7 @@ while true; do --haploid_sensitive ) HAP_SEN=True; shift 1 ;; --include_all_ctgs ) INCLUDE_ALL_CTGS=True; shift 1 ;; --no_phasing_for_fa ) NO_PHASING=True; shift 1 ;; + --remove_intermediate_dir ) RM_TMP_DIR=True; shift 1 ;; -- ) shift; break; ;; -h|--help ) print_help_messages; exit 0 ;; @@ -173,8 +177,12 @@ mkdir -p ${OUTPUT_FOLDER} if [ ! -d ${OUTPUT_FOLDER} ]; then echo -e "${ERROR} Cannot create output folder ${OUTPUT_FOLDER}${NC}"; exit 1; fi # show default reference proportion 0.3 for ilmn and hifi, 0.1 for ont -if [ "${PLATFORM}" = "ont" ] && [ ! "${REF_PRO}" -gt 0 ]; then REF_PRO=0.1; fi -if [ "${PLATFORM}" != "ont" ] && [ ! "${REF_PRO}" -gt 0 ]; then REF_PRO=0.3; fi +if [ "${PLATFORM}" = "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.1; fi +if [ "${PLATFORM}" != "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.3; fi + +# show default variant proportion 0.3 for ilmn and hifi, 0.7 for ont +if [ "${PLATFORM}" = "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.7; fi +if [ "${PLATFORM}" != "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.3; fi # optional parameters should use "=" (time ( @@ -198,8 +206,8 @@ echo "[INFO] CHUNK SIZE: ${CHUNK_SIZE}" if [ ${CHUNK_NUM} -gt 0 ]; then echo "[INFO] CHUNK NUM: ${CHUNK_NUM}"; fi echo "[INFO] FULL ALIGN PROPORTION: ${PRO}" echo "[INFO] FULL ALIGN REFERENCE PROPORTION: ${REF_PRO}" -if [ ${SNP_AF} -gt 0 ]; then echo "[INFO] USER DEFINED SNP THRESHOLD: ${SNP_AF}"; fi -if [ ${INDEL_AF} -gt 0 ]; then echo "[INFO] USER DEFINED INDEL THRESHOLD: ${INDEL_AF}"; fi +if [ "${SNP_AF}" != "0" ]; then echo "[INFO] USER DEFINED SNP THRESHOLD: ${SNP_AF}"; fi +if [ "${INDEL_AF}" != "0" ]; then echo "[INFO] USER DEFINED INDEL THRESHOLD: ${INDEL_AF}"; fi echo "[INFO] ENABLE FILEUP ONLY CALLING: ${PILEUP_ONLY}" echo "[INFO] ENABLE FAST MODE CALLING: ${FAST_MODE}" echo "[INFO] ENABLE CALLING SNP CANDIDATES ONLY: ${SNP_ONLY}" @@ -209,6 +217,7 @@ echo "[INFO] ENABLE HAPLOID PRECISE MODE: ${HAP_PRE}" echo "[INFO] ENABLE HAPLOID SENSITIVE MODE: ${HAP_SEN}" echo "[INFO] ENABLE INCLUDE ALL CTGS CALLING: ${INCLUDE_ALL_CTGS}" echo "[INFO] ENABLE NO PHASING FOR FULL ALIGNMENT: ${NO_PHASING}" +echo "[INFO] ENABLE REMOVING INTERMEDIATE FILES: ${RM_TMP_DIR}" echo $'' # file check @@ -294,7 +303,8 @@ ${SCRIPT_PATH}/scripts/clair3.sh \ --include_all_ctgs=${INCLUDE_ALL_CTGS} \ --no_phasing_for_fa=${NO_PHASING} \ --pileup_model_prefix=${PILEUP_PREFIX} \ - --fa_model_prefix=${FA_PREFIX} + --fa_model_prefix=${FA_PREFIX} \ + --remove_intermediate_dir=${RM_TMP_DIR} )) |& tee ${OUTPUT_FOLDER}/run_clair3.log \ No newline at end of file diff --git a/scripts/clair3.sh b/scripts/clair3.sh index 6549c0c..cffbd5d 100755 --- a/scripts/clair3.sh +++ b/scripts/clair3.sh @@ -8,7 +8,7 @@ ARGS=`getopt -o b:f:t:m:p:o:r::c::s::h::g \ -l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ bed_fn::,vcf_fn::,ctg_name::,sample_name::,help::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,\ snp_min_af::,indel_min_af::,ref_pct_full::,pileup_only::,fast_mode::,gvcf::,print_ref_calls::,haploid_precise::,haploid_sensitive::,include_all_ctgs::,\ -no_phasing_for_fa::,pileup_model_prefix::,fa_model_prefix::,call_snp_only:: -n 'run_clair3.sh' -- "$@"` +no_phasing_for_fa::,pileup_model_prefix::,fa_model_prefix::,call_snp_only::,remove_intermediate_dir:: -n 'run_clair3.sh' -- "$@"` if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi eval set -- "${ARGS}" @@ -48,6 +48,7 @@ while true; do --haploid_sensitive ) HAP_SEN="$2"; shift 2 ;; --include_all_ctgs ) INCLUDE_ALL_CTGS="$2"; shift 2 ;; --no_phasing_for_fa ) NO_PHASING="$2"; shift 2 ;; + --remove_intermediate_dir ) RM_TMP_DIR="$2"; shift 2 ;; -- ) shift; break; ;; -h|--help ) print_help_messages; break ;; @@ -146,6 +147,7 @@ ${PYPY} ${CLAIR3} SortVcf \ if [ "$( gzip -fdc ${OUTPUT_FOLDER}/pileup.vcf.gz | grep -v '#' | wc -l )" -eq 0 ]; then echo "[INFO] Exit in pileup variant calling"; exit 0; fi if [ ${PILEUP_ONLY} == True ]; then + if [ ${RM_TMP_DIR} == True ]; then echo "[INFO] Removing intermediate files in ${OUTPUT_FOLDER}/tmp"; rm -rf ${OUTPUT_FOLDER}/tmp; fi echo "[INFO] Only call pileup output with --pileup_only, output file: ${OUTPUT_FOLDER}/pileup.vcf.gz" echo "[INFO] Finish calling!" exit 0; @@ -241,7 +243,16 @@ ${PYPY} ${CLAIR3} SortVcf \ --contigs_fn ${TMP_FILE_PATH}/CONTIGS if [ "$( gzip -fdc ${OUTPUT_FOLDER}/full_alignment.vcf.gz | grep -v '#' | wc -l )" -eq 0 ]; then echo "[INFO] Exit in full-alignment variant calling"; exit 0; fi -if [ ${GVCF} == True ]; then cat ${GVCF_TMP_PATH}/*.tmp.g.vcf | ${PYPY} ${CLAIR3} SortVcf --output_fn ${GVCF_TMP_PATH}/non_var.gvcf; fi +# Compress GVCF output using lz4 +if [ ${GVCF} == True ] +then + ${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${GVCF_TMP_PATH} \ + --vcf_fn_suffix ".tmp.gvcf" \ + --output_fn ${GVCF_TMP_PATH}/non_var.gvcf \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS +fi ##Merge pileup and full alignment vcf ##----------------------------------------------------------------------------------------------------------------------- @@ -284,5 +295,7 @@ then --contigs_fn ${TMP_FILE_PATH}/CONTIGS fi +if [ ${RM_TMP_DIR} == True ]; then echo "[INFO] Removing intermediate files in ${OUTPUT_FOLDER}/tmp"; rm -rf ${OUTPUT_FOLDER}/tmp; fi + echo $'' echo "[INFO] Finish calling, output file: ${OUTPUT_FOLDER}/merge_output.vcf.gz" diff --git a/shared/param_f.py b/shared/param_f.py index a9aa7a5..ffc46eb 100644 --- a/shared/param_f.py +++ b/shared/param_f.py @@ -5,6 +5,8 @@ zstd='zstd' default_optimizer = "Radam" default_loss_function = "FocalLoss" +min_af = 0.08 +min_af_dict = {'ont':0.15, 'hifi':min_af, 'ilmn':min_af } matrix_depth_dict = {'ont': 89, 'hifi': 55, 'ilmn': 55} # Full alignment input feature list