diff --git a/odak/learn/tools/matrix.py b/odak/learn/tools/matrix.py index d6b7a385..237cc6ac 100644 --- a/odak/learn/tools/matrix.py +++ b/odak/learn/tools/matrix.py @@ -3,24 +3,27 @@ import torch.nn -def quantize(image_field, bits=4): +def quantize(image_field, bits = 4, limits = [0., 1.]): """ Definition to quantize a image field (0-255, 8 bit) to a certain bits level. Parameters ---------- image_field : torch.tensor - Input image field. + Input image field between any range. bits : int A value in between 0 to 8. Can not be zero. + limits : list + The minimum and maximum of the image_field variable. Returns ---------- new_field : torch.tensor Quantized image field. """ - divider = 2**(8-bits) - new_field = image_field/divider + normalized_field = (image_field - limits[0]) / (limits[1] - limits[0]) + divider = 2 ** bits + new_field = normalized_field * divider new_field = new_field.int() return new_field diff --git a/odak/learn/wave/optimizers.py b/odak/learn/wave/optimizers.py index e2dbae5a..c7cea6d6 100644 --- a/odak/learn/wave/optimizers.py +++ b/odak/learn/wave/optimizers.py @@ -3,7 +3,7 @@ import numpy as np from tqdm import tqdm from .util import wavenumber, generate_complex_field, calculate_amplitude, calculate_phase -from ..tools import torch_load, multi_scale_total_variation_loss +from ..tools import torch_load, multi_scale_total_variation_loss, quantize from .propagators import propagator @@ -386,7 +386,7 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]): return hologram_phases.detach() - def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]): + def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8): """ Function to optimize multiplane phase-only holograms. @@ -396,6 +396,8 @@ def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]): Number of iterations. weights : list Loss weights. + bits : int + Quantizes the hologram using the given bits and reconstructs. Returns ------- @@ -409,6 +411,7 @@ def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]): number_of_iterations=number_of_iterations, weights=weights ) + hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi torch.no_grad() reconstruction_intensities = self.propagator.reconstruct(hologram_phases) laser_powers = self.propagator.get_laser_powers() diff --git a/test/data/sample_hologram.png b/test/data/sample_hologram.png index c5a777c1..d713b439 100644 Binary files a/test/data/sample_hologram.png and b/test/data/sample_hologram.png differ diff --git a/test/test_learn_wave_propagator.py b/test/test_learn_wave_propagator.py index 37fb8afb..203a4af9 100644 --- a/test/test_learn_wave_propagator.py +++ b/test/test_learn_wave_propagator.py @@ -11,13 +11,13 @@ def test(): pixel_pitch = 3.74e-6 number_of_frames = 3 number_of_depth_layers = 3 - volume_depth = 1e-2 - image_location_offset = 5e-3 + volume_depth = 5e-3 + image_location_offset = 0. propagation_type = 'Bandlimited Angular Spectrum' propagator_type = 'forward' laser_channel_power = None aperture = None - aperture_size = None + aperture_size = 1800 method = 'conventional' device = torch.device('cpu') hologram_phases_filename = './test/data/sample_hologram.png' @@ -37,6 +37,8 @@ def test(): propagation_type = propagation_type, propagator_type = propagator_type, laser_channel_power = laser_channel_power, + aperture_size = aperture_size, + aperture = aperture, method = method, device = device )