Skip to content

Commit

Permalink
more utils fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
tjlane committed Aug 21, 2024
1 parent aca7d8b commit a761c6d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 22 deletions.
14 changes: 10 additions & 4 deletions meteor/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Sequence

from .validate import negentropy
from .utils import compute_map_from_coefficients, compute_coefficients_from_map
from .utils import compute_map_from_coefficients, compute_coefficients_from_map, resolution_limits
from .settings import (
TV_LAMBDA_RANGE,
TV_STOP_TOLERANCE,
Expand Down Expand Up @@ -73,14 +73,20 @@ def negentropy_objective(tv_lambda: float):
optimizer_result = minimize_scalar(
negentropy_objective, bracket=TV_LAMBDA_RANGE, method="golden"
)
assert optimizer_result.success
assert optimizer_result.success, "Golden minimization failed to find optimal TV lambda"
optimal_lambda = optimizer_result.x

final_map_array = _tv_denoise_ccp4_map(map=difference_map, weight=optimal_lambda)

# TODO: verify correctness
final_map_coefficients = (np.fft.fftn(final_map_array)[::TV_MAP_SAMPLING, ::TV_MAP_SAMPLING, ::TV_MAP_SAMPLING]).flatten()
print(final_map_coefficients.shape, len(difference_map_coefficients))

_, high_resolution_limit = resolution_limits(difference_map_coefficients)
final_map_coefficients = compute_coefficients_from_map(
map=final_map_array,
high_resolution_limit=high_resolution_limit,
amplitude_label=TV_AMPLITUDE_LABEL,
phase_label=TV_PHASE_LABEL,
)

# TODO: need to be sure HKLs line up
difference_map_coefficients[[TV_AMPLITUDE_LABEL]] = np.abs(final_map_coefficients)
Expand Down
48 changes: 30 additions & 18 deletions meteor/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import gemmi as gm
import reciprocalspaceship as rs
from typing import overload, Literal, Union

Check failure on line 4 in meteor/utils.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

meteor/utils.py:4:39: F401 `typing.Union` imported but unused


def resolution_limits(dataset: rs.DataSet) -> tuple[float, float]:
Expand All @@ -17,34 +18,45 @@ def cut_resolution(dataset: rs.DataSet, *, dmax_limit: float | None = None, dmin
return dataset


@overload
def canonicalize_amplitudes(
dataset: rs.DataSet,
amplitude_label: str,
phase_label: str,
inplace: Literal[False],
) -> rs.DataSet:
# TODO review and improve
# I think we can infer phase types
...

new_phis = dataset[phase_label].copy(deep=True)
new_Fs = dataset[amplitude_label].copy(deep=True)

negs = np.where(dataset[amplitude_label] < 0)
@overload
def canonicalize_amplitudes(
dataset: rs.DataSet,
amplitude_label: str,
phase_label: str,
inplace: Literal[True],
) -> None:
...

dataset.canonicalize_phases(inplace=True)

for i in negs:
new_phis.iloc[i] = dataset[phase_label].iloc[i] + 180
new_Fs.iloc[i] = np.abs(new_Fs.iloc[i])
def canonicalize_amplitudes(
dataset: rs.DataSet,
amplitude_label: str,
phase_label: str,
inplace: bool = False,
) -> rs.DataSet | None:

new_phis.canonicalize_phases(inplace=True)
dataset.canonicalize_phases(inplace=inplace)
if not inplace:
dataset = dataset.copy(deep=True)

df_new = dataset.copy(deep=True)
df_new[amplitude_label] = new_Fs
df_new[amplitude_label] = df_new[amplitude_label].astype("SFAmplitude")
df_new[phase_label] = new_phis
df_new[phase_label] = df_new[phase_label].astype("Phase")
negative_amplitude_indices = dataset[amplitude_label] < 0.0
dataset[amplitude_label] = np.abs(dataset[amplitude_label])
dataset.loc[negative_amplitude_indices, phase_label] += 180.0

return df_new
if not inplace:
return dataset
else:
return None


def compute_map_from_coefficients(
Expand Down Expand Up @@ -79,7 +91,7 @@ def compute_coefficients_from_map(
amplitude_label=amplitude_label,
phase_label=phase_label
)
elif isinstance(map, np.ndarray):
elif isinstance(map, gm.Ccp4Map):
return _compute_coefficients_from_ccp4_map(
ccp4_map=map,
high_resolution_limit=high_resolution_limit,
Expand Down Expand Up @@ -111,7 +123,7 @@ def _compute_coefficients_from_ccp4_map(
# to ensure we include the final shell of reflections, add a small buffer to the resolution
high_resolution_buffer = 0.05

gemmi_structure_factors = gm.transform_map_to_f_phi(map.grid, half_l=False)
gemmi_structure_factors = gm.transform_map_to_f_phi(ccp4_map.grid, half_l=False)
data = gemmi_structure_factors.prepare_asu_data(
dmin=high_resolution_limit - high_resolution_buffer, with_sys_abs=True
)
Expand Down
31 changes: 31 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pytest import fixture
import reciprocalspaceship as rs
import numpy as np
import gemmi as gm


@fixture
def random_intensities() -> rs.DataSet:
"""
A simple 10x10x10 P1 dataset, with random intensities
"""

params = (10.0, 10.0, 10.0, 90.0, 90.0, 90.0)
cell = gm.UnitCell(*params)
sg_1 = gm.SpaceGroup(1)
Hall = rs.utils.generate_reciprocal_asu(cell, sg_1, 1.0, anomalous=False)

h, k, l = Hall.T

Check failure on line 18 in test/conftest.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E741)

test/conftest.py:18:11: E741 Ambiguous variable name: `l`
ds = rs.DataSet(
{
"H": h,
"K": k,
"L": l,
"IMEAN": np.abs(np.random.randn(len(h))),
},
spacegroup=sg_1,
cell=cell,
).infer_mtz_dtypes()
ds.set_index(["H", "K", "L"], inplace=True)

return ds
73 changes: 73 additions & 0 deletions test/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from meteor import utils
import reciprocalspaceship as rs
import pytest
import gemmi as gm
import pandas as pd

Check failure on line 5 in test/unit/test_utils.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F401)

test/unit/test_utils.py:5:18: F401 `pandas` imported but unused
import numpy as np


def test_resolution_limits(random_intensities: rs.DataSet) -> None:
Expand Down Expand Up @@ -38,3 +41,73 @@ def test_cut_resolution(random_intensities: rs.DataSet, dmax_limit: float, dmin_
dmax, dmin = utils.resolution_limits(random_intensities)
assert dmax <= expected_max_dmax
assert dmin >= expected_min_dmin

@pytest.mark.parametrize("inplace", [False, True])
def test_canonicalize_amplitudes(inplace: bool, flat_difference_map: rs.DataSet) -> None:
amplitude_label = "DF"
phase_label = "PHIC"

if inplace:
canonicalized = flat_difference_map.copy(deep=True)
utils.canonicalize_amplitudes(
canonicalized,
amplitude_label=amplitude_label,
phase_label=phase_label,
inplace=inplace
)
else:
canonicalized = utils.canonicalize_amplitudes(
flat_difference_map,
amplitude_label=amplitude_label,
phase_label=phase_label,
inplace=inplace
)

assert (canonicalized[amplitude_label] >= 0.0).all()
assert (canonicalized[phase_label] >= -180.0).all()
assert (canonicalized[phase_label] <= 180.0).all()

np.testing.assert_almost_equal(
np.array(np.abs(flat_difference_map[amplitude_label])),
np.array(canonicalized[amplitude_label])
)

def test_compute_map_from_coefficients(flat_difference_map: rs.DataSet) -> None:
map = utils.compute_map_from_coefficients(
map_coefficients=flat_difference_map,
amplitude_label="DF",
phase_label="PHIC",
map_sampling=1,
)
assert isinstance(map, gm.Ccp4Map)
assert map.grid.shape == (6,6,6)


# def test_map_round_trip_ccp4_format(flat_difference_map: rs.DataSet) -> None:
# amplitude_label = "DF"
# phase_label = "PHIC"
# map_sampling = 1

# flat_difference_map = utils.canonicalize_amplitudes(
# flat_difference_map,
# amplitude_label=amplitude_label,
# phase_label=phase_label
# )

# map = utils.compute_map_from_coefficients(
# map_coefficients=flat_difference_map,
# amplitude_label=amplitude_label,
# phase_label=phase_label,
# map_sampling=map_sampling,
# )

# _, dmin = utils.resolution_limits(flat_difference_map)

# output_coefficients = utils.compute_coefficients_from_map(
# map=map,
# high_resolution_limit=dmin,
# amplitude_label=amplitude_label,
# phase_label=phase_label,
# )

# pd.testing.assert_frame_equal(left=flat_difference_map, right=output_coefficients, check_exact=False)

0 comments on commit a761c6d

Please sign in to comment.