Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better dencode #145

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# 0.0.7 2024-04-30
- Change on-disk format of distributed encode and simplify
- Check for all partitions nominally completed encoding before doing
anything destructive in dencode-finalise

# 0.0.6 2024-04-24

- Only use NOSHUFFLE by default on ``call_genotype`` and bool arrays.
Expand Down
117 changes: 64 additions & 53 deletions bio2zarr/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,18 @@ def scan_vcf(path, target_num_partitions):

def check_overlap(partitions):
for i in range(1, len(partitions)):
prev_partition = partitions[i - 1]
current_partition = partitions[i]
if (
prev_partition.region.contig == current_partition.region.contig
and prev_partition.region.end > current_partition.region.start
):
raise ValueError(
f"Multiple VCFs have the region "
f"{prev_partition.region.contig}:{prev_partition.region.start}-"
f"{current_partition.region.end}"
)
prev_region = partitions[i - 1].region
current_region = partitions[i].region
if prev_region.contig == current_region.contig:
if prev_region.end is None:
logger.warning("Cannot check overlaps; issue #146")
continue
if prev_region.end > current_region.start:
raise ValueError(
f"Multiple VCFs have the region "
f"{prev_region.contig}:{prev_region.start}-"
f"{current_region.end}"
)


def scan_vcfs(paths, show_progress, target_num_partitions, worker_processes=1):
Expand Down Expand Up @@ -453,7 +454,7 @@ def sanitise_value_float_2d(buff, j, value):

def sanitise_int_array(value, ndmin, dtype):
if isinstance(value, tuple):
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
value = [VCF_INT_MISSING if x is None else x for x in value] # NEEDS TEST
value = np.array(value, ndmin=ndmin, copy=False)
value[value == VCF_INT_MISSING] = -1
value[value == VCF_INT_FILL] = -2
Expand Down Expand Up @@ -1548,10 +1549,8 @@ def parse_max_memory(max_memory):

@dataclasses.dataclass
class VcfZarrPartition:
start_index: int
stop_index: int
start_chunk: int
stop_chunk: int
start: int
stop: int

