Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

phaseboost 🚀 ☄️ #47

Merged
merged 15 commits into from
Oct 28, 2024
4 changes: 1 addition & 3 deletions meteor/diffmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
import reciprocalspaceship as rs

from .rsmap import Map, _assert_is_map
from .settings import MAP_SAMPLING
from .settings import DEFAULT_KPARAMS_TO_SCAN, MAP_SAMPLING
from .utils import filter_common_indices
from .validate import ScalarMaximizer, negentropy

DEFAULT_KPARAMS_TO_SCAN = np.linspace(0.0, 1.0, 101)


def set_common_crystallographic_metadata(map1: Map, map2: Map, *, output: Map) -> None:
if hasattr(map1, "cell"):
Expand Down
32 changes: 8 additions & 24 deletions meteor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,14 @@
from __future__ import annotations

import re
from typing import Final

OBSERVED_INTENSITY_COLUMNS: Final[list[str]] = [
"I", # generic
"IMEAN", # CCP4
"I-obs", # phenix
]

OBSERVED_AMPLITUDE_COLUMNS: Final[list[str]] = [
"F", # generic
"FP", # CCP4 & GLPh native
r"FPH\d", # CCP4 derivative
"F-obs", # phenix
]

OBSERVED_UNCERTAINTY_COLUMNS: Final[list[str]] = [
"SIGF", # generic
"SIGFP", # CCP4 & GLPh native
r"SIGFPH\d", # CCP4
]

COMPUTED_AMPLITUDE_COLUMNS: Final[list[str]] = ["FC"]

COMPUTED_PHASE_COLUMNS: Final[list[str]] = ["PHIC"]

from .settings import (
COMPUTED_AMPLITUDE_COLUMNS,
COMPUTED_PHASE_COLUMNS,
OBSERVED_AMPLITUDE_COLUMNS,
OBSERVED_INTENSITY_COLUMNS,
OBSERVED_UNCERTAINTY_COLUMNS,
)


class AmbiguousMtzColumnError(ValueError): ...
Expand Down
44 changes: 36 additions & 8 deletions meteor/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@

import numpy as np
import pandas as pd
import structlog

from .rsmap import Map
from .settings import (
DEFAULT_TV_WEIGHTS_TO_SCAN_AT_EACH_ITERATION,
ITERATIVE_TV_CONVERGENCE_TOLERANCE,
ITERATIVE_TV_MAX_ITERATIONS,
)
from .tv import TvDenoiseResult, tv_denoise_difference_map
from .utils import (
average_phase_diff_in_degrees,
complex_array_to_rs_dataseries,
)

DEFAULT_TV_WEIGHTS_TO_SCAN = [0.001, 0.01, 0.1, 1.0]
log = structlog.get_logger()


