Skip to content

Commit

Permalink
Refactor common file name code.
Browse files Browse the repository at this point in the history
  • Loading branch information
We-Gold committed Aug 7, 2024
1 parent a6240db commit 797d455
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 12 deletions.
12 changes: 12 additions & 0 deletions python/ouroboros/helpers/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,15 @@ def format_backproject_tempvolumes(output_name: str):

def format_backproject_resave_volume(output_name: str):
return output_name + "-temp-straightened.tif"


def format_tiff_name(i: int, num_digits: int) -> str:
return f"{str(i).zfill(num_digits)}.tif"


def parse_tiff_name(tiff_name: str) -> int:
return int(tiff_name.split(".")[0])


def num_digits_for_n_files(n: int):
return len(str(n - 1))
27 changes: 25 additions & 2 deletions python/ouroboros/helpers/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,30 @@ def detect_color_channels(data: np.ndarray, none_value=1):
- num_color_channels (int): The number of color channels in the volume.
"""

has_color_channels = data.ndim == COLOR_CHANNELS_DIMENSIONS
num_color_channels = data.shape[-1] if has_color_channels else none_value
has_color_channels, num_color_channels = detect_color_channels_shape(
data.shape, none_value
)

return has_color_channels, num_color_channels


def detect_color_channels_shape(shape: tuple, none_value=1):
"""
Detect the number of color channels in a volume.
Parameters:
----------
shape (tuple): The shape of the volume data.
none_value (int): The value to return if the volume has no color channels.
Returns:
-------
tuple: A tuple containing the following:
- has_color_channels (bool): Whether the volume has color channels.
- num_color_channels (int): The number of color channels in the volume.
"""

has_color_channels = len(shape) == COLOR_CHANNELS_DIMENSIONS
num_color_channels = shape[-1] if has_color_channels else none_value

return has_color_channels, num_color_channels
19 changes: 11 additions & 8 deletions python/ouroboros/pipeline/backproject_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ouroboros.helpers.memory_usage import GIGABYTE, calculate_gigabytes_from_dimensions
from ouroboros.helpers.slice import (
detect_color_channels,
detect_color_channels_shape,
generate_coordinate_grid_for_rect,
make_volume_binary,
write_slices_to_volume,
Expand All @@ -15,9 +16,12 @@
format_backproject_output_multiple,
format_backproject_resave_volume,
format_backproject_tempvolumes,
format_tiff_name,
get_sorted_tif_files,
join_path,
load_and_save_tiff_from_slices,
num_digits_for_n_files,
parse_tiff_name,
)

import concurrent.futures
Expand Down Expand Up @@ -201,7 +205,7 @@ def _process(self, input_data: any) -> tuple[any, None] | tuple[None, any]:
volume_shape = volume_cache.get_volume_shape()

# Determine the number of digits needed for the tif file names
num_digits = len(str(volume_shape[axis] - 1))
num_digits = num_digits_for_n_files(volume_shape[axis])

# Determine the number of channels in the straightened volume
temp_straightened_volume = make_tiff_memmap(straightened_volume_path, mode="r")
Expand Down Expand Up @@ -234,7 +238,7 @@ def _process(self, input_data: any) -> tuple[any, None] | tuple[None, any]:
slice_index = 0
for j in slice_range:
tifffile.imwrite(
join_path(folder_path, f"{str(j).zfill(num_digits)}.tif"),
join_path(folder_path, format_tiff_name(j, num_digits)),
np.take(chunk_volume, slice_index, axis=axis),
contiguous=True,
compression=config.backprojection_compression,
Expand Down Expand Up @@ -298,7 +302,7 @@ def _process(self, input_data: any) -> tuple[any, None] | tuple[None, any]:
slice_index = 0
for j in slice_range:
tifffile.imwrite(
join_path(folder_path, f"{str(j).zfill(num_digits)}.tif"),
join_path(folder_path, format_tiff_name(j, num_digits)),
np.take(chunk_volume, slice_index, axis=axis),
contiguous=True,
compression=config.backprojection_compression,
Expand Down Expand Up @@ -664,9 +668,9 @@ def rescale_folder_tif(
source_url, current_mip, target_mip, new_shape
)

num_digits = len(str(len(tifs)))
num_digits = num_digits_for_n_files(len(tifs))

first_index = int(tifs[0].split(".")[0])
first_index = parse_tiff_name(tifs[0])

output_index = int(first_index * resolution_factors[0])

Expand All @@ -684,7 +688,7 @@ def rescale_folder_tif(
# Write the layers to new tif files
for j in range(size):
tifffile.imwrite(
join_path(output_folder, f"{str(output_index).zfill(num_digits)}.tif"),
join_path(output_folder, format_tiff_name(output_index, num_digits)),
layers[j],
contiguous=True if compression is None else False,
compression=compression,
Expand All @@ -707,8 +711,7 @@ def calculate_scaling_factors(source_url, current_mip, target_mip, tif_shape):
for i in range(len(target_resolution))
)

has_color_channels = len(tif_shape) == 4
num_channels = tif_shape[-1] if has_color_channels else 1
has_color_channels, num_channels = detect_color_channels_shape(tif_shape)

# Determine the scaling factor for each axis as a tuple
scaling_factors = resolution_factors + (
Expand Down
6 changes: 4 additions & 2 deletions python/ouroboros/pipeline/slice_parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from ouroboros.helpers.files import (
format_slice_output_file,
format_slice_output_multiple,
format_tiff_name,
join_path,
num_digits_for_n_files,
)
from .pipeline import PipelineStep
from ouroboros.helpers.options import SliceOptions
Expand Down Expand Up @@ -107,7 +109,7 @@ def _process(self, input_data: tuple[any]) -> None | str:
return f"Error creating single tif file: {e}"

# Calculate the number of digits needed to store the number of slices
num_digits = len(str(len(slice_rects) - 1))
num_digits = num_digits_for_n_files(len(slice_rects))

# Create a queue to hold downloaded data for processing
data_queue = multiprocessing.Queue()
Expand Down Expand Up @@ -273,7 +275,7 @@ def process_worker_save_parallel(

for i, slice_i in zip(slice_indices, slices):
start = time.perf_counter()
filename = join_path(folder_name, f"{str(i).zfill(num_digits)}.tif")
filename = join_path(folder_name, format_tiff_name(i, num_digits))
futures.append(thread_executor.submit(save_thread, filename, slice_i))
durations["save"].append(time.perf_counter() - start)

Expand Down
21 changes: 21 additions & 0 deletions python/test/helpers/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
format_slice_output_config_file,
format_slice_output_file,
format_slice_output_multiple,
format_tiff_name,
load_and_save_tiff_from_slices,
get_sorted_tif_files,
join_path,
combine_unknown_folder,
num_digits_for_n_files,
parse_tiff_name,
)


Expand Down Expand Up @@ -142,3 +145,21 @@ def test_format_backproject_tempvolumes():
def test_format_backproject_resave_volume():
result = format_backproject_resave_volume("test")
assert isinstance(result, str)


def test_format_tiff_name():
result = format_tiff_name(1, 3)
assert isinstance(result, str)
assert result == f"{str(1).zfill(3)}.tif"


def test_parse_tiff_name():
result = parse_tiff_name("001.tif")
assert isinstance(result, int)
assert result == 1


def test_num_digits_for_n_files():
result = num_digits_for_n_files(100)
assert isinstance(result, int)
assert result == 2

0 comments on commit 797d455

Please sign in to comment.