@staticmethod
def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None):
Expand All @@ -1565,9 +1564,7 @@ def generate_partitions(num_records, chunk_size, num_partitions, max_chunks=None
stop_chunk = int(chunk_slice[-1]) + 1
start_index = start_chunk * chunk_size
stop_index = min(stop_chunk * chunk_size, num_records)
partitions.append(
VcfZarrPartition(start_index, stop_index, start_chunk, stop_chunk)
)
partitions.append(VcfZarrPartition(start_index, stop_index))
return partitions


Expand All @@ -1590,7 +1587,7 @@ def asdict(self):
def fromdict(d):
if d["format_version"] != VZW_METADATA_FORMAT_VERSION:
raise ValueError(
"VcfZarrWriter format version mismatch: "
"VcfZarrWriter format version mismatch: "
f"{d['format_version']} != {VZW_METADATA_FORMAT_VERSION}"
)
ret = VcfZarrWriterMetadata(**d)
Expand Down Expand Up @@ -1675,7 +1672,7 @@ def init(
root = zarr.group(store=store)

for column in self.schema.columns.values():
self.init_array(root, column, partitions[-1].stop_index)
self.init_array(root, column, partitions[-1].stop)

logger.info("Writing WIP metadata")
with open(self.wip_path / "metadata.json", "w") as f:
Expand Down Expand Up @@ -1762,28 +1759,42 @@ def load_metadata(self):
def partition_path(self, partition_index):
return self.partitions_path / f"p{partition_index}"

def wip_partition_path(self, partition_index):
return self.partitions_path / f"wip_p{partition_index}"

def wip_partition_array_path(self, partition_index, name):
return self.partition_path(partition_index) / f"wip_{name}"
return self.wip_partition_path(partition_index) / name

def partition_array_path(self, partition_index, name):
return self.partition_path(partition_index) / name

def encode_partition(self, partition_index):
self.load_metadata()
partition_path = self.partition_path(partition_index)
if partition_index < 0 or partition_index >= self.num_partitions:
raise ValueError(
"Partition index must be in the range 0 <= index < num_partitions"
)
partition_path = self.wip_partition_path(partition_index)
partition_path.mkdir(exist_ok=True)
logger.info(f"Encoding partition {partition_index} to {partition_path}")

self.encode_alleles_partition(partition_index)
self.encode_id_partition(partition_index)
self.encode_filters_partition(partition_index)
self.encode_contig_partition(partition_index)
self.encode_alleles_partition(partition_index)
for col in self.schema.columns.values():
if col.vcf_field is not None:
self.encode_array_partition(col, partition_index)
if "call_genotype" in self.schema.columns:
self.encode_genotypes_partition(partition_index)

final_path = self.partition_path(partition_index)
logger.info(f"Finalising {partition_index} at {final_path}")
if final_path.exists():
logger.warning("Removing existing partition at {final_path}")
shutil.rmtree(final_path)
os.rename(partition_path, final_path)

def init_partition_array(self, partition_index, name):
wip_path = self.wip_partition_array_path(partition_index, name)
# Create an empty array like the definition
Expand All @@ -1795,27 +1806,17 @@ def init_partition_array(self, partition_index, name):
return array

def finalise_partition_array(self, partition_index, name):
wip_path = self.wip_partition_array_path(partition_index, name)
final_path = self.partition_array_path(partition_index, name)
if final_path.exists():
# NEEDS TEST
logger.warning(f"Removing existing {final_path}")
shutil.rmtree(final_path)
# Atomic swap
os.rename(wip_path, final_path)
logger.debug(f"Encoded {name} partition {partition_index}")

def encode_array_partition(self, column, partition_index):
array = self.init_partition_array(partition_index, column.name)

partition = self.metadata.partitions[partition_index]
ba = core.BufferedArray(array, partition.start_index)
ba = core.BufferedArray(array, partition.start)
source_col = self.icf.columns[column.vcf_field]
sanitiser = source_col.sanitiser_factory(ba.buff.shape)

for value in source_col.iter_values(
partition.start_index, partition.stop_index
):
for value in source_col.iter_values(partition.start, partition.stop):
# We write directly into the buffer in the sanitiser function
# to make it easier to reason about dimension padding
j = ba.next_buffer_row()
Expand All @@ -1831,14 +1832,12 @@ def encode_genotypes_partition(self, partition_index):
)

partition = self.metadata.partitions[partition_index]
gt = core.BufferedArray(gt_array, partition.start_index)
gt_mask = core.BufferedArray(gt_mask_array, partition.start_index)
gt_phased = core.BufferedArray(gt_phased_array, partition.start_index)
gt = core.BufferedArray(gt_array, partition.start)
gt_mask = core.BufferedArray(gt_mask_array, partition.start)
gt_phased = core.BufferedArray(gt_phased_array, partition.start)

source_col = self.icf.columns["FORMAT/GT"]
for value in source_col.iter_values(
partition.start_index, partition.stop_index
):
for value in source_col.iter_values(partition.start, partition.stop):
j = gt.next_buffer_row()
sanitise_value_int_2d(gt.buff, j, value[:, :-1])
j = gt_phased.next_buffer_row()
Expand All @@ -1859,13 +1858,13 @@ def encode_alleles_partition(self, partition_index):
array_name = "variant_allele"
alleles_array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
alleles = core.BufferedArray(alleles_array, partition.start_index)
alleles = core.BufferedArray(alleles_array, partition.start)
ref_col = self.icf.columns["REF"]
alt_col = self.icf.columns["ALT"]

for ref, alt in zip(
ref_col.iter_values(partition.start_index, partition.stop_index),
alt_col.iter_values(partition.start_index, partition.stop_index),
ref_col.iter_values(partition.start, partition.stop),
alt_col.iter_values(partition.start, partition.stop),
):
j = alleles.next_buffer_row()
alleles.buff[j, :] = STR_FILL
Expand All @@ -1879,11 +1878,11 @@ def encode_id_partition(self, partition_index):
vid_array = self.init_partition_array(partition_index, "variant_id")
vid_mask_array = self.init_partition_array(partition_index, "variant_id_mask")
partition = self.metadata.partitions[partition_index]
vid = core.BufferedArray(vid_array, partition.start_index)
vid_mask = core.BufferedArray(vid_mask_array, partition.start_index)
vid = core.BufferedArray(vid_array, partition.start)
vid_mask = core.BufferedArray(vid_mask_array, partition.start)
col = self.icf.columns["ID"]

for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = vid.next_buffer_row()
k = vid_mask.next_buffer_row()
assert j == k
Expand All @@ -1904,10 +1903,10 @@ def encode_filters_partition(self, partition_index):
array_name = "variant_filter"
array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
var_filter = core.BufferedArray(array, partition.start_index)
var_filter = core.BufferedArray(array, partition.start)

col = self.icf.columns["FILTERS"]
for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = var_filter.next_buffer_row()
var_filter.buff[j] = False
for f in value:
Expand All @@ -1926,10 +1925,10 @@ def encode_contig_partition(self, partition_index):
array_name = "variant_contig"
array = self.init_partition_array(partition_index, array_name)
partition = self.metadata.partitions[partition_index]
contig = core.BufferedArray(array, partition.start_index)
contig = core.BufferedArray(array, partition.start)
col = self.icf.columns["CHROM"]

for value in col.iter_values(partition.start_index, partition.stop_index):
for value in col.iter_values(partition.start, partition.stop):
j = contig.next_buffer_row()
# Note: because we are using the indexes to define the lookups
# and we always have an index, it seems that we the contig lookup
Expand All @@ -1950,7 +1949,7 @@ def finalise_array(self, name):
if final_path.exists():
# NEEDS TEST
raise ValueError(f"Array {name} already exists")
for partition in range(len(self.metadata.partitions)):
for partition in range(self.num_partitions):
# Move all the files in partition dir to dest dir
src = self.partition_array_path(partition, name)
if not src.exists():
Expand All @@ -1977,6 +1976,15 @@ def finalise_array(self, name):
def finalise(self, show_progress=False):
self.load_metadata()

logger.info("Scanning {self.num_partitions} partitions")
missing = []
# TODO may need a progress bar here
for partition_id in range(self.num_partitions):
if not self.partition_path(partition_id).exists():
missing.append(partition_id)
if len(missing) > 0:
raise FileNotFoundError(f"Partitions not encoded: {missing}")

progress_config = core.ProgressConfig(
total=len(self.schema.columns),
title="Finalise",
Expand All @@ -1994,6 +2002,9 @@ def finalise(self, show_progress=False):
with core.ParallelWorkManager(0, progress_config) as pwm:
for name in self.schema.columns:
pwm.submit(self.finalise_array, name)
logger.debug(f"Removing {self.wip_path}")
shutil.rmtree(self.wip_path)
logger.info("Consolidating Zarr metadata")
zarr.consolidate_metadata(self.path)

######################
Expand Down
Loading
Loading