Skip to content

Commit

Permalink
Specify an output sort order in FilterConsensusReads
Browse files Browse the repository at this point in the history
Also ReferenceSequenceIterator closes the underlying reference file
  • Loading branch information
nh13 committed Feb 26, 2022
1 parent 7512807 commit 610b100
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 35 deletions.
27 changes: 22 additions & 5 deletions src/main/scala/com/fulcrumgenomics/bam/Bams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ import com.fulcrumgenomics.util.{Io, ProgressLogger, Sorter}
import htsjdk.samtools.SAMFileHeader.{GroupOrder, SortOrder}
import htsjdk.samtools.SamPairUtil.PairOrientation
import htsjdk.samtools._
import htsjdk.samtools.reference.ReferenceSequenceFileWalker
import htsjdk.samtools.reference.{ReferenceSequence, ReferenceSequenceFileWalker}
import htsjdk.samtools.util.{CloserUtil, CoordMath, SequenceUtil}

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.math.{max, min}

/**
Expand Down Expand Up @@ -246,7 +245,7 @@ object Bams extends LazyLogging {
case (Some(Queryname), _) => new SelfClosingIterator(iterator.bufferBetter, () => CloserUtil.close(iterator))
case (_, _) =>
logger.info(parts = "Sorting into queryname order.")
val progress = ProgressLogger(this.logger, "Records", "sorted")
val progress = ProgressLogger(this.logger, "records", "sorted")
val sort = sorter(Queryname, header, maxInMemory, tmpDir)
iterator.foreach { rec =>
progress.record(rec)
Expand Down Expand Up @@ -412,20 +411,38 @@ object Bams extends LazyLogging {
* @param rec the SamRecord to update
* @param ref a reference sequence file walker to pull the reference information from
*/
def regenerateNmUqMdTags(rec: SamRecord, ref: ReferenceSequenceFileWalker): Unit = {
def regenerateNmUqMdTags(rec: SamRecord, ref: ReferenceSequenceFileWalker): SamRecord = {
if (rec.unmapped) regenerateNmUqMdTags(rec, Map.empty[Int, ReferenceSequence]) else {
val refSeq = ref.get(rec.refIndex)
regenerateNmUqMdTags(rec, Map(refSeq.getContigIndex -> refSeq))
}
rec
}

/**
* Ensures that any NM/UQ/MD tags on the read are accurate. If the read is unmapped, any existing
* values are removed. If the read is mapped all three tags will have values regenerated.
*
* @param rec the SamRecord to update
* @param ref a reference sequence file walker to pull the reference information from
*/
def regenerateNmUqMdTags(rec: SamRecord, ref: Map[Int, ReferenceSequence]): SamRecord = {
import SAMTag._
if (rec.unmapped) {
rec(NM.name()) = null
rec(UQ.name()) = null
rec(MD.name()) = null
}
else {
val refBases = ref.get(rec.refIndex).getBases
val refBases = ref.getOrElse(rec.refIndex, throw new IllegalArgumentException(
s"Record '${rec.name}' had contig index '${rec.refIndex}', but not found in the input reference map"
)).getBases
SequenceUtil.calculateMdAndNmTags(rec.asSam, refBases, true, true)
if (rec.quals != null && rec.quals.length != 0) {
rec(SAMTag.UQ.name) = SequenceUtil.sumQualitiesOfMismatches(rec.asSam, refBases, 0)
}
}
rec
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ package com.fulcrumgenomics.fasta

import htsjdk.samtools.reference.{ReferenceSequence, ReferenceSequenceFile, ReferenceSequenceFileFactory}
import com.fulcrumgenomics.commons.CommonsDef._
import com.fulcrumgenomics.commons.collection.SelfClosingIterator

object ReferenceSequenceIterator {
/** Constructs an iterator over a reference sequence from a Path to the FASTA file. */
def apply(path: PathToFasta, stripComments: Boolean = false) = {
def apply(path: PathToFasta, stripComments: Boolean = false): ReferenceSequenceIterator = {
new ReferenceSequenceIterator(ReferenceSequenceFileFactory.getReferenceSequenceFile(path, stripComments, false))
}

/** Constructs an iterator over a reference sequence from a ReferenceSequenceFile. */
def apply(ref: ReferenceSequenceFile) = {
def apply(ref: ReferenceSequenceFile): ReferenceSequenceIterator = {
new ReferenceSequenceIterator(ref)
}
}
Expand All @@ -45,8 +46,15 @@ object ReferenceSequenceIterator {
class ReferenceSequenceIterator private(private val ref: ReferenceSequenceFile) extends Iterator[ReferenceSequence] {
// The next reference sequence; will be null if there is no next :(
private var nextSequence: ReferenceSequence = ref.nextSequence()
private var isOpen: Boolean = true

override def hasNext: Boolean = this.nextSequence != null
override def hasNext: Boolean = if (this.nextSequence != null) true else {
if (isOpen) {
isOpen = false
ref.close()
}
false
}

override def next(): ReferenceSequence = {
assert(hasNext, "next() called on empty iterator")
Expand Down
149 changes: 122 additions & 27 deletions src/main/scala/com/fulcrumgenomics/umi/FilterConsensusReads.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,24 @@

package com.fulcrumgenomics.umi

import java.lang.Math.{max, min}

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.bam.Bams
import com.fulcrumgenomics.bam.api.{SamOrder, SamRecord, SamSource, SamWriter}
import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool}
import com.fulcrumgenomics.commons.io.Writer
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.fasta.ReferenceSequenceIterator
import com.fulcrumgenomics.sopt.{arg, clp}
import com.fulcrumgenomics.util.NumericTypes.PhredScore
import com.fulcrumgenomics.util.{Io, ProgressLogger}
import htsjdk.samtools.SAMFileHeader.SortOrder
import htsjdk.samtools.SAMFileHeader
import htsjdk.samtools.SAMFileHeader.GroupOrder
import htsjdk.samtools.reference.ReferenceSequenceFileWalker
import htsjdk.samtools.util.SequenceUtil

import java.io.Closeable
import java.lang.Math.{max, min}

/** Filter values for filtering consensus reads */
private[umi] case class ConsensusReadFilter(minReads: Int, maxReadErrorRate: Double, maxBaseErrorRate: Double)

Expand Down Expand Up @@ -114,7 +118,9 @@ class FilterConsensusReads
@arg(flag='q', doc="The minimum mean base quality across the consensus read.")
val minMeanBaseQuality: Option[PhredScore] = None,
@arg(flag='s', doc="Mask (make `N`) consensus bases where the AB and BA consensus reads disagree (for duplex-sequencing only).")
val requireSingleStrandAgreement: Boolean = false
val requireSingleStrandAgreement: Boolean = false,
@arg(flag='S', doc="The sort order of the output, if `:none:` then the same as the input.") val sortOrder: Option[SamOrder] = Some(SamOrder.Coordinate),
@arg(flag='l', doc="Load the full reference sequence in memory") val loadFullReference: Boolean = false
) extends FgBioTool with LazyLogging {
// Baseline input validation
Io.assertReadable(input)
Expand Down Expand Up @@ -169,12 +175,9 @@ class FilterConsensusReads
private val EmptyFilterResult = FilterResult(keepRead=true, maskedBases=0)

override def execute(): Unit = {
val in = SamSource(input)
val header = in.header.clone()
header.setSortOrder(SortOrder.coordinate)
val sorter = Bams.sorter(SamOrder.Coordinate, header, maxRecordsInRam=MaxRecordsInMemoryWhenSorting)
val out = SamWriter(output, header)
val progress1 = ProgressLogger(logger, verb="Filtered and masked")
val progress = ProgressLogger(logger, verb="Filtered and masked")
val in = SamSource(input)
val out = buildOutputWriter(in.header)

// Go through the reads by template and do the filtering
val templateIterator = Bams.templateIterator(in, maxInMemory=MaxRecordsInMemoryWhenSorting)
Expand All @@ -201,34 +204,126 @@ class FilterConsensusReads
keptReads += primaryReadCount
totalBases += r1.length + template.r2.map(_.length).getOrElse(0)
maskedBases += r1Result.maskedBases + r2Result.maskedBases
sorter += r1
progress1.record(r1)
template.r2.foreach { r => sorter += r; progress1.record(r) }
out += r1
progress.record(r1)
template.r2.foreach { r => out += r; progress.record(r) }

template.allSupplementaryAndSecondary.foreach { r =>
val result = filterRecord(r)
if (result.keepRead) {
sorter += r
progress1.record(r)
out += r
progress.record(r)
}
}
}
}
progress.logLast()

// Then iterate the reads in coordinate order and re-calculate key tags
logger.info("Filtering complete; fixing tags and writing coordinate sorted reads.")
val progress2 = new ProgressLogger(logger, verb="Wrote")
val walker = new ReferenceSequenceFileWalker(ref.toFile)
sorter.foreach { rec =>
Bams.regenerateNmUqMdTags(rec, walker)
out += rec
progress2.record(rec)
}

logger.info("Finalizing the output")
in.safelyClose()
out.close()
logger.info(f"Output ${keptReads}%,d of ${totalReads}%,d primary consensus reads.")
logger.info(f"Masked ${maskedBases}%,d of ${totalBases}%,d bases in retained primary consensus reads.")
logger.info(f"Output $keptReads%,d of $totalReads%,d primary consensus reads.")
logger.info(f"Masked $maskedBases%,d of $totalBases%,d bases in retained primary consensus reads.")
}

/** Builds a method to re-generate teh NM/UQ/MD tags based on if we are loading the full reference or not. Also
* returns a method to close the underling reference */
private def buildRegenerateNmUqMdTags(): (SamRecord => SamRecord, () => Unit) = {
if (loadFullReference) {
logger.info("Loading reference into memory")
val refMap = ReferenceSequenceIterator(ref, stripComments=true).map { ref => ref.getContigIndex -> ref}.toMap
val f = (rec: SamRecord) => Bams.regenerateNmUqMdTags(rec, refMap)
(f, () => ())
}
else {
logger.warning("Will require coordinate sorting to update tags, try --load-full-reference instead")
val walker = new ReferenceSequenceFileWalker(ref)
val f = (rec: SamRecord) => Bams.regenerateNmUqMdTags(rec, walker)
(f, () => walker.safelyClose())
}
}

/** Builds the writer to which filtered records should be written.
*
* The filtered records may be sorted once, twice, or never depending on (a) if the full reference is loaded into
* memory, (b) the order after filtering, and (c) the output order.
*
* The order after filtering is determined as follows:
* 1. If the input order is Queryname, or the input is query grouped, then use the input order.
* 2. Otherwise, Queryname.
*
* The output order is determined as follows:
* 1. The order from the `--sort-order` option.
* 2. Otherwise, the order from the input file, if an order is present.
* 3. Otherwise, the order after filtering.
*
* If the full reference has not been loaded then:
* 1. the filtered records are sorted by coordinate to reset the SAM tags.
* 2. if the output order is coordinate, the records are then written directly to the output, otherwise they are
* re-sorted (for a second time) to the desired output order, and written to the output.
*
* If the full reference has been loaded then:
* 1. if the output order is the same as the order after filtering, the filtered records are written to the output,
* otherwise they re-sorted to the desired output order and written to the output.
* */
private def buildOutputWriter(header: SAMFileHeader): Closeable with Writer[SamRecord] = {
val (regenerateNmUqMdTags, refCloseMethod) = buildRegenerateNmUqMdTags()
val outHeader = header.clone()

// Check if the input will be re-sorted into QueryName, or if the input sort order will be kept
val orderAfterFiltering = SamOrder(header) match {
case Some(order) if order == SamOrder.Queryname || order.groupOrder == GroupOrder.query => order
case None => SamOrder.Queryname
}

// Get the output order
val outputOrder = this.sortOrder
.orElse(SamOrder(header)) // use the input sort order
.getOrElse(orderAfterFiltering) // use the order after filtering, so no sort occurs
outputOrder.applyTo(outHeader) // remember to apply it

// Build the writer
val sort = {
if (loadFullReference) {
// If the full reference has been loaded, we need only sort the output if the order after filtering does not
// match the output order.
if (orderAfterFiltering == outputOrder) None else Some(outputOrder)
}
else {
// If the full reference has not been loaded, we will need to coordinate sort to reset
// the tags, then only re-sort in the output order if not in coordinate order.
if (outputOrder == SamOrder.Coordinate) None else Some(outputOrder)
}
}
sort.foreach(o => logger.info(f"Output will be sorted into $o order"))
val writer = SamWriter(output, outHeader, sort=sort, maxRecordsInRam=MaxRecordsInMemoryWhenSorting)

// Create the final writer based on if the full reference has been loaded, or not
if (loadFullReference) {
new Writer[SamRecord] with Closeable {
override def write(rec: SamRecord): Unit = writer += regenerateNmUqMdTags(rec)
def close(): Unit = {
writer.close()
refCloseMethod()
}
}
}
else {
val progress = ProgressLogger(this.logger, "records", "sorted")
new Writer[SamRecord] with Closeable {
private val _sorter = Bams.sorter(order=SamOrder.Coordinate, header=header, maxRecordsInRam=MaxRecordsInMemoryWhenSorting)
override def write(rec: SamRecord): Unit = {
progress.record(rec)
this._sorter += rec
}
def close(): Unit = {
this._sorter.foreach { rec => writer += regenerateNmUqMdTags(rec) }
writer.close()
refCloseMethod()
}
}
}
}

/**
Expand Down

0 comments on commit 610b100

Please sign in to comment.