Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic flags based on size of BAM file #11

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
195 changes: 178 additions & 17 deletions gcp_deepvariant_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import time
import urllib
import uuid
import enum

import gke_cluster
from google.api_core import exceptions as google_exceptions
Expand Down Expand Up @@ -161,6 +162,66 @@
}}
"""

# Following const values are used to automatically set the computational flags.
_WGS_STANDARD = 'wgs_standard'
_WES_STANDARD = 'wes_standard'
_WES_LARGE_THR = 12 * 1024 * 1024 * 1024
_WGS_SMALL_THR = 25 * 1024 * 1024 * 1024
_WGS_LARGE_THR = 200 * 1024 * 1024 * 1024


class BamCategories(enum.Enum):
"""List of BAM categories that determine automatically assigned flags."""
WES_SMALL = 0
WES_LARGE = 1
WGS_SMALL = 2
WGS_MEDIUM = 3
WGS_LARGE = 4


# Default optimal computational flag values, one per BAM category.
_DEFAULT_FLAGS = {}
_DEFAULT_FLAGS[BamCategories.WES_SMALL] = {
'make_examples_workers': 8,
'make_examples_cores_per_worker': 2,
'call_variants_workers': 1,
'call_variants_cores_per_worker': 2,
'gpu': True
}
_DEFAULT_FLAGS[BamCategories.WES_LARGE] = {
'make_examples_workers': 8,
'make_examples_cores_per_worker': 2,
'call_variants_workers': 2,
'call_variants_cores_per_worker': 2,
'gpu': True
}
_DEFAULT_FLAGS[BamCategories.WGS_SMALL] = {
'make_examples_workers': 16,
'make_examples_cores_per_worker': 2,
'call_variants_workers': 2,
'call_variants_cores_per_worker': 2,
'gpu': True
}
_DEFAULT_FLAGS[BamCategories.WGS_MEDIUM] = {
'make_examples_workers': 32,
'make_examples_cores_per_worker': 2,
'call_variants_workers': 4,
'call_variants_cores_per_worker': 2,
'gpu': True
}
_DEFAULT_FLAGS[BamCategories.WGS_LARGE] = {
'make_examples_workers': 64,
'make_examples_cores_per_worker': 2,
'call_variants_workers': 8,
'call_variants_cores_per_worker': 2,
'gpu': True
}
# Common computational flag values across all BAM categories.
_RAM_PER_CORE = 4
_MAKE_EXAMPLES_DISK_PER_WORKER = 200
_CALL_VARIANTS_DISK_PER_WORKER = 50
_POSTPROCESS_VARIANTS_DISK_GVCF = 200


def _get_staging_examples_folder_to_write(pipeline_args,
make_example_worker_index):
Expand Down Expand Up @@ -313,22 +374,30 @@ def _is_valid_gcs_path(gcs_path):
urllib.parse.urlparse(gcs_path).netloc != '')


def _gcs_object_exist(gcs_obj_path):
"""Returns true if the given path is a valid object on GCS.
def _get_gcs_object_size(gcs_obj_path):
"""Returns size of GCS object or 0, if object is missing or access is denied.

Args:
gcs_obj_path: (str) a path to an obj on GCS.
"""
try:
storage_client = storage.Client()
bucket_name = _get_gcs_bucket(gcs_obj_path)
obj_name = _get_gcs_relative_path(gcs_obj_path)
bucket = storage_client.bucket(bucket_name)
obj = bucket.blob(obj_name)
return obj.exists()
except google_exceptions.Forbidden as e:
logging.error('Missing GCS object: %s', str(e))
return False
except ValueError as e:
logging.error('Invalid GCS path: %s', str(e))
return 0

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
try:
blob = bucket.get_blob(obj_name)
except (google_exceptions.NotFound, google_exceptions.Forbidden) as e:
logging.error('Unable to access GCS bucket: %s', str(e))
return 0

if blob is None:
return 0
return blob.size


