Skip to content

Commit

Permalink
Updating a unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaanaksit committed Dec 19, 2023
1 parent b46c96a commit 558f930
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
11 changes: 7 additions & 4 deletions odak/learn/tools/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
import torch.nn


def quantize(image_field, bits=4):
def quantize(image_field, bits = 4, limits = [0., 1.]):
"""
Definition to quantize a image field (0-255, 8 bit) to a certain bits level.
Parameters
----------
image_field : torch.tensor
Input image field.
Input image field between any range.
bits : int
A value in between 0 to 8. Can not be zero.
limits : list
The minimum and maximum of the image_field variable.
Returns
----------
new_field : torch.tensor
Quantized image field.
"""
divider = 2**(8-bits)
new_field = image_field/divider
normalized_field = (image_field - limits[0]) / (limits[1] - limits[0])
divider = 2 ** bits
new_field = normalized_field * divider
new_field = new_field.int()
return new_field

Expand Down
7 changes: 5 additions & 2 deletions odak/learn/wave/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from tqdm import tqdm
from .util import wavenumber, generate_complex_field, calculate_amplitude, calculate_phase
from ..tools import torch_load, multi_scale_total_variation_loss
from ..tools import torch_load, multi_scale_total_variation_loss, quantize
from .propagators import propagator


Expand Down Expand Up @@ -386,7 +386,7 @@ def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
return hologram_phases.detach()


def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]):
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
"""
Function to optimize multiplane phase-only holograms.
Expand All @@ -396,6 +396,8 @@ def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]):
Number of iterations.
weights : list
Loss weights.
bits : int
Quantizes the hologram using the given bits and reconstructs.
Returns
-------
Expand All @@ -409,6 +411,7 @@ def optimize(self, number_of_iterations=100, weights=[1., 1., 1.]):
number_of_iterations=number_of_iterations,
weights=weights
)
hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
torch.no_grad()
reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
laser_powers = self.propagator.get_laser_powers()
Expand Down
Binary file modified test/data/sample_hologram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions test/test_learn_wave_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ def test():
pixel_pitch = 3.74e-6
number_of_frames = 3
number_of_depth_layers = 3
volume_depth = 1e-2
image_location_offset = 5e-3
volume_depth = 5e-3
image_location_offset = 0.
propagation_type = 'Bandlimited Angular Spectrum'
propagator_type = 'forward'
laser_channel_power = None
aperture = None
aperture_size = None
aperture_size = 1800
method = 'conventional'
device = torch.device('cpu')
hologram_phases_filename = './test/data/sample_hologram.png'
Expand All @@ -37,6 +37,8 @@ def test():
propagation_type = propagation_type,
propagator_type = propagator_type,
laser_channel_power = laser_channel_power,
aperture_size = aperture_size,
aperture = aperture,
method = method,
device = device
)
Expand Down

0 comments on commit 558f930

Please sign in to comment.