Skip to content

Commit

Permalink
pytorch SVD is inaccurate. Replace Pinv with inv
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Aug 6, 2024
1 parent b498fca commit ead21e9
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 116 deletions.
41 changes: 4 additions & 37 deletions dflat/rcwa/colburn_rcwa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,23 @@ def redheffer_star_product(SA, SB):
I = torch.tile(I, (batchSize, pixelsX, pixelsY, 1, 1))

# Calculate S11.
S11 = thl.pinv(I - thl.matmul(SB["S11"], SA["S22"]))
S11 = thl.inv(I - thl.matmul(SB["S11"], SA["S22"]))
S11 = thl.matmul(S11, SB["S11"])
S11 = thl.matmul(SA["S12"], S11)
S11 = SA["S11"] + thl.matmul(S11, SA["S21"])

# Calculate S12.
S12 = thl.pinv(I - thl.matmul(SB["S11"], SA["S22"]))
S12 = thl.inv(I - thl.matmul(SB["S11"], SA["S22"]))
S12 = thl.matmul(S12, SB["S12"])
S12 = thl.matmul(SA["S12"], S12)

# Calculate S21.
S21 = thl.pinv(I - thl.matmul(SA["S22"], SB["S11"]))
S21 = thl.inv(I - thl.matmul(SA["S22"], SB["S11"]))
S21 = thl.matmul(S21, SA["S21"])
S21 = thl.matmul(SB["S21"], S21)

# Calculate S22.
S22 = thl.pinv(I - thl.matmul(SA["S22"], SB["S11"]))
S22 = thl.inv(I - thl.matmul(SA["S22"], SB["S11"]))
S22 = thl.matmul(S22, SA["S22"])
S22 = thl.matmul(SB["S21"], S22)
S22 = SB["S22"] + thl.matmul(S22, SB["S12"])
Expand All @@ -132,36 +132,3 @@ def redheffer_star_product(SA, SB):
S["S22"] = S22

return S


def complex_pseudoinverse(tensor, rcond=1e-15):
dtype = tensor.dtype
if not torch.is_complex(tensor):
return thl.thl.pinv(tensor)

# Compute SVD
s, u, v = thl.svd(tensor, full_matrices=False)

# Compute the reciprocal of singular values
cutoff = rcond * torch.max(s)
s_inv = torch.where(s > cutoff, torch.reciprocal(s), 0.0)

# Cast the reciprocal of singular values to complex
s_inv_complex = s_inv.to(dtype=dtype)

# Compute the pseudoinverse
pseudo_inv_s = tensor_utils.diag_batched(s_inv_complex)
pseudo_inv = torch.thl.matmul(v, torch.thl.matmul(pseudo_inv_s, torch.adjoint(u)))

return pseudo_inv


def batch_regularized_inverse(matrix, alpha=0.1):
if matrix.shape[-1] != matrix.shape[-2]:
raise ValueError("The last two dimensions of the input tensor must be square.")

matrix_shape = matrix.shape[:-2]
identity = torch.eye(matrix.shape[-1], dtype=matrix.dtype, device=matrix.device)
identity = identity.expand(*matrix_shape, matrix.shape[-2], matrix.shape[-1])
regularized_matrix = matrix + alpha * identity
return complex_pseudoinverse(regularized_matrix)
122 changes: 92 additions & 30 deletions dflat/rcwa/colburn_tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,73 @@


def diag_batched(
diagonal, k=0, num_rows=-1, num_cols=-1, padding_value=0, align="RIGHT_LEFT"
diagonal, k=(0,), num_rows=None, num_cols=None, padding_value=0, align="RIGHT_LEFT"
):
batch_shape = diagonal.shape[:-1]
diag_len = diagonal.shape[-1]

if num_rows == -1:
num_rows = diag_len + max(k, 0)
if num_cols == -1:
num_cols = diag_len - min(k, 0)
if not torch.is_tensor(diagonal):
diagonal = torch.tensor(diagonal)

if isinstance(k, int):
k = (k,)

ddim = diagonal.dim()
if ddim == 1:
diagonal = diagonal[None]

# infer the dimensions and potential return square matrix enlarged
M = diagonal.shape[-1]
if num_rows is None and num_cols is None:
max_diag_len = M + max(abs(min(k)), abs(max(k)))
num_rows = num_cols = max_diag_len
if num_rows is None:
num_rows = max(M + max(k), 1)
if num_cols is None:
num_cols = max(M - min(k), 1)

# hold output
num_diagonals = len(k)
if num_diagonals == 1:
batch_size = diagonal.shape[:-1]
else:
batch_size = diagonal.shape[:-2]

