diff --git a/sed/calibrator/momentum.py b/sed/calibrator/momentum.py index 7fb585de..59546946 100644 --- a/sed/calibrator/momentum.py +++ b/sed/calibrator/momentum.py @@ -6,6 +6,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Sequence from typing import Tuple from typing import Union @@ -107,7 +108,10 @@ def __init__( self.correction: Dict[Any, Any] = {"applied": False} self.adjust_params: Dict[Any, Any] = {"applied": False} self.calibration: Dict[Any, Any] = {} - + self.division_model_params: Dict[str, Any] = self._config["momentum"].get( + "division_model_params", + {}, + ) self.x_column = self._config["dataframe"]["x_column"] self.y_column = self._config["dataframe"]["y_column"] self.corrected_x_column = self._config["dataframe"]["corrected_x_column"] @@ -1837,6 +1841,122 @@ def gather_calibration_metadata(self, calibration: dict = None) -> dict: return metadata + def calibrate_k_division_model( + self, + df: Union[pd.DataFrame, dask.dataframe.DataFrame], + warp_params: Union[Dict[str, Any], Sequence[float]] = None, + x_column: str = None, + y_column: str = None, + kx_column: str = None, + ky_column: str = None, + ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]: + """Use the division model to calibrate the momentum axis. + + This function returns the distorted coordinates given the undistorted ones + a little complicated by the fact that gamma needs to go to (0,0). + it uses a radial distortion model called division model + (https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction) + commonly used to correct for lens artifacts. + + The radial distortion parameters k0, k1, k2 are defined as follows: + .. math:: + K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4) + where rpx is the distance from the center of distortion in pixels and rk is the + distance from the center of distortion in k space. + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the + distotion correction to. + warp_params (Sequence[float], optional): Parameters of the division model. + Either a dictionary containing the parameters or a sequence of the + parameters in the order ['center','k0','k1','k2','gamma']. + center and gamma are both 2D vectors, k0, k1 and k2 are scalars. + Defaults to config["momentum"]["division_model_params"]. + x_column (str, optional): Label of the source 'X' column. + Defaults to config["momentum"]["x_column"]. + y_column (str, optional): Label of the source 'Y' column. + Defaults to config["momentum"]["y_column"]. + kx_column (str, optional): Label of the destination 'X' column after + momentum calibration. Defaults to config["momentum"]["kx_column"]. + ky_column (str, optional): Label of the destination 'Y' column after + momentum calibration. Defaults to config["momentum"]["ky_column"]. + + Returns: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe with added columns + metadata (dict): momentum calibration metadata dictionary. + """ + if x_column is None: + x_column = self.x_column + if y_column is None: + y_column = self.y_column + if kx_column is None: + kx_column = self.kx_column + if ky_column is None: + ky_column = self.ky_column + + if warp_params is None: + warp_params = self.division_model_params + + if isinstance(warp_params, Sequence): + if len(warp_params) != 7: + raise ValueError( + f"Warp parameters must be a sequence of 7 floats! (center, k0, k1, k2, gamma)\n" + f"Got {len(warp_params)} instead", + ) + warp_params = { + "center": np.asarray(warp_params[0:2]), + "k0": warp_params[2], + "k1": warp_params[3], + "k2": warp_params[4], + "gamma": np.asarray(warp_params[5:7]), + } + elif isinstance(warp_params, dict): + if not all(key in warp_params for key in ["center", "k0", "k1", "k2", "gamma"]): + raise ValueError( + f"Warp parameters must be a dictionary containing the keys " + "'center', 'k0', 'k1', 'k2', 'gamma'!\n" + f"Got {warp_params.keys()} instead", + ) + if len(warp_params["center"]) != 2: + raise ValueError( + f"Warp parameter 'center' must be a 2D vector!\n" + f"Got {warp_params['center']} instead", + ) + if len(warp_params["gamma"]) != 2: + raise ValueError( + f"Warp parameter 'gamma' must be a 2D vector!\n" + f"Got {warp_params['gamma']} instead", + ) + if not all( + isinstance(value, (int, float, np.integer, np.floating)) + for value in [warp_params[k] for k in ["k0", "k1", "k2"]] + ): + raise ValueError( + f"Warp parameters 'k0', 'k1' and 'k2' must be floats!\n" + f"Got {warp_params['k0']}, {warp_params['k1']} and {warp_params['k2']} instead", + ) + else: + raise TypeError("Warp parameters must be a dictionary or a sequence of floats!") + + df = calibrate_k_division_model( + df, + x_column=x_column, + y_column=y_column, + kx_column=kx_column, + ky_column=ky_column, + **warp_params, + ) + + metadata = { + "applied": True, + "warp_params": warp_params, + "x_column": x_column, + "y_column": y_column, + "kx_column": kx_column, + "ky_column": ky_column, + } + return df, metadata + def cm2palette(cmap_name: str) -> list: """Convert certain matplotlib colormap (cm) to bokeh palette. @@ -2091,3 +2211,83 @@ def load_dfield(file: str) -> Tuple[np.ndarray, np.ndarray]: pass return rdeform_field, cdeform_field + + +def calibrate_k_division_model( + df: Union[pd.DataFrame, dask.dataframe.DataFrame], + center: Tuple[float, float] = None, + k0: float = None, + k1: float = None, + k2: float = None, + rot: float = None, + gamma: Tuple[float, float] = None, + x_column: str = None, + y_column: str = None, + kx_column: str = None, + ky_column: str = None, +) -> dask.dataframe.DataFrame: + """K calibration based on the division model + + This function returns the distorted coordinates given the undistorted ones + a little complicated by the fact that gamma needs to go to (0,0). + it uses a radial distortion model called division model + (https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction) + commonly used to correct for lens artifacts. + + The radial distortion parameters k0, k1, k2 are defined as follows: + .. math:: + K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4) + where rpx is the distance from the center of distortion in pixels and rk is the + distance from the center of distortion in k space. + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the + distotion correction to. + center (Tuple[float, float]): center of distortion in px + k0 (float): radial distortion parameter + k1 (float): radial distortion parameter + k2 (float): radial distortion parameter + rot (float): rotation in rad + gamma (Tuple[float, float]): normal emission (Gamma) in px + x_column (str): Name of the column containing the x steps. + y_column (str): Name of the column containing the y steps. + kx_column (str, optional): Name of the target calibrated x column. + If None, defaults to x_column. + ky_column (str, optional): Name of the target calibrated x column. + If None, defaults to y_column. + + Returns: + df (dask.dataframe.DataFrame): Dataframe with added columns + """ + if kx_column is None: + kx_column = x_column + if ky_column is None: + ky_column = y_column + + def convert_to_kx(x): + """Converts the x steps to kx.""" + x_diff = x[x_column] - center[0] + y_diff = x[y_column] - center[1] + dist = np.sqrt(x_diff**2 + y_diff**2) + den = k0 + k1 * dist**2 + k2 * dist**4 + angle = np.arctan2(y_diff, x_diff) - rot + warp_diff = np.sqrt((gamma[0] - center[0]) ** 2 + (gamma[1] - center[1]) ** 2) + warp_den = k0 + k1 * (gamma[0] - center[0]) ** 2 + k2 * (gamma[1] - center[1]) ** 2 + warp_angle = np.arctan2(gamma[1] - center[1], gamma[0] - center[0]) - rot + return (dist / den) * np.cos(angle) - (warp_diff / warp_den) * np.cos(warp_angle) + + def convert_to_ky(x): + x_diff = x[x_column] - center[0] + y_diff = x[y_column] - center[1] + dist = np.sqrt(x_diff**2 + y_diff**2) + den = k0 + k1 * dist**2 + k2 * dist**4 + angle = np.arctan2(y_diff, x_diff) - rot + warp_diff = np.sqrt((gamma[0] - center[0]) ** 2 + (gamma[1] - center[1]) ** 2) + warp_den = k0 + k1 * (gamma[0] - center[0]) ** 2 + k2 * (gamma[1] - center[1]) ** 2 + warp_angle = np.arctan2(gamma[1] - center[1], gamma[0] - center[0]) - rot + return (dist / den) * np.sin(angle) - (warp_diff / warp_den) * np.sin(warp_angle) + + df[kx_column] = df.map_partitions(convert_to_kx, meta=(kx_column, np.float64)) + df[ky_column] = df.map_partitions(convert_to_ky, meta=(ky_column, np.float64)) + + return df diff --git a/sed/core/processor.py b/sed/core/processor.py index c01bbd5c..e6a2db59 100644 --- a/sed/core/processor.py +++ b/sed/core/processor.py @@ -1507,6 +1507,86 @@ def save_delay_calibration( } save_config(config, filename, overwrite) + def calibrate_k_division_model( + self, + warp_params: Sequence[float] = None, + **kwargs, + ) -> None: + """Use the division model to calibrate the momentum axis. + + This function returns the distorted coordinates given the undistorted ones + a little complicated by the fact that gamma needs to go to (0,0). + it uses a radial distortion model called division model + (https://en.wikipedia.org/wiki/Distortion_(optics)#Software_correction) + commonly used to correct for lens artifacts. + + The radial distortion parameters k0, k1, k2 are defined as follows: + .. math:: + K_n; rk = rpx/(K_0 + K_1*rpx^2 + K_2*rpx^4) + where rpx is the distance from the center of distortion in pixels and rk is the + distance from the center of distortion in k space. + + Args: + df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to apply the + distotion correction to. + warp_params (Sequence[float], optional): Parameters of the division model. + Either a dictionary containing the parameters or a sequence of the + parameters in the order ['center','k0','k1','k2','gamma']. + center and gamma are both 2D vectors, k0, k1 and k2 are scalars. + Center is the center of distortion in pixels, gamma is the center of + the image in k space. k0, k1 and k2 are the radial distortion parameters. + Defaults to config["momentum"]["division_model_params"]. + kwargs: Keyword arguments passed to ``calibrate_k_division_model``: + x_column (str, optional): Label of the source 'X' column. + Defaults to config["momentum"]["x_column"]. + y_column (str, optional): Label of the source 'Y' column. + Defaults to config["momentum"]["y_column"]. + kx_column (str, optional): Label of the destination 'X' column after + momentum calibration. Defaults to config["momentum"]["kx_column"]. + ky_column (str, optional): Label of the destination 'Y' column after + momentum calibration. Defaults to config["momentum"]["ky_column"]. + """ + self._dataframe, metadata = self.mc.calibrate_k_division_model( + df=self._dataframe, + warp_params=warp_params, + **kwargs, + ) + self._attributes.add( + metadata, + "k_division_model", + duplicate_policy="raise", + ) + + def save_k_division_model( + self, + filename: str = None, + overwrite: bool = False, + ) -> None: + """save the generated k division model parameters to the folder config file. + + + + Args: + filename (str, optional): Filename of the config dictionary to save to. + Defaults to "sed_config.yaml" in the current folder. + overwrite (bool, optional): Option to overwrite the present dictionary. + Defaults to False. + """ + if filename is None: + filename = "sed_config.yaml" + params = {} + try: + for key in ["center", "k0", "k1", "k2", "gamma"]: + params[key] = self.mc.division_model_params[key] + except KeyError as exc: + raise KeyError( + "k division model parameters not found, need to generate parameters first!", + ) from exc + + config: Dict[str, Any] = {"momentum": {"k_division_model": params}} + save_config(config, filename, overwrite) + print(f"Saved k division model parameters to {filename}") + def add_delay_offset( self, constant: float = None, @@ -1529,7 +1609,6 @@ def add_delay_offset( of dask.dataframe.Series. For example "mean". In this case the function is applied to the column to generate a single value for the whole dataset. If None, the shift is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported. - Returns: None """ @@ -1618,6 +1697,7 @@ def save_workflow_params( self.save_energy_offset, self.save_delay_calibration, self.save_delay_offsets, + self.save_k_division_model, ]: try: method(filename, overwrite)