def _can_write_to_bucket(bucket_name):
Expand Down Expand Up @@ -386,6 +455,69 @@ def _meets_gcp_label_restrictions(label):
label) is not None


def _get_bam_category(pipeline_args):
"""Returns the category that input BAM files belongs to."""
bam_size = _get_gcs_object_size(pipeline_args.bam)
if bam_size == 0:
logging.warning('Size of input bam file is 0.')

is_wes = pipeline_args.model.find(_WES_STANDARD) != -1
is_wgs = pipeline_args.model.find(_WGS_STANDARD) != -1

if is_wes:
if bam_size < _WES_LARGE_THR:
return BamCategories.WES_SMALL
else:
return BamCategories.WES_LARGE

if is_wgs:
if bam_size < _WGS_SMALL_THR:
return BamCategories.WGS_SMALL
elif bam_size > _WGS_LARGE_THR:
return BamCategories.WGS_LARGE
else:
return BamCategories.WGS_MEDIUM


def _set_computational_flags_based_on_bam_size(pipeline_args):
"""Automatically sets computational flags based on size of input BAM file."""
bam_category = _get_bam_category(pipeline_args)
default_flags = _DEFAULT_FLAGS[bam_category]

pipeline_args.shards = (
default_flags['make_examples_workers'] *
default_flags['make_examples_cores_per_worker'])
pipeline_args.make_examples_workers = (default_flags['make_examples_workers'])
pipeline_args.make_examples_cores_per_worker = (
default_flags['make_examples_cores_per_worker'])
pipeline_args.make_examples_ram_per_worker_gb = (
default_flags['make_examples_cores_per_worker'] * _RAM_PER_CORE)
pipeline_args.make_examples_disk_per_worker_gb = (
_MAKE_EXAMPLES_DISK_PER_WORKER)
if 'gpu' in default_flags:
pipeline_args.gpu = default_flags['gpu']
pipeline_args.call_variants_workers = (
default_flags['call_variants_workers'])
pipeline_args.call_variants_cores_per_worker = (
default_flags['call_variants_cores_per_worker'])
pipeline_args.call_variants_ram_per_worker_gb = (
default_flags['call_variants_cores_per_worker'] * _RAM_PER_CORE)
pipeline_args.call_variants_disk_per_worker_gb = (
_CALL_VARIANTS_DISK_PER_WORKER)
elif 'tpu' in default_flags:
pipeline_args.tpu = default_flags['tpu']
pipeline_args.gke_cluster_zone = pipeline_args.zones[0]
else:
raise ValueError('Either gpu or tpu is needed for default flag settings.')
# Following flags are independent of BAM file category.
pipeline_args.gcsfuse = True
pipeline_args.preemptible = True
pipeline_args.max_preemptible_tries = 2
pipeline_args.max_non_preemptible_tries = 1
if pipeline_args.gvcf_outfile:
pipeline_args.postprocess_variants_disk_gb = _POSTPROCESS_VARIANTS_DISK_GVCF


def _run_make_examples(pipeline_args):
"""Runs the make_examples job."""

Expand Down Expand Up @@ -423,6 +555,8 @@ def get_extra_args():
extra_args.extend(['--sample_name', pipeline_args.sample_name])
if pipeline_args.hts_block_size:
extra_args.extend(['--hts_block_size', str(pipeline_args.hts_block_size)])
if pipeline_args.bam.endswith(_CRAM_FILE_SUFFIX):
extra_args.extend(['--use_ref_for_cram'])
return extra_args

