Skip to content

Commit

Permalink
bugfixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Oct 8, 2024
1 parent 4cadf16 commit 2227051
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
29 changes: 19 additions & 10 deletions meteor/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def negentropy_objective(tv_lambda: float):
else:
return final_map_coefficients

def _form_complex_sf(amplitudes: rs.DataSeries, phases_in_deg: rs.DataSeries) -> rs.DataSeries:
return amplitudes * np.exp(1j * np.deg2rad(phases_in_deg))

def _complex_argument(complex: rs.DataSeries) -> rs.DataSeries:
return complex.apply(np.angle).apply(np.rad2deg)


def _dataseries_l1_norm(
series1: rs.DataSeries,
Expand All @@ -194,19 +200,19 @@ def _projected_derivative_phase(
native_amplitudes: rs.DataSeries,
native_phases: rs.DataSeries,
) -> rs.DataSeries:
complex_difference = difference_amplitudes * np.exp(difference_phases)
complex_native = native_amplitudes * np.exp(native_phases)
complex_difference = _form_complex_sf(difference_amplitudes, difference_phases)
complex_native = _form_complex_sf(native_amplitudes, native_phases)
complex_derivative_estimate = complex_difference + complex_native
return complex_derivative_estimate.apply(np.angle).apply(np.rad2deg)
return _complex_argument(complex_derivative_estimate)


def iterative_tv_phase_retrieval(
*,
input_dataset: rs.DataSet,
native_amplitude_column: str = "F",
derivative_amplitude_column: str = "FH",
derivative_amplitude_column: str = "Fh",
calculated_phase_column: str = "PHIC",
output_derivative_phase_column: str = "PHICH",
output_derivative_phase_column: str = "PHICh",
convergence_tolerance: float = 0.01,
) -> rs.DataSet:
"""
Expand All @@ -232,7 +238,8 @@ def iterative_tv_phase_retrieval(
threshold.
"""

# TODO should these be adjustable input params?
# TODO work on below for readability
# TODO should these be adjustable input params? not returned?
difference_amplitude_column: str = "DF"
difference_phase_column: str = "DPHIC"

Check warning on line 244 in meteor/tv.py

View check run for this annotation

Codecov / codecov/patch

meteor/tv.py#L243-L244

Added lines #L243 - L244 were not covered by tests

Expand Down Expand Up @@ -290,16 +297,18 @@ def iterative_tv_phase_retrieval(
native_phases=working_ds[calculated_phase_column],
)

current_complex_native = working_ds[native_amplitude_column] * np.exp(
# TODO encapsulate block below into function
current_complex_native = _form_complex_sf(

Check warning on line 301 in meteor/tv.py

View check run for this annotation

Codecov / codecov/patch

meteor/tv.py#L301

Added line #L301 was not covered by tests
working_ds[native_amplitude_column],
working_ds[calculated_phase_column]
)
current_complex_derivative = working_ds[derivative_amplitude_column] * np.exp(
current_complex_derivative = _form_complex_sf(

Check warning on line 305 in meteor/tv.py

View check run for this annotation

Codecov / codecov/patch

meteor/tv.py#L305

Added line #L305 was not covered by tests
working_ds[derivative_amplitude_column],
working_ds[output_derivative_phase_column]
)
current_complex_difference = current_complex_derivative - current_complex_native

working_ds[difference_amplitude_column] = np.abs(current_complex_difference)
working_ds[difference_phase_column] = np.rad2deg(np.angle(current_complex_difference))
working_ds[difference_phase_column] = _complex_argument(current_complex_difference)

Check warning on line 311 in meteor/tv.py

View check run for this annotation

Codecov / codecov/patch

meteor/tv.py#L309-L311

Added lines #L309 - L311 were not covered by tests

canonicalize_amplitudes(

Check warning on line 313 in meteor/tv.py

View check run for this annotation

Codecov / codecov/patch

meteor/tv.py#L313

Added line #L313 was not covered by tests
working_ds,
Expand Down
28 changes: 27 additions & 1 deletion test/unit/test_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,30 @@ def test_dataseries_l1_norm_no_overlapping_indices() -> None:
tv._dataseries_l1_norm(series1, series2)


def test_phase_of_projection_to_experimental_set() -> None: ...
def test_projected_derivative_phase_identical_phases() -> None:
hkls = [0,1,2]
phases = rs.DataSeries([0.0, 30.0, 60.0], index=hkls)
amplitudes = rs.DataSeries([1., 1., 1.], index=hkls)

derivative_phases = tv._projected_derivative_phase(
difference_amplitudes=amplitudes,
difference_phases=phases,
native_amplitudes=amplitudes,
native_phases=phases
)
np.testing.assert_almost_equal(phases.to_numpy(), derivative_phases.to_numpy())


def test_projected_derivative_phase_opposite_phases() -> None:
hkls = [0,1,2]
native_phases = rs.DataSeries([0.0, 30.0, 60.0], index=hkls)

# if DF = 0, then derivative and native phase should be the same
derivative_phases = tv._projected_derivative_phase(
difference_amplitudes=rs.DataSeries([0., 0., 0.], index=hkls),
difference_phases=native_phases,
native_amplitudes=rs.DataSeries([1., 1., 1.], index=hkls),
native_phases=native_phases
)
np.testing.assert_almost_equal(derivative_phases.to_numpy(), native_phases.to_numpy())

0 comments on commit 2227051

Please sign in to comment.