diff --git a/dflat/GDSII/assemble.py b/dflat/GDSII/assemble.py index 8bebd4b..9856c58 100644 --- a/dflat/GDSII/assemble.py +++ b/dflat/GDSII/assemble.py @@ -213,7 +213,9 @@ def assemble_standard_shapes( if cell_fun == gdspy.Round: if len(shape_params) == 1: shape_params = [shape_params[0], shape_params[0]] - shape = cell_fun((xoffset, yoffset), shape_params) + shape = cell_fun( + (xoffset, yoffset), shape_params, number_of_points=number_of_points + ) elif cell_fun == gdspy.Rectangle: shape_params += [xoffset, yoffset] shape = cell_fun((xoffset, yoffset), shape_params) @@ -221,17 +223,17 @@ def assemble_standard_shapes( raise ValueError cell.add(shape) - # Add lens markers - hx = cell_size[1] * pshape[1] / gds_unit - hy = cell_size[0] * pshape[0] / gds_unit - ms = marker_size / gds_unit - cell_annot = lib.new_cell(f"TEXT_{unique_id}") - add_marker_tag(cell_annot, ms, hx, hy) + # # Add lens markers + # hx = cell_size[1] * pshape[1] / gds_unit + # hy = cell_size[0] * pshape[0] / gds_unit + # ms = marker_size / gds_unit + # cell_annot = lib.new_cell(f"TEXT_{unique_id}") + # add_marker_tag(cell_annot, ms, hx, hy) - # Create top-level cell and add references + # # Create top-level cell and add references top_cell = lib.new_cell(f"TOP_CELL_{unique_id}") top_cell.add(gdspy.CellReference(cell)) - top_cell.add(gdspy.CellReference(cell_annot)) + # top_cell.add(gdspy.CellReference(cell_annot)) # Write GDS file lib.write_gds(savepath) diff --git a/dflat/plot_utilities/__init__.py b/dflat/plot_utilities/__init__.py index 9eae4ab..e9dcc2d 100644 --- a/dflat/plot_utilities/__init__.py +++ b/dflat/plot_utilities/__init__.py @@ -1,2 +1,2 @@ from .mp_format import format_plot, add_colorbar, axis_off -from .vis import plot_3d_stack, gif_from_saved_images, video_from_saved_images +from .vis import plot_3d_stack, gif_from_saved_images diff --git a/dflat/plot_utilities/vis.py b/dflat/plot_utilities/vis.py index 6e4747d..effcdbc 100644 --- a/dflat/plot_utilities/vis.py +++ b/dflat/plot_utilities/vis.py @@ -3,34 +3,36 @@ import numpy as np import matplotlib.pyplot as plt from natsort import natsorted -from moviepy.editor import ImageSequenceClip +# from moviepy.editor import ImageSequenceClip +# Moviepy should be removed. It has some bugs -def video_from_saved_images( - filepath, filetag, savename, fps, deleteFrames=True, verbose=False -): - print("Call video generator") - png_files = natsorted( - [ - os.path.join(filepath, f) - for f in os.listdir(filepath) - if f.startswith(filetag) and f.endswith(".png") - ] - ) - if verbose: - for file in png_files: - print("Adding image file as frame: " + file) - clip = ImageSequenceClip(png_files, fps=fps) - clip.write_videofile( - os.path.join(filepath, savename) + ".mp4", codec="libx264", fps=fps - ) +# def video_from_saved_images( +# filepath, filetag, savename, fps, deleteFrames=True, verbose=False +# ): +# print("Call video generator") +# png_files = natsorted( +# [ +# os.path.join(filepath, f) +# for f in os.listdir(filepath) +# if f.startswith(filetag) and f.endswith(".png") +# ] +# ) +# if verbose: +# for file in png_files: +# print("Adding image file as frame: " + file) - if deleteFrames: - for file in png_files: - os.remove(file) +# clip = ImageSequenceClip(png_files, fps=fps) +# clip.write_videofile( +# os.path.join(filepath, savename) + ".mp4", codec="libx264", fps=fps +# ) - return +# if deleteFrames: +# for file in png_files: +# os.remove(file) + +# return def gif_from_saved_images( diff --git a/dflat/render/__init__.py b/dflat/render/__init__.py index fd705ba..59e8dc3 100644 --- a/dflat/render/__init__.py +++ b/dflat/render/__init__.py @@ -1,3 +1,3 @@ from .fft_convolve import general_convolve, weiner_deconvolve from .fronto_planar_renderer import Fronto_Planar_Renderer_Incoherent -from .util_meas import hsi_to_rgb, photons_to_ADU +from .util_meas import hsi_to_rgb, photons_to_ADU, rgb_to_hsi_adjoint diff --git a/dflat/render/fft_convolve.py b/dflat/render/fft_convolve.py index 0b54116..8af5b05 100644 --- a/dflat/render/fft_convolve.py +++ b/dflat/render/fft_convolve.py @@ -6,9 +6,7 @@ from dflat.radial_tranforms import resize_with_crop_or_pad - - -def general_convolve(image, filter, rfft=False, mode="valid"): +def general_convolve(image, filter, rfft=False, mode="valid", adjoint=False): """Runs the Fourier space convolution between an image and filter, where the filter kernels may have a different size from the image shape. Args: @@ -40,7 +38,7 @@ def general_convolve(image, filter, rfft=False, mode="valid"): filter_resh = resize_with_crop_or_pad(filter, *image_shape[-2:], radial_flag=False) ### Run the convolution (Defualt to using a checkpoint of the fourier transform) - image = checkpoint(fourier_convolve, image, filter_resh, rfft) + image = checkpoint(fourier_convolve, image, filter_resh, rfft, adjoint) image = torch.real(image) if mode == "valid": @@ -90,27 +88,31 @@ def weiner_deconvolve(image, filter, const=1e-4, abs=False): return image -def fourier_convolve(image, filter, rfft=False): +def fourier_convolve(image, filter, rfft=False, adjoint=False): """Computes the convolution of two signals (real or complex) using frequency space multiplcation. Convolution is done over the two inner-most dimensions. Args: `image` (float or complex): Image to apply filter to, of shape [..., Ny, Nx] `filter` (float or complex): Filter kernel; The kernel must be the same shape as the image + `adjoint' (bool, optional): _description_. Defaults to False. Returns: complex: Image with filter convolved, same shape as input """ + # Ensure inputs are complex TORCH_ZERO = torch.tensor(0.0).to(dtype=image.dtype, device=image.device) if rfft: - fourier_product = rfft2(ifftshift(image)) * rfft2(ifftshift(filter)) + kf = rfft2(ifftshift(filter)) + kf = torch.conj(kf) if adjoint else kf + fourier_product = rfft2(ifftshift(image)) * kf fourier_product = fftshift(irfft2(fourier_product)) else: image = torch.complex(image, TORCH_ZERO) if not image.is_complex() else image - filter = ( - torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter - ) - fourier_product = fft2(ifftshift(image)) * fft2(ifftshift(filter)) + kf = torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter + kf = fft2(ifftshift(filter)) + kf = torch.conj(kf) if adjoint else kf + fourier_product = fft2(ifftshift(image)) * kf fourier_product = fftshift(ifft2(fourier_product)) return fourier_product diff --git a/dflat/render/fronto_planar_renderer.py b/dflat/render/fronto_planar_renderer.py index 9adfdaf..462691f 100644 --- a/dflat/render/fronto_planar_renderer.py +++ b/dflat/render/fronto_planar_renderer.py @@ -55,13 +55,13 @@ def _forward(self, psf_intensity, scene_radiance, rfft, crop_to_psf_dim): ## TO be updated and added later return meas - def rgb_measurement(self, meas, wavelength_set_m, bayer_mosaic=False, gamma=True): + def rgb_measurement(self, meas, wavelength_set_m, process="demosaic", gamma=True): B, P, Z, L, H, W = meas.shape meas = hsi_to_rgb( rearrange(meas, "B P Z L H W -> (B P Z) H W L", B=B, P=P, Z=Z), wavelength_set_m, - bayer_mosaic, - gamma, + gamma=gamma, + process=process, ) meas = rearrange(meas, "(B P Z) H W C -> B P Z C H W", B=B, P=P, Z=Z) return meas diff --git a/dflat/render/util_meas.py b/dflat/render/util_meas.py index 9d2162e..7d81a36 100644 --- a/dflat/render/util_meas.py +++ b/dflat/render/util_meas.py @@ -2,6 +2,8 @@ import torch.nn.functional as F from einops import rearrange import numpy as np +import warnings + from dflat.render.util_sensor import get_QETrans_Basler_Bayer from dflat.render.util_spectral import get_rgb_bar_CIE1931, gamma_correction @@ -9,11 +11,12 @@ def hsi_to_rgb( hsi, wavelength_set_m, - demosaic=False, gamma=False, tensor_ordering=False, normalize=True, projection="Basler_Bayer", + process="ideal", + **kwargs ): """Converts a batched hyperspectral datacube of shape [minibatch, Height, Width, Channels] to RGB. If tensor_ordering is true, input may instead be passed with the more common tensor shape [B, Ch, H, W]. The CIE1931 color matching functions are used by default. @@ -21,25 +24,40 @@ def hsi_to_rgb( Args: hsi (float): Hyperspectral cube with shsape [B, H, W, Ch] or [B, Ch, H, W] if tensor_ordering is True. wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension. - demosaic (bool, optional): If True, a Bayer filter mask is applied to the RGB images and then interpolation is used to match experiment. Defaults to True. gamma (bool, optional): Applies gamma transformation to the input images. Defaults to True. tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False. normalize (bool, optional): If true, the returned projection is max normalized to 1. projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves. - - Returns: + process (str, optional): Either 'ideal', 'raw', 'demosaic'. ideal means return 3 color channels with no spatial resolution loss. Demosaic applies bayer mask and interp, raw returns 1 channel spatial mosaiced measurement. + demosaic(boolean, optional): Deprecated. Now replaced by process field. demosaic true sets process = 'demosaic'. + Returns: RGB: Stack of images with output channels=3 """ + if "demosaic" in kwargs: + warnings.warn( + "The 'demosaic' argument is deprecated and will be removed in future versions. " + "Please use the 'process' argument instead, setting process='demosaic'.", + DeprecationWarning, + ) + if kwargs["demosaic"]: + process = "demosaic" + assert projection in [ "CIE1931", "Basler_Bayer", ], "Projection must be one of ['CIE1931', 'Basler_Bayer']." + assert process in [ + "ideal", + "raw", + "demosaic", + ], "Process must be one of ['ideal', 'raw', 'demosaic']." input_tensor = torch.is_tensor(hsi) if not input_tensor: hsi = torch.tensor(hsi) if tensor_ordering: hsi = hsi.transpose(-3, -1).transpose(-3, -2).contiguous() + assert ( len(wavelength_set_m) == hsi.shape[-1] ), "List of wavelengths should match the input channel dimension." @@ -52,20 +70,27 @@ def hsi_to_rgb( spec = spec / np.sum(spec, axis=0, keepdims=True) spec = torch.tensor(spec).type_as(hsi) - rgb = torch.matmul(hsi, spec) - scale = torch.amax(rgb, dim=(-3, -2, -1), keepdim=True) - if normalize: - rgb = rgb / scale - if demosaic: - rgb = bayer_interpolate(bayer_mask(rgb)) + out = torch.matmul(hsi, spec) + + if process == "demosaic": + out = bayer_interpolate(bayer_mask(out)) + elif process == "raw": + out = bayer_mask(out) + out = torch.sum(out, axis=-1, keepdims=True) + + if normalize or gamma: + out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True) + if gamma: - rgb = gamma_correction(rgb) + out = gamma_correction(out) + if tensor_ordering: - rgb = rgb.transpose(-3, -1).transpose(-2, -1).contiguous() + out = out.transpose(-3, -1).transpose(-2, -1).contiguous() + if not input_tensor: - rgb = rgb.cpu().numpy() + out = out.cpu().numpy() - return rgb + return out def bayer_mask(rgb_img): @@ -185,3 +210,58 @@ def photons_to_ADU( return torch.clip(electrons_signal, min=0) else: return electrons_signal + + +def rgb_to_hsi_adjoint( + rgb, + wavelength_set_m, + tensor_ordering=False, + normalize=True, + projection="Basler_Bayer", +): + """Compute the adjoint approximation of rgb to hsi (used for some algorithm initializations) + + Args: + rgb (float): Three channel RGB measurement + wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension. + tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False. + normalize (bool, optional): If true, the returned projection is max normalized to 1. + projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves. + + Returns: + float: Hyperspectral initialization + """ + + assert projection.lower() in [ + "cie1931", + "basler_bayer", + ], "Projection must be one of ['cie1931', 'basler_bayer']." + + input_tensor = torch.is_tensor(rgb) + if not input_tensor: + rgb = torch.tensor(rgb) + + if tensor_ordering: + rgb = rgb.transpose(-3, -1).transpose(-3, -2).contiguous() # ... h w c + assert 3 == rgb.shape[-1], "Channel dimension must be 3 for adjoint transform" + + if projection.lower() == "cie1931": + spec = get_rgb_bar_CIE1931(wavelength_set_m * 1e9) + elif projection.lower() == "basler_bayer": + spec, _ = get_QETrans_Basler_Bayer(wavelength_set_m * 1e9) + spec = np.concatenate([spec[:, 0:1], spec[:, 2:]], axis=-1) + spec = spec / np.sum(spec, axis=0, keepdims=True) + + spec = torch.tensor(spec).type_as(rgb) # [C, 3] + out = torch.matmul(rgb, spec.T) + + if normalize: + out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True) + + if tensor_ordering: + out = out.transpose(-3, -1).transpose(-2, -1).contiguous() + + if not input_tensor: + out = out.cpu().numpy() + + return out diff --git a/docs/api/rcwa.rst b/docs/api/rcwa.rst index dcf06f4..a7c5d88 100644 --- a/docs/api/rcwa.rst +++ b/docs/api/rcwa.rst @@ -13,5 +13,6 @@ Public Functions :members: :undoc-members: :show-inheritance: - :inherited-members: + + .. automethod:: forward diff --git a/tests/test_hsi_to_rgb.py b/tests/test_hsi_to_rgb.py index 2a236d8..4441857 100644 --- a/tests/test_hsi_to_rgb.py +++ b/tests/test_hsi_to_rgb.py @@ -16,15 +16,14 @@ def hyperspectral_data(): @pytest.mark.parametrize( - "demosaic, gamma, normalize", + "process, gamma, normalize", [ - (True, True, True), - (False, False, False), - (True, False, True), - (False, True, False), + ("ideal", True, True), + ("raw", False, False), + ("demosaic", False, True), ], ) -def test_hsi_to_rgb(hyperspectral_data, demosaic, gamma, normalize): +def test_hsi_to_rgb(hyperspectral_data, process, gamma, normalize): hsi, wavelengths = hyperspectral_data hsi_tensor = torch.tensor(hsi) B, H, W, _ = hsi_tensor.shape # Define B, H, and W based on tensor shape @@ -33,25 +32,23 @@ def test_hsi_to_rgb(hyperspectral_data, demosaic, gamma, normalize): rgb_image = hsi_to_rgb( hsi_tensor, wavelengths, - demosaic=demosaic, gamma=gamma, tensor_ordering=False, normalize=normalize, + process=process, projection="CIE1931", # You can switch between "CIE1931" and "Basler_Bayer" if needed ) # Check the shape of the output - expected_shape = (B, H, W, 3) # Expecting 3 channels for RGB + expected_shape = (B, H, W, 3) if process in ["ideal", "demosaic"] else (B, H, W, 1) assert ( rgb_image.shape == expected_shape ), f"Expected RGB shape {expected_shape}, but got {rgb_image.shape}" - # Optional: Add additional checks such as dtype or value ranges assert rgb_image.dtype == torch.float32, "Output should be of type float32" if normalize: assert rgb_image.max() <= 1, "Normalized RGB values should not exceed 1" - # Check that the RGB conversion does not produce any unexpected extremely high or low values assert not torch.any(torch.isnan(rgb_image)), "RGB image should not contain NaNs" assert not torch.any( torch.isinf(rgb_image) diff --git a/tests/test_render.py b/tests/test_render.py index 47c7b0c..6f8f57c 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -45,9 +45,7 @@ def test_fronto_planar_renderer_rgb_conversion( psf_intensity, scene_radiance, rfft=True, crop_to_psf_dim=False ) - rgb_output = renderer.rgb_measurement( - meas, wavelength_set_m, bayer_mosaic=True, gamma=True - ) + rgb_output = renderer.rgb_measurement(meas, wavelength_set_m) expected_shape = meas.shape[:-3] + (3, meas.shape[-2], meas.shape[-1]) assert (