result_shape = batch_shape + (num_rows, num_cols)
result = torch.full(
result_shape, padding_value, dtype=diagonal.dtype, device=diagonal.device
output_shape = (*batch_size, num_rows, num_cols)
output = torch.full(
output_shape, padding_value, dtype=diagonal.dtype, device=diagonal.device
)

if k >= 0:
diag_idx = torch.arange(
min(diag_len, num_rows, num_cols - k), device=diagonal.device
)
if align == "RIGHT_LEFT":
result[..., diag_idx, diag_idx + k] = diagonal[..., : len(diag_idx)]
else:
result[..., -len(diag_idx) :, k : k + len(diag_idx)] = diagonal[
..., : len(diag_idx)
]
if diagonal.dim() > 2:
padby = (0, max(0, num_cols - M), 0, max(0, num_rows - M))
diagonal = torch.nn.functional.pad(diagonal, padby)
else:
diag_idx = torch.arange(
min(diag_len, num_rows + k, num_cols), device=diagonal.device
)
if align == "RIGHT_LEFT":
result[..., diag_idx - k, diag_idx] = diagonal[..., : len(diag_idx)]
padby = (0, max(0, num_rows - M))
diagonal = torch.nn.functional.pad(diagonal, padby)

####
## Simplify the bottom if possible ##
for i, d in enumerate(k):
diag_len = min(num_cols - max(d, 0), num_rows + min(d, 0))

if align in {"RIGHT_LEFT", "RIGHT_RIGHT"}:
offset = max(0, M - diag_len)
else:
result[..., -len(diag_idx) - k :, : len(diag_idx)] = diagonal[
..., : len(diag_idx)
offset = 0

if num_diagonals == 1:
diag_values = diagonal[..., offset : offset + diag_len]
else:
diag_values = diagonal[..., i, offset : offset + diag_len]

indices = torch.arange(diag_len)
if d >= 0:
output[..., indices, indices + d] = diag_values[..., :diag_len]
else:
row_indices = indices - d
valid_mask = row_indices < num_rows
output[..., row_indices[valid_mask], indices[valid_mask]] = diag_values[
..., : len(indices[valid_mask])
]

return result
return output


def expand_and_tile_np(array, batchSize, pixelsX, pixelsY):
Expand Down Expand Up @@ -117,7 +147,7 @@ def backward(ctx, grad_D, grad_U):
grad_A = F.conj() * torch.matmul(torch.adjoint(U), grad_U)
grad_A = grad_D + grad_A
grad_A = torch.matmul(grad_A, torch.adjoint(U))
grad_A = torch.matmul(torch.linalg.pinv(torch.adjoint(U)), grad_A)
grad_A = torch.matmul(torch.linalg.inv(torch.adjoint(U)), grad_A)
return grad_A, None


Expand Down Expand Up @@ -150,3 +180,35 @@ def eig_general(A, eps=1e-6):
eigendecompostion of the input argument `A`.
"""
return EigGeneralFunction.apply(A, eps)


# if __name__ == "__main__":
# # thetas = [1.0 for i in range(5)]
# # theta = torch.tensor(thetas, dtype=torch.float32)
# # theta = theta[:, None, None, None, None, None]
# # pixelsX, pixelsY = 1, 1
# # theta = torch.tile(theta, dims=(1, pixelsX, pixelsY, 1, 5, 1))
# # kx_T = torch.permute(theta, [0, 1, 2, 3, 5, 4])
# # KX = diag_batched(kx_T)

# # # Test 1: Main diagonal; k=(0,) diagonal- 2x3
# diagonal = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
# output = diag_batched(diagonal).cpu().numpy()
# print(output, output.shape, "\n")

# # # Test 2: Superdiagonal; k=(1,), diagonal- 2x3
# diagonal = np.array([[1, 2, 3], [4, 5, 6]])
# output = diag_batched(diagonal, k=1).cpu().numpy()
# print(output, "\n")

# # Test 3: Tridiagonal band; k = (-1, 1), diagonal- 2x3x3
# diagonals = np.array(
# [[[7, 8, 9], [4, 5, 6], [1, 2, 3]], [[16, 17, 18], [13, 14, 15], [10, 11, 12]]]
# )
# output = diag_batched(diagonals, k=(-1, 0, 1)).cpu().numpy()
# print(output, "\n")

# # test rectangular matrix
# diagonals = np.array([[1, 2]])
# output = diag_batched(diagonals, k=-1)
# print(output)
Loading

0 comments on commit ead21e9

Please sign in to comment.