Skip to content

Commit

Permalink
Improve startup time by only constructing an interval tree from selec…
Browse files Browse the repository at this point in the history
…ted ranges.

PiperOrigin-RevId: 698481604
  • Loading branch information
danielecook authored and copybara-github committed Nov 20, 2024
1 parent 7d6e9ba commit 5e9a8ae
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 49 deletions.
1 change: 1 addition & 0 deletions deepvariant/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ py_library(
"//deepvariant/python:allelecounter",
"//deepvariant/python:direct_phasing",
"//third_party/nucleus/io/python:hts_verbose",
"//third_party/nucleus/protos:range_py_pb2",
"//third_party/nucleus/protos:reads_py_pb2",
"//third_party/nucleus/util:errors",
"//third_party/nucleus/util:proto_utils",
Expand Down
70 changes: 49 additions & 21 deletions deepvariant/make_examples_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,9 +852,25 @@ def dict_by_chromosome(
# ---------------------------------------------------------------------------


def read_confident_regions(options):
def read_confident_regions(
options: deepvariant_pb2.MakeExamplesOptions,
calling_regions: Optional[Sequence[range_pb2.Range]] = None,
) -> Optional[ranges.RangeSet]:
"""Reads in bed file of confident regions.
Args:
options: MakeExamplesOptions proto.
calling_regions: calling regions to intersect with confident regions.
Returns:
List of ranges from confident region option or none if option is not set.
"""
if options.confident_regions_filename:
return ranges.RangeSet.from_bed(options.confident_regions_filename)
confident_regions = ranges.RangeSet.from_bed(
options.confident_regions_filename,
intersect_ranges=calling_regions,
)
return confident_regions
else:
return None

Expand Down Expand Up @@ -1233,32 +1249,40 @@ def close_all(self):
class RegionProcessor:
"""Creates DeepVariant example protos for a single region on the genome.
This class helps us to run the very sensitive caller, pileup image creator,
and variant labeler operations on a single region in parallel across many
regions using the PoolExecutor API. In order to do this we need three separate
key operations:
(1) Collect all of the info needed to create our resources (e.g., ref reader)
at construction. We cannot actually initialize those resources in the
constructor, though, since we actually want different resources in each
worker process/thread. I.e., we need lazy resource initialization.
(2) Actually initialize these resources *after* the worker has been forked
in our process pool. This gives us a fresh resource to use in each
separate process.
(3) Process the region to find candidate variants and process those into our
tf.Example protos.
This class helps us to run the very sensitive caller, pileup image creator,
and variant labeler operations on a single region in parallel across many
regions using the PoolExecutor API. In order to do this we need three
separate
key operations:
(1) Collect all of the info needed to create our resources (e.g., ref
reader)
at construction. We cannot actually initialize those resources in the
constructor, though, since we actually want different resources in each
worker process/thread. I.e., we need lazy resource initialization.
c
(2) Actually initialize these resources *after* the worker has been forked
in our process pool. This gives us a fresh resource to use in each
separate process.
(3) Process the region to find candidate variants and process those into our
tf.Example protos.
"""

def __init__(self, options: deepvariant_pb2.MakeExamplesOptions):
def __init__(
self,
options: deepvariant_pb2.MakeExamplesOptions,
calling_regions: Optional[Sequence[range_pb2.Range]] = None,
):
"""Creates a new RegionProcess.
Args:
options: deepvariant.MakeExamplesOptions proto used to specify our
resources for calling (e.g., reference_filename).
calling_regions: A list of ranges to call variants in.
"""
self.options = options
self.calling_regions = calling_regions
self.samples = [
sample_lib.Sample(options=x) for x in self.options.sample_options
]
Expand Down Expand Up @@ -1560,7 +1584,9 @@ def _make_labeler_from_options(self):
self.options.truth_variants_filename,
excluded_format_fields=['GL', 'GQ', 'PL'],
)
confident_regions = read_confident_regions(self.options)
confident_regions = read_confident_regions(
self.options, self.calling_regions
)

if (
self.options.variant_caller
Expand Down Expand Up @@ -3076,7 +3102,9 @@ def make_examples_runner(options: deepvariant_pb2.MakeExamplesOptions):
candidates_writer = epath.Path(candidate_positions_filename).open('wb')

# Create a processor to create candidates and examples for each region.
region_processor = RegionProcessor(options)
# Replace path in calling regions with the actual calling regions.
calling_regions = list(calling_regions) if calling_regions else None
region_processor = RegionProcessor(options, calling_regions)
region_processor.initialize()

if options.candidates_filename:
Expand Down
94 changes: 66 additions & 28 deletions third_party/nucleus/util/ranges.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def __init__(self, ranges=None, contigs=None, quiet=False):
if contigs is not None:
self._contigs = contigs
self._contig_map = contigs_dict(contigs)
self._contig_sort_key_fn = (
lambda name: self._contig_map[name].pos_in_fasta)
self._contig_sort_key_fn = lambda name: self._contig_map[
name
].pos_in_fasta
self._is_valid_contig = lambda name: name in self._contig_map
else:
self._contigs = None
Expand All @@ -117,7 +118,8 @@ def __init__(self, ranges=None, contigs=None, quiet=False):
for i, range_ in enumerate(ranges):
if not self._is_valid_contig(range_.reference_name):
raise ValueError(
'Range {} is on an unrecognized contig.'.format(range_))
'Range {} is on an unrecognized contig.'.format(range_)
)
self._by_chr[range_.reference_name].addi(range_.start, range_.end, None)
if not quiet and i > 0 and i % _LOG_EVERY_N_RANGES_IN_RANGESET_INIT == 0:
# We do our test directly here on i > 0 so we only see the log messages
Expand All @@ -138,7 +140,8 @@ def __iter__(self):
are new range protos so can be freely modified.
"""
for refname in sorted(
six.iterkeys(self._by_chr), key=self._contig_sort_key_fn):
six.iterkeys(self._by_chr), key=self._contig_sort_key_fn
):
for start, end, _ in sorted(self._by_chr[refname]):
yield make_range(refname, start, end)

Expand Down Expand Up @@ -176,21 +179,27 @@ def from_contigs(
"""Creates a RangeSet with an interval covering each base of each contig."""
return cls(
(make_range(contig.name, 0, contig.n_bases) for contig in contigs),
contigs)
contigs,
)

@classmethod
def from_bed(cls, source, contigs=None, enable_logging=True):
def from_bed(
cls, source, contigs=None, intersect_ranges=None, enable_logging=True
):
"""Creates a RangeSet containing the intervals from source.
Args:
source: A path to a BED (or equivalent) file of intervals.
contigs: An optional list of ContigInfo proto, used by RangeSet
constructor.
intersect_ranges: An optional list of RangeSet objects to intersect with
the intervals in the BED file before creating the RangeSet.
enable_logging: Enables logging line while reading the file.
Returns:
A RangeSet.
"""
return cls(bed_parser(source, enable_logging), contigs)
return cls(bed_parser(source, intersect_ranges, enable_logging), contigs)

def intersection(self, *others: 'RangeSet') -> 'RangeSet':
"""Computes the intersection among this RangeSet and *others RangeSets.
Expand Down Expand Up @@ -234,10 +243,15 @@ def _intersect2(refname, tree1, tree2):
(bigtree, smalltree) = (tree1, tree2)
else:
(bigtree, smalltree) = (tree2, tree1)
return (make_range(refname, max(interval1.begin, overlapping.begin),
min(interval1.end, overlapping.end))
for interval1 in bigtree
for overlapping in smalltree.overlap(interval1))
return (
make_range(
refname,
max(interval1.begin, overlapping.begin),
min(interval1.end, overlapping.end),
)
for interval1 in bigtree
for overlapping in smalltree.overlap(interval1)
)

# Iteratively intersect each of our *other RangeSets with this RangeSet.
# Sort by size so we do the smallest number of element merge first.
Expand All @@ -256,7 +270,8 @@ def _intersect2(refname, tree1, tree2):
other_chr = other._by_chr.get(refname, None)
if other_chr:
intersected_intervals.extend(
_intersect2(refname, intervals, other_chr))
_intersect2(refname, intervals, other_chr)
)

