Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change rgb render functionality to match recent work #4

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions dflat/GDSII/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,27 @@ def assemble_standard_shapes(
if cell_fun == gdspy.Round:
if len(shape_params) == 1:
shape_params = [shape_params[0], shape_params[0]]
shape = cell_fun((xoffset, yoffset), shape_params)
shape = cell_fun(
(xoffset, yoffset), shape_params, number_of_points=number_of_points
)
elif cell_fun == gdspy.Rectangle:
shape_params += [xoffset, yoffset]
shape = cell_fun((xoffset, yoffset), shape_params)
else:
raise ValueError
cell.add(shape)

# Add lens markers
hx = cell_size[1] * pshape[1] / gds_unit
hy = cell_size[0] * pshape[0] / gds_unit
ms = marker_size / gds_unit
cell_annot = lib.new_cell(f"TEXT_{unique_id}")
add_marker_tag(cell_annot, ms, hx, hy)
# # Add lens markers
# hx = cell_size[1] * pshape[1] / gds_unit
# hy = cell_size[0] * pshape[0] / gds_unit
# ms = marker_size / gds_unit
# cell_annot = lib.new_cell(f"TEXT_{unique_id}")
# add_marker_tag(cell_annot, ms, hx, hy)

# Create top-level cell and add references
# # Create top-level cell and add references
top_cell = lib.new_cell(f"TOP_CELL_{unique_id}")
top_cell.add(gdspy.CellReference(cell))
top_cell.add(gdspy.CellReference(cell_annot))
# top_cell.add(gdspy.CellReference(cell_annot))

# Write GDS file
lib.write_gds(savepath)
Expand Down
2 changes: 1 addition & 1 deletion dflat/plot_utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .mp_format import format_plot, add_colorbar, axis_off
from .vis import plot_3d_stack, gif_from_saved_images, video_from_saved_images
from .vis import plot_3d_stack, gif_from_saved_images
48 changes: 25 additions & 23 deletions dflat/plot_utilities/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,36 @@
import numpy as np
import matplotlib.pyplot as plt
from natsort import natsorted
from moviepy.editor import ImageSequenceClip

# from moviepy.editor import ImageSequenceClip
# Moviepy should be removed. It has some bugs

def video_from_saved_images(
filepath, filetag, savename, fps, deleteFrames=True, verbose=False
):
print("Call video generator")
png_files = natsorted(
[
os.path.join(filepath, f)
for f in os.listdir(filepath)
if f.startswith(filetag) and f.endswith(".png")
]
)
if verbose:
for file in png_files:
print("Adding image file as frame: " + file)

clip = ImageSequenceClip(png_files, fps=fps)
clip.write_videofile(
os.path.join(filepath, savename) + ".mp4", codec="libx264", fps=fps
)
# def video_from_saved_images(
# filepath, filetag, savename, fps, deleteFrames=True, verbose=False
# ):
# print("Call video generator")
# png_files = natsorted(
# [
# os.path.join(filepath, f)
# for f in os.listdir(filepath)
# if f.startswith(filetag) and f.endswith(".png")
# ]
# )
# if verbose:
# for file in png_files:
# print("Adding image file as frame: " + file)

if deleteFrames:
for file in png_files:
os.remove(file)
# clip = ImageSequenceClip(png_files, fps=fps)
# clip.write_videofile(
# os.path.join(filepath, savename) + ".mp4", codec="libx264", fps=fps
# )

return
# if deleteFrames:
# for file in png_files:
# os.remove(file)

# return


def gif_from_saved_images(
Expand Down
2 changes: 1 addition & 1 deletion dflat/render/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .fft_convolve import general_convolve, weiner_deconvolve
from .fronto_planar_renderer import Fronto_Planar_Renderer_Incoherent
from .util_meas import hsi_to_rgb, photons_to_ADU
from .util_meas import hsi_to_rgb, photons_to_ADU, rgb_to_hsi_adjoint
22 changes: 12 additions & 10 deletions dflat/render/fft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from dflat.radial_tranforms import resize_with_crop_or_pad




def general_convolve(image, filter, rfft=False, mode="valid"):
def general_convolve(image, filter, rfft=False, mode="valid", adjoint=False):
"""Runs the Fourier space convolution between an image and filter, where the filter kernels may have a different size from the image shape.

Args:
Expand Down Expand Up @@ -40,7 +38,7 @@ def general_convolve(image, filter, rfft=False, mode="valid"):
filter_resh = resize_with_crop_or_pad(filter, *image_shape[-2:], radial_flag=False)

### Run the convolution (Defualt to using a checkpoint of the fourier transform)
image = checkpoint(fourier_convolve, image, filter_resh, rfft)
image = checkpoint(fourier_convolve, image, filter_resh, rfft, adjoint)
image = torch.real(image)

if mode == "valid":
Expand Down Expand Up @@ -90,27 +88,31 @@ def weiner_deconvolve(image, filter, const=1e-4, abs=False):
return image


def fourier_convolve(image, filter, rfft=False):
def fourier_convolve(image, filter, rfft=False, adjoint=False):
"""Computes the convolution of two signals (real or complex) using frequency space multiplcation. Convolution is done over the two inner-most dimensions.

Args:
`image` (float or complex): Image to apply filter to, of shape [..., Ny, Nx]
`filter` (float or complex): Filter kernel; The kernel must be the same shape as the image
`adjoint' (bool, optional): _description_. Defaults to False.

Returns:
complex: Image with filter convolved, same shape as input
"""

# Ensure inputs are complex
TORCH_ZERO = torch.tensor(0.0).to(dtype=image.dtype, device=image.device)
if rfft:
fourier_product = rfft2(ifftshift(image)) * rfft2(ifftshift(filter))
kf = rfft2(ifftshift(filter))
kf = torch.conj(kf) if adjoint else kf
fourier_product = rfft2(ifftshift(image)) * kf
fourier_product = fftshift(irfft2(fourier_product))
else:
image = torch.complex(image, TORCH_ZERO) if not image.is_complex() else image
filter = (
torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter
)
fourier_product = fft2(ifftshift(image)) * fft2(ifftshift(filter))
kf = torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter
kf = fft2(ifftshift(filter))
kf = torch.conj(kf) if adjoint else kf
fourier_product = fft2(ifftshift(image)) * kf
fourier_product = fftshift(ifft2(fourier_product))

return fourier_product
6 changes: 3 additions & 3 deletions dflat/render/fronto_planar_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def _forward(self, psf_intensity, scene_radiance, rfft, crop_to_psf_dim):
## TO be updated and added later
return meas

def rgb_measurement(self, meas, wavelength_set_m, bayer_mosaic=False, gamma=True):
def rgb_measurement(self, meas, wavelength_set_m, process="demosaic", gamma=True):
B, P, Z, L, H, W = meas.shape
meas = hsi_to_rgb(
rearrange(meas, "B P Z L H W -> (B P Z) H W L", B=B, P=P, Z=Z),
wavelength_set_m,
bayer_mosaic,
gamma,
gamma=gamma,
process=process,
)
meas = rearrange(meas, "(B P Z) H W C -> B P Z C H W", B=B, P=P, Z=Z)
return meas
108 changes: 94 additions & 14 deletions dflat/render/util_meas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,62 @@
import torch.nn.functional as F
from einops import rearrange
import numpy as np
import warnings

from dflat.render.util_sensor import get_QETrans_Basler_Bayer
from dflat.render.util_spectral import get_rgb_bar_CIE1931, gamma_correction


def hsi_to_rgb(
hsi,
wavelength_set_m,
demosaic=False,
gamma=False,
tensor_ordering=False,
normalize=True,
projection="Basler_Bayer",
process="ideal",
**kwargs
):
"""Converts a batched hyperspectral datacube of shape [minibatch, Height, Width, Channels] to RGB. If tensor_ordering is true,
input may instead be passed with the more common tensor shape [B, Ch, H, W]. The CIE1931 color matching functions are used by default.

Args:
hsi (float): Hyperspectral cube with shsape [B, H, W, Ch] or [B, Ch, H, W] if tensor_ordering is True.
wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension.
demosaic (bool, optional): If True, a Bayer filter mask is applied to the RGB images and then interpolation is used to match experiment. Defaults to True.
gamma (bool, optional): Applies gamma transformation to the input images. Defaults to True.
tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False.
normalize (bool, optional): If true, the returned projection is max normalized to 1.
projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves.

Returns:
process (str, optional): Either 'ideal', 'raw', 'demosaic'. ideal means return 3 color channels with no spatial resolution loss. Demosaic applies bayer mask and interp, raw returns 1 channel spatial mosaiced measurement.
demosaic(boolean, optional): Deprecated. Now replaced by process field. demosaic true sets process = 'demosaic'.
Returns:
RGB: Stack of images with output channels=3
"""
if "demosaic" in kwargs:
warnings.warn(
"The 'demosaic' argument is deprecated and will be removed in future versions. "
"Please use the 'process' argument instead, setting process='demosaic'.",
DeprecationWarning,
)
if kwargs["demosaic"]:
process = "demosaic"

assert projection in [
"CIE1931",
"Basler_Bayer",
], "Projection must be one of ['CIE1931', 'Basler_Bayer']."
assert process in [
"ideal",
"raw",
"demosaic",
], "Process must be one of ['ideal', 'raw', 'demosaic']."

input_tensor = torch.is_tensor(hsi)
if not input_tensor:
hsi = torch.tensor(hsi)
if tensor_ordering:
hsi = hsi.transpose(-3, -1).transpose(-3, -2).contiguous()

assert (
len(wavelength_set_m) == hsi.shape[-1]
), "List of wavelengths should match the input channel dimension."
Expand All @@ -52,20 +70,27 @@ def hsi_to_rgb(
spec = spec / np.sum(spec, axis=0, keepdims=True)
spec = torch.tensor(spec).type_as(hsi)

rgb = torch.matmul(hsi, spec)
scale = torch.amax(rgb, dim=(-3, -2, -1), keepdim=True)
if normalize:
rgb = rgb / scale
if demosaic:
rgb = bayer_interpolate(bayer_mask(rgb))
out = torch.matmul(hsi, spec)

if process == "demosaic":
out = bayer_interpolate(bayer_mask(out))
elif process == "raw":
out = bayer_mask(out)
out = torch.sum(out, axis=-1, keepdims=True)

if normalize or gamma:
out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True)

if gamma:
rgb = gamma_correction(rgb)
out = gamma_correction(out)

if tensor_ordering:
rgb = rgb.transpose(-3, -1).transpose(-2, -1).contiguous()
out = out.transpose(-3, -1).transpose(-2, -1).contiguous()

if not input_tensor:
rgb = rgb.cpu().numpy()
out = out.cpu().numpy()

return rgb
return out


def bayer_mask(rgb_img):
Expand Down Expand Up @@ -185,3 +210,58 @@ def photons_to_ADU(
return torch.clip(electrons_signal, min=0)
else:
return electrons_signal


def rgb_to_hsi_adjoint(
rgb,
wavelength_set_m,
tensor_ordering=False,
normalize=True,
projection="Basler_Bayer",
):
"""Compute the adjoint approximation of rgb to hsi (used for some algorithm initializations)

Args:
rgb (float): Three channel RGB measurement
wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension.
tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False.
normalize (bool, optional): If true, the returned projection is max normalized to 1.
projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves.

Returns:
float: Hyperspectral initialization
"""

assert projection.lower() in [
"cie1931",
"basler_bayer",
], "Projection must be one of ['cie1931', 'basler_bayer']."

input_tensor = torch.is_tensor(rgb)
if not input_tensor:
rgb = torch.tensor(rgb)

if tensor_ordering:
rgb = rgb.transpose(-3, -1).transpose(-3, -2).contiguous() # ... h w c
assert 3 == rgb.shape[-1], "Channel dimension must be 3 for adjoint transform"

if projection.lower() == "cie1931":
spec = get_rgb_bar_CIE1931(wavelength_set_m * 1e9)
elif projection.lower() == "basler_bayer":
spec, _ = get_QETrans_Basler_Bayer(wavelength_set_m * 1e9)
spec = np.concatenate([spec[:, 0:1], spec[:, 2:]], axis=-1)
spec = spec / np.sum(spec, axis=0, keepdims=True)

spec = torch.tensor(spec).type_as(rgb) # [C, 3]
out = torch.matmul(rgb, spec.T)

if normalize:
out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True)

if tensor_ordering:
out = out.transpose(-3, -1).transpose(-2, -1).contiguous()

if not input_tensor:
out = out.cpu().numpy()

return out
3 changes: 2 additions & 1 deletion docs/api/rcwa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ Public Functions
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

.. automethod:: forward

Loading
Loading