Skip to content

Commit

Permalink
Addition of threading in GroupReadsByUmi and some other performance o…
Browse files Browse the repository at this point in the history
…ptimizations.
  • Loading branch information
tfenne committed Nov 27, 2023
1 parent ff1ca67 commit 4c0c101
Showing 1 changed file with 99 additions and 26 deletions.
125 changes: 99 additions & 26 deletions src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ import htsjdk.samtools._
import htsjdk.samtools.util.SequenceUtil

import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ForkJoinPool}
import scala.collection.immutable.IndexedSeq
import scala.collection.mutable.ListBuffer
import scala.collection.parallel.ExecutionContextTaskSupport
import scala.collection.{BufferedIterator, Iterator, mutable}
import scala.concurrent.ExecutionContext


object GroupReadsByUmi {
Expand Down Expand Up @@ -210,9 +213,16 @@ object GroupReadsByUmi {
* Class that implements the directed adjacency graph method from umi_tools.
* See: https://github.com/CGATOxford/UMI-tools
*/
private[umi] class AdjacencyUmiAssigner(val maxMismatches: Int) extends UmiAssigner {
private[umi] class AdjacencyUmiAssigner(final val maxMismatches: Int, val threads: Int = 1) extends UmiAssigner {
private val pool = new ForkJoinPool(threads, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, false)

private val taskSupport = if (threads < 2) None else {
val ctx = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(threads))
Some(new ExecutionContextTaskSupport(ctx))
}

/** Represents a node in the adjacency graph; equality is just by UMI sequence. */
class Node(val umi: Umi, val count: Long, val children: mutable.Buffer[Node] = mutable.Buffer()) {
final class Node(val umi: Umi, val count: Int, val children: mutable.Buffer[Node] = mutable.Buffer(), var assigned: Boolean = false) {
/** Gets the full set of descendants from this node. */
def descendants: List[Node] = {
val buffer = ListBuffer[Node]()
Expand All @@ -235,12 +245,14 @@ object GroupReadsByUmi {

/** Returns whether or not a pair of UMIs match closely enough to be considered adjacent in the graph. */
protected def matches(lhs: Umi, rhs: Umi): Boolean = {
require(lhs.length == rhs.length, s"UMIs of different length detected: $lhs vs. $rhs")
val len = lhs.length
require(rhs.length == len, s"UMIs of different length detected: $lhs vs. $rhs")

var idx = 0
var mismatches = 0
val len = lhs.length
while (idx < len && mismatches <= this.maxMismatches) {
if (lhs(idx) != rhs(idx)) mismatches += 1
val tooManyMismatches = this.maxMismatches + 1
while (mismatches < tooManyMismatches && idx < len) {
if (lhs.charAt(idx) != rhs.charAt(idx)) mismatches += 1
idx += 1
}

Expand All @@ -264,35 +276,83 @@ object GroupReadsByUmi {
val roots = IndexedSeq.newBuilder[Node]

// Make a list of counts of all UMIs in order from most to least abundant; we'll consume from this buffer
var remaining = count(rawUmis).map{ case(umi,count) => new Node(umi, count) }.toBuffer.sortBy((n:Node) => -n.count)
val orderedNodes = count(rawUmis).map{ case(umi,count) => new Node(umi, count.toInt) }.toIndexedSeq.sortBy((n:Node) => -n.count)
val lookup = countIndexLookup(orderedNodes) // Seq of (count, firstIdx) pairs

// Now build one or more graphs starting with the most abundant remaining umi
while (remaining.nonEmpty) {
val nextRoot = remaining.remove(0)
roots += nextRoot
val working = mutable.Buffer[Node](nextRoot)

while (working.nonEmpty) {
val root = working.remove(0)
val (hits, misses) = remaining.partition(other => root.count >= 2 * other.count - 1 && matches(root.umi, other.umi))
root.children ++= hits
working ++= hits
remaining = misses
forloop (from=0, until=orderedNodes.length) { rootIdx =>
val nextRoot = orderedNodes(rootIdx)

if (!nextRoot.assigned) {
roots += nextRoot
val working = mutable.Queue[Node](nextRoot)

while (working.nonEmpty) {
val root = working.remove(0)
root.assigned = true
val maxChildCountPlusOne = (root.count / 2 + 1) + 1
val searchFromIdx = lookup
.find { case (count, _) => count < maxChildCountPlusOne }
.map { case (_, idx) => idx }
.getOrElse(-1)

if (searchFromIdx > 0) {
val hits = taskSupport match {
case None =>
orderedNodes
.drop(searchFromIdx + 1)
.filter(other => !other.assigned && matches(root.umi, other.umi))
case Some(ts) =>
orderedNodes
.drop(searchFromIdx + 1)
.parWith(ts)
.filter(other => !other.assigned && matches(root.umi, other.umi))
.seq
}

root.children ++= hits
working ++= hits
hits.foreach(_.assigned = true)
}
}
}
}

assignIdsToNodes(roots.result())
}

/**
* Generates an indexed seq to enable fast identification of the first index in `nodes` where
* a given UMI count is observed. Assumes that the input is sorted from most abundant to least
* abundant. The generated Seq will contain one entry for every unique count seen; the second value
* in the tuple is the first index in the list of input nodes with that count.
*
* E.g. given nodes with counts [10, 10, 10, 9, 3, 3, 3, 3, 2, 1], the output would be:
* [(10, 0), (9, 3), (3, 4) (2, 8), (1, 9)]
*/
private def countIndexLookup(nodes: IndexedSeq[Node]): IndexedSeq[(Int, Int)] = {
val builder = IndexedSeq.newBuilder[(Int, Int)]
val iter = nodes.iterator.zipWithIndex.bufferBetter

while (iter.hasNext) {
val (currNode, currIdx) = iter.next
builder += ((currNode.count, currIdx))
iter.dropWhile { case (node, _) => node.count == currNode.count}
}

builder.result()
}
}



/**
* Version of the adjacency assigner that works for paired UMIs stored as a single tag of
* the form A-B where reads with A-B and B-A are related but not identical.
*
* @param maxMismatches the maximum number of mismatches between UMIs
*/
class PairedUmiAssigner(maxMismatches: Int) extends AdjacencyUmiAssigner(maxMismatches) {
class PairedUmiAssigner(maxMismatches: Int, threads: Int = 1) extends AdjacencyUmiAssigner(maxMismatches, threads) {
/** String that is prefixed onto the UMI from the read with that maps to a lower coordinate in the genome.. */
private[umi] val lowerReadUmiPrefix: String = ("a" * (maxMismatches+1)) + ":"

Expand Down Expand Up @@ -402,27 +462,27 @@ case class TagFamilySizeMetric(family_size: Int,

/** The strategies implemented by [[GroupReadsByUmi]] to identify reads from the same source molecule.*/
sealed trait Strategy extends EnumEntry {
def newStrategy(edits: Int): UmiAssigner
def newStrategy(edits: Int, threads: Int): UmiAssigner
}
object Strategy extends FgBioEnum[Strategy] {
def values: IndexedSeq[Strategy] = findValues
/** Strategy to only reads with identical UMI sequences are grouped together. */
case object Identity extends Strategy {
def newStrategy(edits: Int = 0): UmiAssigner = {
def newStrategy(edits: Int = 0, threads: Int = 0): UmiAssigner = {
require(edits == 0, "Edits should be zero when using the identity UMI assigner.")
new IdentityUmiAssigner
}
}
/** Strategy to cluster reads into groups based on mismatches between reads in clusters. */
case object Edit extends Strategy { def newStrategy(edits: Int): UmiAssigner = new SimpleErrorUmiAssigner(edits) }
case object Edit extends Strategy { def newStrategy(edits: Int, threads: Int = 0): UmiAssigner = new SimpleErrorUmiAssigner(edits) }
/** Strategy based on the directed adjacency method described in [umi_tools](http://dx.doi.org/10.1101/051755)
* that allows for errors between UMIs but only when there is a count gradient.
*/
case object Adjacency extends Strategy { def newStrategy(edits: Int): UmiAssigner = new AdjacencyUmiAssigner(edits) }
case object Adjacency extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new AdjacencyUmiAssigner(edits, threads) }
/** Strategy similar to the [[Adjacency]] strategy similar to adjacency but for methods that produce template with a
* pair of UMIs such that a read with A-B is related to but not identical to a read with B-A.
*/
case object Paired extends Strategy { def newStrategy(edits: Int): UmiAssigner = new PairedUmiAssigner(edits)}
case object Paired extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new PairedUmiAssigner(edits, threads)}
}

@clp(group=ClpGroups.Umi, description =
Expand Down Expand Up @@ -491,6 +551,11 @@ object Strategy extends FgBioEnum[Strategy] {
| 1. `--min-map-q` defaults to 0 in duplicate marking mode and 1 otherwise
| 2. `--include-secondary` defaults to true in duplicate marking mode and false otherwise
| 3. `--include-supplementary` defaults to true in duplicate marking mode and false otherwise
|
|Multi-threaded operation is supported via the `--threads/-@` option. This only applies to the Adjacency and Paired
|strategies. Additionally the only operation that is multi-threaded is the comparisons of UMIs at the same genomic
|position. Running with e.g. `--threads 8` can provide a _substantial_ reduction in runtime when there are many
|UMIs observed at the same genomic location, such as can occur in amplicon sequencing or ultra-deep coverage data.
"""
)
class GroupReadsByUmi
Expand All @@ -513,13 +578,14 @@ class GroupReadsByUmi
@arg(flag='x', doc= """
|DEPRECATED: this option will be removed in future versions and inter-contig reads will be
|automatically processed.""")
@deprecated val allowInterContig: Boolean = true
@deprecated val allowInterContig: Boolean = true,
@arg(flag='@', doc="Number of threads to use when comparing UMIs. Only recommended for amplicon or similar data.") val threads: Int = 1,
)extends FgBioTool with LazyLogging {
import GroupReadsByUmi._

require(this.minUmiLength.forall(_ => this.strategy != Strategy.Paired), "Paired strategy cannot be used with --min-umi-length")

private val assigner = strategy.newStrategy(this.edits)
private val assigner = strategy.newStrategy(this.edits, this.threads)

// Give values to unset parameters that are different in duplicate marking mode
private val _minMapQ = this.minMapQ.getOrElse(if (this.markDuplicates) 0 else 1)
Expand Down Expand Up @@ -705,13 +771,20 @@ class GroupReadsByUmi
* sub-grouping into UMI groups by original molecule.
*/
def assignUmiGroups(templates: Seq[Template]): Unit = {
val startMs = System.currentTimeMillis
val umis = truncateUmis(templates.map { t => umiForRead(t) })
val rawToId = this.assigner.assign(umis)

templates.iterator.zip(umis.iterator).foreach { case (template, umi) =>
val id = rawToId(umi)
template.primaryReads.foreach(r => r(this.assignTag) = id)
}

val endMs = System.currentTimeMillis()
val durationMs = endMs - startMs
if (durationMs >= 2500) {
logger.debug(f"Grouped ${rawToId.size}%,d UMIs from ${templates.size}%,d templates in ${durationMs}%,d ms." )
}
}

/** When a minimum UMI length is specified, truncates all the UMIs to the length of the shortest UMI. For the paired
Expand Down

0 comments on commit 4c0c101

Please sign in to comment.