# Update our intersected RangeSet with the new intervals.
intersected = RangeSet(intersected_intervals, self._contigs)
Expand Down Expand Up @@ -294,8 +309,11 @@ def __nonzero__(self):

__bool__ = __nonzero__ # Python 3 compatibility.

def variant_overlaps(self, variant: variants_pb2.Variant,
empty_set_return_value: bool = True):
def variant_overlaps(
self,
variant: variants_pb2.Variant,
empty_set_return_value: bool = True,
):
"""Returns True if the variant's range overlaps with any in this set."""
if not self:
return empty_set_return_value
Expand Down Expand Up @@ -386,7 +404,8 @@ def make_position(chrom, position, reverse_strand=False):
strand.
"""
return position_pb2.Position(
reference_name=chrom, position=position, reverse_strand=reverse_strand)
reference_name=chrom, position=position, reverse_strand=reverse_strand
)


def make_range(chrom, start, end):
Expand Down Expand Up @@ -414,8 +433,9 @@ def position_overlaps(chrom, pos, interval):
Returns:
True if interval overlaps chr:pos.
"""
return (chrom == interval.reference_name and
interval.start <= pos < interval.end)
return (
chrom == interval.reference_name and interval.start <= pos < interval.end
)


def ranges_overlap(i1, i2):
Expand All @@ -428,8 +448,11 @@ def ranges_overlap(i1, i2):
Returns:
True if and only if i1 and i2 overlap.
"""
return (i1.reference_name == i2.reference_name and i1.end > i2.start and
i1.start < i2.end)
return (
i1.reference_name == i2.reference_name
and i1.end > i2.start
and i1.start < i2.end
)


def bedpe_parser(filename: str) -> Iterable[range_pb2.Range]:
Expand All @@ -456,22 +479,34 @@ def bedpe_parser(filename: str) -> Iterable[range_pb2.Range]:
yield make_range(parts[0], int(parts[1]), int(parts[5]))


def bed_parser(filename, enable_logging=True):
def bed_parser(filename, intersect_ranges=None, enable_logging=True):
"""Parses Range objects from a BED-formatted file object.
See http://bedtools.readthedocs.org/en/latest/content/general-usage.html
for more information on the BED format.
Args:
filename: File name of a BED-formatted file.
intersect_ranges: An optional list of RangeSet objects to intersect with the
intervals in the BED file before creating the RangeSet. Requires a tabix
index.
enable_logging: Enables logging line while reading the file.
Yields:
nucleus.genomics.v1.Range protobuf objects.
"""
with bed.BedReader(filename, enable_logging) as fin:
for r in fin.iterate():
yield make_range(r.reference_name, r.start, r.end)
if not fin.has_index():
logging.warning(
'BED file does not have a tabix index. Reading full bed file.'
)
if intersect_ranges and fin.has_index():
for region in intersect_ranges:
for r in fin.query(region):
yield make_range(r.reference_name, r.start, r.end)
else:
for r in fin.iterate():
yield make_range(r.reference_name, r.start, r.end)


def from_regions(regions, contig_map=None):
Expand All @@ -494,9 +529,9 @@ def from_regions(regions, contig_map=None):
regions: iterable[str]. Converts each element of this iterable into
region(s).
contig_map: An optional dictionary mapping from contig names to ContigInfo
protobufs. If provided, allows literals of the format "contig_name",
which will be parsed into a Range with reference_name=contig_name,
start=0, end=n_bases where n_bases comes from the ContigInfo.
protobufs. If provided, allows literals of the format "contig_name", which
will be parsed into a Range with reference_name=contig_name, start=0,
end=n_bases where n_bases comes from the ContigInfo.
Yields:
A Range proto.
Expand Down Expand Up @@ -541,8 +576,9 @@ def to_literal(range_pb):
Returns:
A string representation of the Range.
"""
return '{}:{}-{}'.format(range_pb.reference_name, range_pb.start + 1,
range_pb.end)
return '{}:{}-{}'.format(
range_pb.reference_name, range_pb.start + 1, range_pb.end
)


def parse_literal(region_literal, contig_map=None):
Expand Down Expand Up @@ -599,7 +635,8 @@ def parse_position(pos_str):
'Could not parse "{}" as a region literal. Region literals '
'should have the form "chr:start-stop" or "chr:start" or '
'just "chr". A common error is to use the "chr" prefix on '
'inputs that don\'t have it, or vice-versa.'.format(region_literal))
"inputs that don't have it, or vice-versa.".format(region_literal)
)


def parse_literals(region_literals, contig_map=None):
Expand Down Expand Up @@ -644,6 +681,7 @@ def sorted_ranges(ranges, contigs=None):
def to_key(range_):
pos = contig_map[range_.reference_name].pos_in_fasta
return pos, range_.start, range_.end

else:
to_key = as_tuple

Expand Down

0 comments on commit 5e9a8ae

Please sign in to comment.