def _project_derivative_on_experimental_set(
Expand Down Expand Up @@ -53,13 +59,14 @@ def _project_derivative_on_experimental_set(
return projected_derivative


def _complex_derivative_from_iterative_tv(
def _complex_derivative_from_iterative_tv( # noqa: PLR0913
*,
native: np.ndarray,
initial_derivative: np.ndarray,
tv_denoise_function: Callable[[np.ndarray], tuple[np.ndarray, TvDenoiseResult]],
convergence_tolerance: float = 1e-4,
max_iterations: int = 1000,
convergence_tolerance: float = ITERATIVE_TV_CONVERGENCE_TOLERANCE,
max_iterations: int = ITERATIVE_TV_MAX_ITERATIONS,
verbose: bool = False,
) -> tuple[np.ndarray, pd.DataFrame]:
"""
Estimate the derivative phases using the iterative TV algorithm.
Expand Down Expand Up @@ -87,6 +94,9 @@ def _complex_derivative_from_iterative_tv(
max_iterations: int
If this number of iterations is reached, stop early. Default 1000.

verbose: bool
Log or not.

Returns
-------
estimated_complex_derivative: np.ndarray
Expand Down Expand Up @@ -127,20 +137,28 @@ def _complex_derivative_from_iterative_tv(
"average_phase_change": phase_change,
},
)
if verbose:
log.info(
f" iteration {num_iterations:04d}", # noqa: G004
phase_change=round(phase_change, 4),
negentropy=round(tv_metadata.optimal_negentropy, 4),
tv_weight=tv_metadata.optimal_tv_weight,
)

if num_iterations > max_iterations:
break

return derivative, pd.DataFrame(metadata)


def iterative_tv_phase_retrieval(
def iterative_tv_phase_retrieval( # noqa: PLR0913
initial_derivative: Map,
native: Map,
*,
convergence_tolerance: float = 1e-4,
max_iterations: int = 1000,
tv_weights_to_scan: list[float] = DEFAULT_TV_WEIGHTS_TO_SCAN,
convergence_tolerance: float = ITERATIVE_TV_CONVERGENCE_TOLERANCE,
max_iterations: int = ITERATIVE_TV_MAX_ITERATIONS,
tv_weights_to_scan: list[float] = DEFAULT_TV_WEIGHTS_TO_SCAN_AT_EACH_ITERATION,
verbose: bool = False,
) -> tuple[Map, pd.DataFrame]:
"""
Here is a brief pseudocode sketch of the alogrithm. Structure factors F below are complex unless
Expand Down Expand Up @@ -182,6 +200,9 @@ def iterative_tv_phase_retrieval(
A list of TV regularization weights (λ values) to be scanned for optimal results,
by default [0.001, 0.01, 0.1, 1.0].

verbose: bool
Log or not.

Returns
-------
output_map: Map
Expand Down Expand Up @@ -209,12 +230,19 @@ def tv_denoise_closure(difference: np.ndarray) -> tuple[np.ndarray, TvDenoiseRes
return denoised_map.complex_amplitudes, tv_metadata

# estimate the derivative phases using the iterative TV algorithm
if verbose:
log.info(
"convergence criteria:",
phase_tolerance=convergence_tolerance,
max_iterations=max_iterations,
)
it_tv_complex_derivative, metadata = _complex_derivative_from_iterative_tv(
native=native.complex_amplitudes,
initial_derivative=initial_derivative.complex_amplitudes,
tv_denoise_function=tv_denoise_closure,
convergence_tolerance=convergence_tolerance,
max_iterations=max_iterations,
verbose=verbose,
)
_, derivative_phases = complex_array_to_rs_dataseries(
it_tv_complex_derivative,
Expand Down
156 changes: 126 additions & 30 deletions meteor/scripts/common.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
from __future__ import annotations

import argparse
import json
import re
from dataclasses import dataclass
from enum import StrEnum, auto
from io import StringIO
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
import reciprocalspaceship as rs
import structlog

from meteor.diffmaps import (
compute_difference_map,
compute_kweighted_difference_map,
max_negentropy_kweighted_difference_map,
)
from meteor.io import find_observed_amplitude_column, find_observed_uncertainty_column
from meteor.rsmap import Map
from meteor.scale import scale_maps
from meteor.settings import COMPUTED_MAP_RESOLUTION_LIMIT, KWEIGHT_PARAMETER_DEFAULT
from meteor.sfcalc import structure_file_to_calculated_map
from meteor.tv import TvDenoiseResult

log = structlog.get_logger()

Expand Down Expand Up @@ -64,82 +76,81 @@ class DiffmapArgParser(argparse.ArgumentParser):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

derivative_group = self.add_argument_group(
"derivative",
description=(
"The 'derivative' diffraction data, typically: light-triggered, ligand-bound, etc. "
),
required_group = self.add_argument_group("required")
required_group.add_argument(
"derivative_mtz",
type=Path,
help="Path to MTZ containing the `derivative` data; positional arg (order matters).",
)
required_group.add_argument(
"native_mtz",
type=Path,
help="Path to MTZ containing the `native` data; positional arg (order matters)",
)
derivative_group.add_argument("derivative_mtz", type=Path)
derivative_group.add_argument(
required_group.add_argument(
"-s",
"--structure",
type=Path,
required=True,
help="Specify CIF or PDB file path, for phases (usually a native model)",
)

labels_group = self.add_argument_group("mtz column labels (input)")
labels_group.add_argument(
"-da",
"--derivative-amplitude-column",
type=str,
default=INFER_COLUMN_NAME,
help="specify the MTZ column for the amplitudes; will try to guess if not provided",
)
derivative_group.add_argument(
labels_group.add_argument(
"-du",
"--derivative-uncertainty-column",
type=str,
default=INFER_COLUMN_NAME,
help="specify the MTZ column for the uncertainties; will try to guess if not provided",
)

native_group = self.add_argument_group(
"native",
description=("The 'native' diffraction data, typically: dark, apo, etc."),
)
native_group.add_argument("native_mtz", type=Path)
native_group.add_argument(
labels_group.add_argument(
"-na",
"--native-amplitude-column",
type=str,
default=INFER_COLUMN_NAME,
help="specify the MTZ column for the amplitudes; will try to guess if not provided",
)
native_group.add_argument(
labels_group.add_argument(
"-nu",
"--native-uncertainty-column",
type=str,
default=INFER_COLUMN_NAME,
help="specify the MTZ column for the uncertainties; will try to guess if not provided",
)

self.add_argument(
"-s",
"--structure",
type=Path,
required=True,
help="Specify CIF or PDB file path, for phases (usually a native model). Required.",
)

self.add_argument(
output_group = self.add_argument_group("output")
output_group.add_argument(
"-o",
"--mtzout",
type=Path,
default=DEFAULT_OUTPUT_MTZ,
help=f"Specify output MTZ file path. Default: {DEFAULT_OUTPUT_MTZ}.",
)

self.add_argument(
output_group.add_argument(
"-m",
"--metadataout",
type=Path,
default=DEFAULT_OUTPUT_METADATA_FILE,
help=f"Specify output metadata file path. Default: {DEFAULT_OUTPUT_METADATA_FILE}.",
)

self.add_argument(
kweight_group = self.add_argument_group("k weighting settings")
kweight_group.add_argument(
"-k",
"--kweight-mode",
type=WeightMode,
default=WeightMode.optimize,
choices=list(WeightMode),
help="How to pick the k-parameter. Optimize means max negentropy. Default: `optimize`.",
)

self.add_argument(
kweight_group.add_argument(
"-w",
"--kweight-parameter",
type=float,
Expand Down Expand Up @@ -235,3 +246,88 @@ def load_difference_maps(args: argparse.Namespace) -> DiffMapSet:

mapset.scale()
return mapset


def kweight_diffmap_according_to_mode(
*, mapset: DiffMapSet, kweight_mode: WeightMode, kweight_parameter: float | None = None
) -> tuple[Map, float | None]:
"""
Make and k-weight a difference map using a specified `WeightMode`.

Three modes are possible to pick the k-parameter:
* `WeightMode.optimize`, max-negentropy value will and picked, this may take some time
* `WeightMode.fixed`, `kweight_parameter` is used
* `WeightMode.none`, then no k-weighting is done (note this is NOT equivalent to
kweight_parameter=0.0)

Parameters
----------
mapset: DiffMapSet
The set of `derivative`, `native`, `computed` maps to use to compute the diffmap.

kweight_mode: WeightMode
How to set the k-parameter: {optimize, fixed, none}. See above. If `fixed`, then
`kweight_parameter` is required.

kweight_parameter: float | None
If kweight_mode == WeightMode.fixed, then this must be a float that specifies the
k-parameter to use.

Returns
-------
diffmap: meteor.rsmap.Map
The difference map, k-weighted if requested.

kweight_parameter: float | None
The `kweight_parameter` used. Only really interesting if WeightMode.optimize.
"""
log.info("Computing difference map.")

if kweight_mode == WeightMode.optimize:
diffmap, kweight_parameter = max_negentropy_kweighted_difference_map(
mapset.derivative, mapset.native
)
log.info(" using negentropy optimized", kparameter=kweight_parameter)
if kweight_parameter is np.nan:
msg = "determined `k-parameter` is NaN, something went wrong..."
raise RuntimeError(msg)

elif kweight_mode == WeightMode.fixed:
if not isinstance(kweight_parameter, float):
msg = f"`kweight_parameter` is type `{type(kweight_parameter)}`, must be `float`"
raise TypeError(msg)

diffmap = compute_kweighted_difference_map(
mapset.derivative, mapset.native, k_parameter=kweight_parameter
)

log.info(" using fixed", kparameter=kweight_parameter)

elif kweight_mode == WeightMode.none:
diffmap = compute_difference_map(mapset.derivative, mapset.native)
kweight_parameter = None
log.info(" requested no k-weighting")

else:
raise InvalidWeightModeError(kweight_mode)

return diffmap, kweight_parameter


def write_combined_metadata(
*, filename: Path, it_tv_metadata: pd.DataFrame, final_tv_metadata: TvDenoiseResult
) -> None:
combined_metadata = {
"iterative_tv": it_tv_metadata.to_json(orient="records", indent=4),
"final_tv_pass": final_tv_metadata.json(),
}
with filename.open("w") as f:
json.dump(combined_metadata, f, indent=4)


def read_combined_metadata(*, filename: Path) -> tuple[pd.DataFrame, TvDenoiseResult]:
with filename.open("r") as f:
combined_metadata = json.load(f)
it_tv_metadata = pd.read_json(StringIO(combined_metadata["iterative_tv"]))
final_tv_metadata = TvDenoiseResult.from_json(combined_metadata["final_tv_pass"])
return it_tv_metadata, final_tv_metadata
Loading
Loading