command = _MAKE_EXAMPLES_COMMAND.format(
Expand Down Expand Up @@ -675,6 +809,23 @@ def get_extra_args():

def _validate_and_complete_args(pipeline_args):
"""Validates pipeline arguments and fills some missing args (if any)."""
if pipeline_args.set_optimized_flags_based_on_bam_size:
# First validating all necessary flags are present.
if not (pipeline_args.docker_image and pipeline_args.docker_image_gpu):
raise ValueError('both --docker_image and --docker_image_gpu must be '
'provided with --set_optimized_flags_based_on_bam_size')
is_wes = pipeline_args.model.find(_WES_STANDARD) != -1
is_wgs = pipeline_args.model.find(_WGS_STANDARD) != -1
if not is_wes and not is_wgs:
raise ValueError('Unable to automatically set computational flags. Given '
'model is neither WGS nor WES: %s' % pipeline_args.model)
if is_wes and is_wgs:
raise ValueError('Unable to automatically set computational flags. Given '
'model matches both WGS & WES: %s' % pipeline_args.model)
if not pipeline_args.bam.endswith(_BAM_FILE_SUFFIX):
raise ValueError(
'Only able to automatically set computational flags for BAM files.')
_set_computational_flags_based_on_bam_size(pipeline_args)
# Basic validation logic. More detailed validation is done by pipelines API.
if (pipeline_args.job_name_prefix and
not _meets_gcp_label_restrictions(pipeline_args.job_name_prefix)):
Expand Down Expand Up @@ -743,20 +894,21 @@ def _validate_and_complete_args(pipeline_args):
pipeline_args.ref_gzi = pipeline_args.ref + _GZI_FILE_SUFFIX
if not pipeline_args.bai:
pipeline_args.bai = pipeline_args.bam + _BAI_FILE_SUFFIX
if not _gcs_object_exist(pipeline_args.bai):
if _get_gcs_object_size(pipeline_args.bai) == 0:
pipeline_args.bai = pipeline_args.bam.replace(_BAM_FILE_SUFFIX,
_BAI_FILE_SUFFIX)

# Ensuring all input files exist...
if not _gcs_object_exist(pipeline_args.ref):
if _get_gcs_object_size(pipeline_args.ref) == 0:
raise ValueError('Given reference file via --ref does not exist')
if not _gcs_object_exist(pipeline_args.ref_fai):
if _get_gcs_object_size(pipeline_args.ref_fai) == 0:
raise ValueError('Given FAI index file via --ref_fai does not exist')
if (pipeline_args.ref_gzi and not _gcs_object_exist(pipeline_args.ref_gzi)):
if (pipeline_args.ref_gzi and
_get_gcs_object_size(pipeline_args.ref_gzi) == 0):
raise ValueError('Given GZI index file via --ref_gzi does not exist')
if not _gcs_object_exist(pipeline_args.bam):
if _get_gcs_object_size(pipeline_args.bam) == 0:
raise ValueError('Given BAM file via --bam does not exist')
if not _gcs_object_exist(pipeline_args.bai):
if _get_gcs_object_size(pipeline_args.bai) == 0:
raise ValueError('Given BAM index file via --bai does not exist')
# ...and we can write to output buckets.
if not _can_write_to_bucket(_get_gcs_bucket(pipeline_args.staging)):
Expand Down Expand Up @@ -896,6 +1048,15 @@ def run(argv=None):
help=('Optional. If non-zero, specifies the time interval in seconds for '
'writing workers log. Otherwise, log is written when the job is '
'finished.'))
parser.add_argument(
'--set_optimized_flags_based_on_bam_size',
default=False,
action='store_true',
help=('Automatically sets the best values for computational flags, such '
'as number of workers, number of cores, amount of ram and disk per '
'worker for both make_examples and call_variants steps based on '
'the size of input BAM file. This flag also automatically decides '
'whether to use TPU or GPU for call_variants stage.'))

# Optional GPU args.
parser.add_argument(
Expand Down Expand Up @@ -1000,12 +1161,12 @@ def run(argv=None):
parser.add_argument(
'--postprocess_variants_cores',
type=int,
default=8,
default=4,
help='Number of cores to use for postprocess_variants.')
parser.add_argument(
'--postprocess_variants_ram_gb',
type=int,
default=30,
default=16,
help='RAM (in GB) to use for postprocess_variants.')
parser.add_argument(
'--postprocess_variants_disk_gb',
Expand Down
Loading