Skip to content

Commit

Permalink
checkpoint before moving to complex model
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Oct 8, 2024
1 parent 9f06faf commit 99d99d7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
19 changes: 12 additions & 7 deletions meteor/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from .utils import canonicalize_amplitudes


def _form_complex_sf(amplitudes: rs.DataSeries, phases_in_deg: rs.DataSeries) -> rs.DataSeries:
def _form_complex_sf(amplitudes: rs.DataSeries, phases_in_deg: rs.DataSeries) -> np.ndarray:
expi = lambda x: np.exp(1j * np.deg2rad(x)) # noqa: E731
return amplitudes * phases_in_deg.apply(expi)
return amplitudes.to_numpy().astype(np.complex128) * expi(phases_in_deg.to_numpy().astype(np.float64))

Check failure on line 10 in meteor/iterative.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E501)

meteor/iterative.py:10:101: E501 Line too long (106 > 100)


def _complex_argument(complex: rs.DataSeries) -> rs.DataSeries:
Expand Down Expand Up @@ -35,6 +35,8 @@ def _projected_derivative_phase(
complex_difference = _form_complex_sf(difference_amplitudes, difference_phases)
complex_native = _form_complex_sf(native_amplitudes, native_phases)
complex_derivative_estimate = complex_difference + complex_native
complex_derivative_estimate = rs.DataSeries(complex_derivative_estimate, index=native_amplitudes.index)

Check failure on line 38 in meteor/iterative.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E501)

meteor/iterative.py:38:101: E501 Line too long (107 > 100)
print(complex_derivative_estimate)
return _complex_argument(complex_derivative_estimate)


Expand Down Expand Up @@ -116,17 +118,17 @@ def iterative_tv_phase_retrieval(
difference_map_amplitude_column=difference_amplitude_column,
difference_map_phase_column=difference_phase_column,
lambda_values_to_scan=[
0.001,
0.1,
],
full_output=True,
)
print("***", result.optimal_lambda)

change_in_DF = _dataseries_l1_norm( # noqa: N806
working_ds[difference_amplitude_column], # previous iteration
DF_prime[difference_amplitude_column], # current iteration
)
converged = change_in_DF < convergence_tolerance
print("***", result.optimal_negentropy, change_in_DF)

# update working_ds, NB native and derivative amplitudes & native phases stay the same
working_ds[output_derivative_phase_column] = _projected_derivative_phase(
Expand All @@ -143,10 +145,13 @@ def iterative_tv_phase_retrieval(
current_complex_derivative = _form_complex_sf(
working_ds[derivative_amplitude_column], working_ds[output_derivative_phase_column]
)

current_complex_difference = current_complex_derivative - current_complex_native
working_ds[difference_amplitude_column] = np.abs(current_complex_difference).astype(
rs.StructureFactorAmplitudeDtype()
)
working_ds[difference_amplitude_column] = rs.DataSeries(
np.abs(current_complex_difference),
index=working_ds.index,
name=difference_amplitude_column
).astype(rs.StructureFactorAmplitudeDtype())
working_ds[difference_phase_column] = _complex_argument(current_complex_difference)

canonicalize_amplitudes(
Expand Down
9 changes: 3 additions & 6 deletions meteor/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class TvDenoiseResult:

def _tv_denoise_array(*, map_as_array: np.ndarray, weight: float) -> np.ndarray:
"""Closure convienence function to generate more readable code."""
if weight < 0.0:
raise ValueError("TV weight < 0 requested, something went wrong")
denoised_map = denoise_tv_chambolle(
map_as_array,
weight=weight,
Expand Down Expand Up @@ -77,7 +79,6 @@ def tv_denoise_difference_map(
2. Alternatively, an explicit list of lambda values to assess can be provided using
`lambda_values_to_scan`.
Parameters
----------
difference_map_coefficients : rs.DataSet
Expand Down Expand Up @@ -124,7 +125,6 @@ def tv_denoise_difference_map(
>>> coefficients = rs.read_mtz("./path/to/difference_map.mtz") # load dataset
>>> denoised_map, result = tv_denoise_difference_map(coefficients, full_output=True)
>>> print(f"Optimal Lambda: {result.optimal_lambda}, Negentropy: {result.optimal_negentropy}")
"""
difference_map = compute_map_from_coefficients(
map_coefficients=difference_map_coefficients,
Expand All @@ -143,10 +143,7 @@ def negentropy_objective(tv_lambda: float):
maximizer.optimize_over_explicit_values(arguments_to_scan=lambda_values_to_scan)
else:
maximizer.optimize_with_golden_algorithm(bracket=TV_LAMBDA_RANGE)

if maximizer.argument_optimum < 0.0:
raise RuntimeError("optimal TV denoising parameter is negative, something went wrong")


Check failure on line 146 in meteor/tv.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (W293)

meteor/tv.py:146:1: W293 Blank line contains whitespace
# denoise using the optimized parameters and convert to an rs.DataSet
final_map = _tv_denoise_array(
map_as_array=difference_map_as_array, weight=maximizer.argument_optimum
Expand Down
8 changes: 4 additions & 4 deletions test/unit/test_iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def test_projected_derivative_phase_opposite_phases() -> None:
np.testing.assert_almost_equal(derivative_phases.to_numpy(), native_phases.to_numpy())


def test_iterative_tv(displaced_atom_two_datasets_noise_free: rs.DataSet) -> None:
result = iterative.iterative_tv_phase_retrieval(displaced_atom_two_datasets_noise_free)
def test_iterative_tv(displaced_atom_two_datasets_noisy: rs.DataSet) -> None:
result = iterative.iterative_tv_phase_retrieval(displaced_atom_two_datasets_noisy)
for label in ["F", "Fh"]:
pdt.assert_series_equal(
result[label], displaced_atom_two_datasets_noise_free[label], atol=1e-3
result[label], displaced_atom_two_datasets_noisy[label], atol=1e-3
)
assert_phases_allclose(
result["PHIC"], displaced_atom_two_datasets_noise_free["PHIC"], atol=1e-3
result["PHIC"], displaced_atom_two_datasets_noisy["PHIC"], atol=1e-3
)
# assert_phases_allclose(
# result["PHICh"], displaced_atom_two_datasets_noise_free["PHICh_ground_truth"], atol=1e-3
Expand Down

0 comments on commit 99d99d7

Please sign in to comment.