diff --git a/meteor/tv.py b/meteor/tv.py index 8486296..3ab03d9 100644 --- a/meteor/tv.py +++ b/meteor/tv.py @@ -8,7 +8,7 @@ from typing import Sequence from .validate import negentropy -from .utils import compute_map_from_coefficients, compute_coefficients_from_map +from .utils import compute_map_from_coefficients, compute_coefficients_from_map, resolution_limits from .settings import ( TV_LAMBDA_RANGE, TV_STOP_TOLERANCE, @@ -73,14 +73,20 @@ def negentropy_objective(tv_lambda: float): optimizer_result = minimize_scalar( negentropy_objective, bracket=TV_LAMBDA_RANGE, method="golden" ) - assert optimizer_result.success + assert optimizer_result.success, "Golden minimization failed to find optimal TV lambda" optimal_lambda = optimizer_result.x final_map_array = _tv_denoise_ccp4_map(map=difference_map, weight=optimal_lambda) # TODO: verify correctness - final_map_coefficients = (np.fft.fftn(final_map_array)[::TV_MAP_SAMPLING, ::TV_MAP_SAMPLING, ::TV_MAP_SAMPLING]).flatten() - print(final_map_coefficients.shape, len(difference_map_coefficients)) + + _, high_resolution_limit = resolution_limits(difference_map_coefficients) + final_map_coefficients = compute_coefficients_from_map( + map=final_map_array, + high_resolution_limit=high_resolution_limit, + amplitude_label=TV_AMPLITUDE_LABEL, + phase_label=TV_PHASE_LABEL, + ) # TODO: need to be sure HKLs line up difference_map_coefficients[[TV_AMPLITUDE_LABEL]] = np.abs(final_map_coefficients) diff --git a/meteor/utils.py b/meteor/utils.py index 4943f8b..e69213c 100644 --- a/meteor/utils.py +++ b/meteor/utils.py @@ -1,6 +1,7 @@ import numpy as np import gemmi as gm import reciprocalspaceship as rs +from typing import overload, Literal, Union def resolution_limits(dataset: rs.DataSet) -> tuple[float, float]: @@ -17,34 +18,45 @@ def cut_resolution(dataset: rs.DataSet, *, dmax_limit: float | None = None, dmin return dataset +@overload def canonicalize_amplitudes( dataset: rs.DataSet, amplitude_label: str, phase_label: str, + inplace: Literal[False], ) -> rs.DataSet: - # TODO review and improve - # I think we can infer phase types + ... - new_phis = dataset[phase_label].copy(deep=True) - new_Fs = dataset[amplitude_label].copy(deep=True) - negs = np.where(dataset[amplitude_label] < 0) +@overload +def canonicalize_amplitudes( + dataset: rs.DataSet, + amplitude_label: str, + phase_label: str, + inplace: Literal[True], + ) -> None: + ... - dataset.canonicalize_phases(inplace=True) - for i in negs: - new_phis.iloc[i] = dataset[phase_label].iloc[i] + 180 - new_Fs.iloc[i] = np.abs(new_Fs.iloc[i]) +def canonicalize_amplitudes( + dataset: rs.DataSet, + amplitude_label: str, + phase_label: str, + inplace: bool = False, + ) -> rs.DataSet | None: - new_phis.canonicalize_phases(inplace=True) + dataset.canonicalize_phases(inplace=inplace) + if not inplace: + dataset = dataset.copy(deep=True) - df_new = dataset.copy(deep=True) - df_new[amplitude_label] = new_Fs - df_new[amplitude_label] = df_new[amplitude_label].astype("SFAmplitude") - df_new[phase_label] = new_phis - df_new[phase_label] = df_new[phase_label].astype("Phase") + negative_amplitude_indices = dataset[amplitude_label] < 0.0 + dataset[amplitude_label] = np.abs(dataset[amplitude_label]) + dataset.loc[negative_amplitude_indices, phase_label] += 180.0 - return df_new + if not inplace: + return dataset + else: + return None def compute_map_from_coefficients( @@ -79,7 +91,7 @@ def compute_coefficients_from_map( amplitude_label=amplitude_label, phase_label=phase_label ) - elif isinstance(map, np.ndarray): + elif isinstance(map, gm.Ccp4Map): return _compute_coefficients_from_ccp4_map( ccp4_map=map, high_resolution_limit=high_resolution_limit, @@ -111,7 +123,7 @@ def _compute_coefficients_from_ccp4_map( # to ensure we include the final shell of reflections, add a small buffer to the resolution high_resolution_buffer = 0.05 - gemmi_structure_factors = gm.transform_map_to_f_phi(map.grid, half_l=False) + gemmi_structure_factors = gm.transform_map_to_f_phi(ccp4_map.grid, half_l=False) data = gemmi_structure_factors.prepare_asu_data( dmin=high_resolution_limit - high_resolution_buffer, with_sys_abs=True ) diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..bcbc16d --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,31 @@ +from pytest import fixture +import reciprocalspaceship as rs +import numpy as np +import gemmi as gm + + +@fixture +def random_intensities() -> rs.DataSet: + """ + A simple 10x10x10 P1 dataset, with random intensities + """ + + params = (10.0, 10.0, 10.0, 90.0, 90.0, 90.0) + cell = gm.UnitCell(*params) + sg_1 = gm.SpaceGroup(1) + Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 1.0, anomalous=False) + + h, k, l = Hall.T + ds = rs.DataSet( + { + "H": h, + "K": k, + "L": l, + "IMEAN": np.abs(np.random.randn(len(h))), + }, + spacegroup=sg_1, + cell=cell, + ).infer_mtz_dtypes() + ds.set_index(["H", "K", "L"], inplace=True) + + return ds \ No newline at end of file diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 877d0c2..9ad9fda 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -1,6 +1,9 @@ from meteor import utils import reciprocalspaceship as rs import pytest +import gemmi as gm +import pandas as pd +import numpy as np def test_resolution_limits(random_intensities: rs.DataSet) -> None: @@ -38,3 +41,73 @@ def test_cut_resolution(random_intensities: rs.DataSet, dmax_limit: float, dmin_ dmax, dmin = utils.resolution_limits(random_intensities) assert dmax <= expected_max_dmax assert dmin >= expected_min_dmin + +@pytest.mark.parametrize("inplace", [False, True]) +def test_canonicalize_amplitudes(inplace: bool, flat_difference_map: rs.DataSet) -> None: + amplitude_label = "DF" + phase_label = "PHIC" + + if inplace: + canonicalized = flat_difference_map.copy(deep=True) + utils.canonicalize_amplitudes( + canonicalized, + amplitude_label=amplitude_label, + phase_label=phase_label, + inplace=inplace + ) + else: + canonicalized = utils.canonicalize_amplitudes( + flat_difference_map, + amplitude_label=amplitude_label, + phase_label=phase_label, + inplace=inplace + ) + + assert (canonicalized[amplitude_label] >= 0.0).all() + assert (canonicalized[phase_label] >= -180.0).all() + assert (canonicalized[phase_label] <= 180.0).all() + + np.testing.assert_almost_equal( + np.array(np.abs(flat_difference_map[amplitude_label])), + np.array(canonicalized[amplitude_label]) + ) + +def test_compute_map_from_coefficients(flat_difference_map: rs.DataSet) -> None: + map = utils.compute_map_from_coefficients( + map_coefficients=flat_difference_map, + amplitude_label="DF", + phase_label="PHIC", + map_sampling=1, + ) + assert isinstance(map, gm.Ccp4Map) + assert map.grid.shape == (6,6,6) + + +# def test_map_round_trip_ccp4_format(flat_difference_map: rs.DataSet) -> None: +# amplitude_label = "DF" +# phase_label = "PHIC" +# map_sampling = 1 + +# flat_difference_map = utils.canonicalize_amplitudes( +# flat_difference_map, +# amplitude_label=amplitude_label, +# phase_label=phase_label +# ) + +# map = utils.compute_map_from_coefficients( +# map_coefficients=flat_difference_map, +# amplitude_label=amplitude_label, +# phase_label=phase_label, +# map_sampling=map_sampling, +# ) + +# _, dmin = utils.resolution_limits(flat_difference_map) + +# output_coefficients = utils.compute_coefficients_from_map( +# map=map, +# high_resolution_limit=dmin, +# amplitude_label=amplitude_label, +# phase_label=phase_label, +# ) + +# pd.testing.assert_frame_equal(left=flat_difference_map, right=output_coefficients, check_exact=False)