Skip to content

Commit

Permalink
test mod
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Jan 9, 2025
1 parent 8c6a06a commit b58fdc4
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 27 deletions.
20 changes: 11 additions & 9 deletions dflat/GDSII/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,27 @@ def assemble_standard_shapes(
if cell_fun == gdspy.Round:
if len(shape_params) == 1:
shape_params = [shape_params[0], shape_params[0]]
shape = cell_fun((xoffset, yoffset), shape_params)
shape = cell_fun(
(xoffset, yoffset), shape_params, number_of_points=number_of_points
)
elif cell_fun == gdspy.Rectangle:
shape_params += [xoffset, yoffset]
shape = cell_fun((xoffset, yoffset), shape_params)
else:
raise ValueError
cell.add(shape)

# Add lens markers
hx = cell_size[1] * pshape[1] / gds_unit
hy = cell_size[0] * pshape[0] / gds_unit
ms = marker_size / gds_unit
cell_annot = lib.new_cell(f"TEXT_{unique_id}")
add_marker_tag(cell_annot, ms, hx, hy)
# # Add lens markers
# hx = cell_size[1] * pshape[1] / gds_unit
# hy = cell_size[0] * pshape[0] / gds_unit
# ms = marker_size / gds_unit
# cell_annot = lib.new_cell(f"TEXT_{unique_id}")
# add_marker_tag(cell_annot, ms, hx, hy)

# Create top-level cell and add references
# # Create top-level cell and add references
top_cell = lib.new_cell(f"TOP_CELL_{unique_id}")
top_cell.add(gdspy.CellReference(cell))
top_cell.add(gdspy.CellReference(cell_annot))
# top_cell.add(gdspy.CellReference(cell_annot))

# Write GDS file
lib.write_gds(savepath)
Expand Down
22 changes: 12 additions & 10 deletions dflat/render/fft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
from dflat.radial_tranforms import resize_with_crop_or_pad




def general_convolve(image, filter, rfft=False, mode="valid"):
def general_convolve(image, filter, rfft=False, mode="valid", adjoint=False):
"""Runs the Fourier space convolution between an image and filter, where the filter kernels may have a different size from the image shape.
Args:
Expand Down Expand Up @@ -40,7 +38,7 @@ def general_convolve(image, filter, rfft=False, mode="valid"):
filter_resh = resize_with_crop_or_pad(filter, *image_shape[-2:], radial_flag=False)

### Run the convolution (Defualt to using a checkpoint of the fourier transform)
image = checkpoint(fourier_convolve, image, filter_resh, rfft)
image = checkpoint(fourier_convolve, image, filter_resh, rfft, adjoint)
image = torch.real(image)

if mode == "valid":
Expand Down Expand Up @@ -90,27 +88,31 @@ def weiner_deconvolve(image, filter, const=1e-4, abs=False):
return image


def fourier_convolve(image, filter, rfft=False):
def fourier_convolve(image, filter, rfft=False, adjoint=False):
"""Computes the convolution of two signals (real or complex) using frequency space multiplcation. Convolution is done over the two inner-most dimensions.
Args:
`image` (float or complex): Image to apply filter to, of shape [..., Ny, Nx]
`filter` (float or complex): Filter kernel; The kernel must be the same shape as the image
`adjoint' (bool, optional): _description_. Defaults to False.
Returns:
complex: Image with filter convolved, same shape as input
"""

# Ensure inputs are complex
TORCH_ZERO = torch.tensor(0.0).to(dtype=image.dtype, device=image.device)
if rfft:
fourier_product = rfft2(ifftshift(image)) * rfft2(ifftshift(filter))
kf = rfft2(ifftshift(filter))
kf = torch.conj(kf) if adjoint else kf
fourier_product = rfft2(ifftshift(image)) * kf
fourier_product = fftshift(irfft2(fourier_product))
else:
image = torch.complex(image, TORCH_ZERO) if not image.is_complex() else image
filter = (
torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter
)
fourier_product = fft2(ifftshift(image)) * fft2(ifftshift(filter))
kf = torch.complex(filter, TORCH_ZERO) if not filter.is_complex() else filter
kf = fft2(ifftshift(filter))
kf = torch.conj(kf) if adjoint else kf
fourier_product = fft2(ifftshift(image)) * kf
fourier_product = fftshift(ifft2(fourier_product))

return fourier_product
82 changes: 75 additions & 7 deletions dflat/render/util_meas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,42 @@
def hsi_to_rgb(
hsi,
wavelength_set_m,
demosaic=False,
gamma=False,
tensor_ordering=False,
normalize=True,
projection="Basler_Bayer",
process="ideal",
):
"""Converts a batched hyperspectral datacube of shape [minibatch, Height, Width, Channels] to RGB. If tensor_ordering is true,
input may instead be passed with the more common tensor shape [B, Ch, H, W]. The CIE1931 color matching functions are used by default.
Args:
hsi (float): Hyperspectral cube with shsape [B, H, W, Ch] or [B, Ch, H, W] if tensor_ordering is True.
wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension.
demosaic (bool, optional): If True, a Bayer filter mask is applied to the RGB images and then interpolation is used to match experiment. Defaults to True.
gamma (bool, optional): Applies gamma transformation to the input images. Defaults to True.
tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False.
normalize (bool, optional): If true, the returned projection is max normalized to 1.
projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves.
Returns:
process (str, optional): Either 'idea', 'raw', 'demosaic'. ideal means return 3 color channels with no spatial resolution loss. Demosaic applies bayer mask and interp, raw returns 1 channel spatial mosaiced measurement.
Returns:
RGB: Stack of images with output channels=3
"""
assert projection in [
"CIE1931",
"Basler_Bayer",
], "Projection must be one of ['CIE1931', 'Basler_Bayer']."
assert process in [
"ideal",
"raw",
"demosaic",
], "Process must be one of ['ideal', 'raw', 'demosaic']."

input_tensor = torch.is_tensor(hsi)
if not input_tensor:
hsi = torch.tensor(hsi)
if tensor_ordering:
hsi = hsi.transpose(-3, -1).transpose(-3, -2).contiguous()

assert (
len(wavelength_set_m) == hsi.shape[-1]
), "List of wavelengths should match the input channel dimension."
Expand All @@ -54,14 +59,22 @@ def hsi_to_rgb(

rgb = torch.matmul(hsi, spec)
scale = torch.amax(rgb, dim=(-3, -2, -1), keepdim=True)
if normalize:

if process == "demosaic":
out = bayer_interpolate(bayer_mask(out))
elif process == "raw":
out = bayer_mask(out)
out = torch.sum(out, axis=-1, keepdims=True)

if normalize or gamma:
rgb = rgb / scale
if demosaic:
rgb = bayer_interpolate(bayer_mask(rgb))

if gamma:
rgb = gamma_correction(rgb)

if tensor_ordering:
rgb = rgb.transpose(-3, -1).transpose(-2, -1).contiguous()

if not input_tensor:
rgb = rgb.cpu().numpy()

Expand Down Expand Up @@ -185,3 +198,58 @@ def photons_to_ADU(
return torch.clip(electrons_signal, min=0)
else:
return electrons_signal


def rgb_to_hsi_adjoint(
rgb,
wavelength_set_m,
tensor_ordering=False,
normalize=True,
projection="Basler_Bayer",
):
"""Compute the adjoint approximation of rgb to hsi (used for some algorithm initializations)
Args:
rgb (float): Three channel RGB measurement
wavelength_set_m (float): List of wavelengths corresponding to the input channel dimension.
tensor_ordering (bool, optional): If True, allows passing in a HSI with the more covenient pytorch to_tensor form. Defaults to False.
normalize (bool, optional): If true, the returned projection is max normalized to 1.
projection (str, optional): Either "CIE1931" or "Basler_Bayer". Specifies the color spectral curves.
Returns:
float: Hyperspectral initialization
"""

assert projection.lower() in [
"cie1931",
"basler_bayer",
], "Projection must be one of ['cie1931', 'basler_bayer']."

input_tensor = torch.is_tensor(rgb)
if not input_tensor:
rgb = torch.tensor(rgb)

if tensor_ordering:
rgb = rgb.transpose(-3, -1).transpose(-3, -2).contiguous() # ... h w c
assert 3 == rgb.shape[-1], "Channel dimension must be 3 for adjoint transform"

if projection.lower() == "cie1931":
spec = get_rgb_bar_CIE1931(wavelength_set_m * 1e9)
elif projection.lower() == "basler_bayer":
spec, _ = get_QETrans_Basler_Bayer(wavelength_set_m * 1e9)
spec = np.concatenate([spec[:, 0:1], spec[:, 2:]], axis=-1)
spec = spec / np.sum(spec, axis=0, keepdims=True)

spec = torch.tensor(spec).type_as(rgb) # [C, 3]
out = torch.matmul(rgb, spec.T)

if normalize:
out = out / torch.amax(out, dim=(-3, -2, -1), keepdim=True)

if tensor_ordering:
out = out.transpose(-3, -1).transpose(-2, -1).contiguous()

if not input_tensor:
out = out.cpu().numpy()

return out
3 changes: 2 additions & 1 deletion docs/api/rcwa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ Public Functions
:members:
:undoc-members:
:show-inheritance:
:inherited-members:

.. automethod:: forward

0 comments on commit b58fdc4

Please sign in to comment.