From 537eb2f53fca37ea928bf30de6632782f95c4a0b Mon Sep 17 00:00:00 2001 From: tjlane Date: Sun, 11 Aug 2024 16:46:55 +0100 Subject: [PATCH] mainline dev yolo --- meteor/TV_filtering/__init__.py | 0 meteor/TV_filtering/iterative_tv.py | 403 -------------------------- meteor/TV_filtering/tv_denoise_map.py | 116 -------- meteor/dsutils.py | 52 ---- meteor/mask.py | 0 meteor/validate.py | 27 +- pyproject.toml | 3 + test/test_dsutils.py | 48 --- test/test_io.py | 26 -- test/unit/test_validate.py | 11 + 10 files changed, 25 insertions(+), 661 deletions(-) delete mode 100755 meteor/TV_filtering/__init__.py delete mode 100755 meteor/TV_filtering/iterative_tv.py delete mode 100755 meteor/TV_filtering/tv_denoise_map.py delete mode 100644 meteor/mask.py delete mode 100644 test/test_dsutils.py delete mode 100644 test/test_io.py create mode 100644 test/unit/test_validate.py diff --git a/meteor/TV_filtering/__init__.py b/meteor/TV_filtering/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/meteor/TV_filtering/iterative_tv.py b/meteor/TV_filtering/iterative_tv.py deleted file mode 100755 index 967a49d..0000000 --- a/meteor/TV_filtering/iterative_tv.py +++ /dev/null @@ -1,403 +0,0 @@ -import argparse -from re import I -from tqdm import tqdm -import matplotlib.pyplot as plt -import numpy as np -import os - -import seaborn as sns - -sns.set_context("notebook", font_scale=1.8) - -from meteor import io -from meteor import dsutils -from meteor import maps -from meteor import tv - -""" - -Iterative TV algorithm to improve phase estimates for low occupancy species. -Writes out a difference map file (MTZ) with improved phases. - -""" - - -def parse_arguments(): - """Parse commandline arguments""" - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter, description=__doc__ - ) - - # Required arguments - parser.add_argument( - "-mtz", - "--mtz", - nargs=6, - metavar=("mtz", "F_off", "F_on", "phi_col", "SIGF_off", "SIGF_on"), - required=True, - help=( - "MTZ to be used for initial map. Specified as (filename, F_off, F_on, Phi)" - ), - ) - - parser.add_argument( - "-ref", - "--refpdb", - nargs=1, - metavar=("pdb"), - required=True, - help=( - "PDB to be used as reference ('off') structure. " "Specified as (filename)." - ), - ) - - # Optional arguments - - parser.add_argument( - "-d_h", - "--highres", - type=float, - default=None, - help="If set, high res to truncate maps", - ) - - parser.add_argument( - "-fl", - "--flags", - type=str, - default=None, - help="If set, label for Rfree flags to use as test set", - ) - - parser.add_argument( - "-a", - "--alpha", - type=float, - default=0.0, - help="alpha value for computing difference map weights (default=0.0)", - ) - - parser.add_argument( - "-l", - "--lambda_tv", - type=float, - default=0.015, - help="lambda value for TV denoising weighting (default 0.015)", - ) - - parser.add_argument( - "--plot", - help="--plot for optional plotting and saving, --no-plot to skip", - action=argparse.BooleanOptionalAction, - ) - - return parser.parse_args() - - -def main(): - - # Map writing parameters - map_res = 4 - - # Parse commandline arguments - args = parse_arguments() - - path = os.path.split(args.mtz[0])[0] - name = os.path.split(args.mtz[0])[1].split(".")[0] - cell, space_group = io.get_pdbinfo(args.refpdb[0]) - - print( - "%%%%%%%%%% ANALYZING DATASET : {n} in {p} %%%%%%%%%%%".format(n=name, p=path) - ) - print("CELL : {}".format(cell)) - print("SPACEGROUP : {}".format(space_group)) - - # Apply resolution cut if specified - if args.highres is not None: - high_res = args.highres - else: - high_res = np.min(io.load_mtz(args.mtz[0]).compute_dHKL()["dHKL"]) - - # Read in mtz file - og_mtz = io.load_mtz(args.mtz[0]) - og_mtz = og_mtz.loc[og_mtz.compute_dHKL()["dHKL"] > high_res] - # og_mtz = og_mtz.loc[og_mtz.compute_dHKL()["dHKL"] < 15] - - # for ocp - # calc = io.load_mtz("ecn_dark_SFALL.mtz") - # common = og_mtz.index.intersection(calc.index).sort_values() - - # Use own R-free flags set if specified - if args.flags is not None: - og_mtz = og_mtz.loc[:, [args.mtz[1], args.mtz[2], args.mtz[3], args.flags]] - flags = og_mtz[args.flags] == 0 - - else: - og_mtz = og_mtz.loc[ - :, [args.mtz[1], args.mtz[2], args.mtz[3], args.mtz[4], args.mtz[5]] - ] - flags = np.random.binomial(1, 0.03, og_mtz[args.mtz[1]].shape[0]).astype(bool) - - # scale second dataset ('on') and the first ('off') to FCalcs, and calculate deltaFs (weighted or not) - # og_mtz["FC"] = calc.loc[common, "FC_ALL"] - og_mtz, ws = maps.find_w_diffs( - og_mtz, - args.mtz[2], - args.mtz[1], - args.mtz[5], - args.mtz[4], - args.refpdb[0], - high_res, - path, - args.alpha, - ) - - # in case of calculated structure factors: - # og_mtz["light-phis"] = mtr.load_mtz(args.mtz[0])["light-phis"] - - proj_mags = [] - entropies = [] - phase_changes = [] - cum_phase_changes = [] - - # in case of calculated structure factors: - # ph_err_corrs = [] - - N = 50 - l = args.lambda_tv - - with tqdm(total=N) as pbar: - for i in np.arange(N) + 1: - if i == 1: - new_amps, new_phases, proj_mag, entropy, phase_change, z = ( - tv.TV_iteration( - og_mtz, - "WDF", - args.mtz[3], - "scaled_on", - "scaled_off", - args.mtz[3], - map_res, - cell, - space_group, - flags, - l, - high_res, - ws, - ) - ) - cum_phase_change = np.abs( - np.array( - dsutils.positive_Fs( - og_mtz, args.mtz[3], "WDF", "phases-pos", "diffs-pos" - )["phases-pos"] - - new_phases - ) - ) - # ph_err_corr = np.abs(new_phases - np.array(mtr.positive_Fs(og_mtz, "light-phis", "diffs", "phases-pos", "diffs-pos")["phases-pos"])) - - cum_phase_change = dsutils.adjust_phi_interval(cum_phase_change) - # ph_err_corr = mtr.adjust_phi_interval(ph_err_corr) - - og_mtz["new_amps"] = new_amps - og_mtz["new_amps"] = og_mtz["new_amps"].astype("SFAmplitude") - og_mtz["new_phases"] = new_phases - og_mtz["new_phases"] = og_mtz["new_phases"].astype("Phase") - og_mtz.write_mtz("{name}_TVit{i}_{l}.mtz".format(name=name, i=i, l=l)) - - else: - new_amps, new_phases, proj_mag, entropy, phase_change, z = ( - tv.TV_iteration( - og_mtz, - "new_amps", - "new_phases", - "scaled_on", - "scaled_off", - args.mtz[3], - map_res, - cell, - space_group, - flags, - l, - high_res, - ws, - ) - ) - cum_phase_change = np.abs( - np.array( - dsutils.positive_Fs( - og_mtz, args.mtz[3], "WDF", "phases-pos", "diffs-pos" - )["phases-pos"] - - new_phases - ) - ) - # ph_err_corr = np.abs(new_phases - np.array(mtr.positive_Fs(og_mtz, "light-phis", "diffs", "phases-pos", "diffs-pos")["phases-pos"])) - - cum_phase_change = dsutils.adjust_phi_interval(cum_phase_change) - - og_mtz["new_amps"] = new_amps - og_mtz["new_amps"] = og_mtz["new_amps"].astype("SFAmplitude") - og_mtz["new_phases"] = new_phases - og_mtz["new_phases"] = og_mtz["new_phases"].astype("Phase") - - og_mtz["new-light-phi"] = z - og_mtz["new-light-phi"] = og_mtz["new-light-phi"].astype("Phase") - - og_mtz.write_mtz("{name}_TVit{i}_{l}.mtz".format(name=name, i=i, l=l)) - - # Track projection magnitude and phase change for each iteration - proj_mags.append(proj_mag) - entropies.append(entropy) - phase_changes.append(phase_change) - cum_phase_changes.append(cum_phase_change) - # ph_err_corrs.append(ph_err_corr) - pbar.update() - - # np.save("{name}_TVit{i}_{l}-Z-mags.npy".format(name=name, i=i, l=l), z) - print("FINAL ENTROPY VALUE ", entropies[-1]) - np.save("{name}_TVit{i}_{l}-entropies.npy".format(name=name, i=i, l=l), entropies) - - print("SAVING FINAL MAP") - - og_mtz["new_amps"] = new_amps * ws - - finmap = dsutils.map_from_Fs(og_mtz, "new_amps", "new_phases", map_res) - - from meteor import validate - - fin_map_afterkweight = validate.negentropy(np.array(finmap.grid).flatten()) - print("FINAL ENTROPY AFTER K WEIGHTING ", fin_map_afterkweight) - finmap.write_ccp4_map("{name}_TVit{i}_{l}.ccp4".format(name=name, i=i, l=l)) - fin_TV_map, fin_ent_TV = tv.TV_filter( - finmap, l, finmap.grid.shape, cell, space_group - ) - print("FINAL TV FILTERED ENTROPY VALUE", fin_ent_TV) - mtz_finTV = dsutils.from_gemmi(io.map2mtz(fin_TV_map, high_res)) - mtz_finTV.write_mtz( - "{name}_TVit{i}_{l}_finalTVfiltered.mtz".format(name=name, i=i, l=l) - ) - - # Optionally plot result for errors and negentropy - if args.plot is True: - - fig, ax1 = plt.subplots(figsize=(10, 4)) - - color = "black" - ax1.set_title("$\lambda$ = {}".format(l)) - ax1.set_xlabel(r"Iteration") - ax1.set_ylabel(r"TV Projection Magnitude (TV$_\mathrm{proj}$)", color=color) - ax1.plot( - np.arange(N - 1), - np.array( - np.mean(np.array(proj_mags), axis=1) - / np.max(np.mean(np.array(proj_mags), axis=1)) - )[1:], - color=color, - linewidth=5, - ) - ax1.tick_params(axis="y", labelcolor=color) - - ax2 = ax1.twinx() - - color = "silver" - ax2.set_ylabel("Negentropy", color="grey") - ax2.plot(np.arange(N - 1), entropies[1:], color=color, linewidth=5) - ax2.tick_params(axis="y", labelcolor="grey") - # ax2.set_xlim(0,0.11) - plt.tight_layout() - fig.savefig("{p}/{n}non-cum-error-TV.png".format(p=path, n=name)) - # np.save("/Users/alisia/Desktop/negentropy.npy", entropies[1:]) - - fig, ax1 = plt.subplots(figsize=(10, 4)) - - color = "tomato" - ax1.set_title("$\lambda$ = {}".format(l)) - ax1.set_xlabel(r"Iteration") - ax1.set_ylabel( - " Iteration < $\phi_\mathrm{c}$ - $\phi_\mathrm{TV}$ > ($^\circ$)" - ) - ax1.plot(np.arange(N), phase_changes, color=color, linewidth=5) - ax1.tick_params(axis="y") - fig.set_tight_layout(True) - fig.savefig("{p}/{n}non-cum-phi-change.png".format(p=path, n=name)) - # np.save("/Users/alisia/Desktop/non-cum-phi-change.npy", phase_changes) - - fig, ax = plt.subplots(figsize=(10, 5)) - ax.set_title("$\lambda$ = {}".format(l)) - ax.set_xlabel(r"Iteration") - ax.set_ylabel( - r"Cumulative < $\phi_\mathrm{c}$ - $\phi_\mathrm{TV}$ > ($^\circ$)" - ) - ax.plot( - np.arange(N), - np.mean(cum_phase_changes, axis=1), - color="orangered", - linewidth=5, - ) - # ax.plot(np.arange(N), np.mean(ph_err_corrs, axis=1), color='orangered', linewidth=5, linestyle='--') - plt.tight_layout() - fig.savefig("{p}/{n}cum-phi-change.png".format(p=path, n=name)) - # np.save( - # "/Users/alisia/Desktop/cum-phi-change.npy", - # np.mean(cum_phase_changes, axis=1), - # ) - - fig, ax = plt.subplots(figsize=(6, 5)) - ax.set_title("$\lambda$ = {}".format(l)) - ax.set_xlabel(r"1/dHKL (${\AA}^{-1}$)") - ax.set_ylabel(r"TV Projection Magnitude (TV$_\mathrm{proj}$)") - ax.scatter( - 1 / og_mtz.compute_dHKL()["dHKL"][flags], - proj_mags[N - 1], - color="black", - alpha=0.5, - ) - - res_mean, data_mean = dsutils.resolution_shells( - proj_mags[N - 1], 1 / og_mtz.compute_dHKL()["dHKL"][flags], 15 - ) - ax.plot(res_mean, data_mean, linewidth=3, linestyle="--", color="orangered") - plt.tight_layout() - fig.savefig("{p}/{n}tv-err-dhkl.png".format(p=path, n=name)) - - fig, ax = plt.subplots(figsize=(6, 5)) - ax.set_title("$\lambda$ = {}".format(l)) - ax.set_xlabel(r"1/dHKL (${\AA}^{-1}$)") - ax.set_ylabel( - r"Cumulative < $\phi_\mathrm{c}$ - $\phi_\mathrm{TV}$ > ($^\circ$)" - ) - ax.scatter( - 1 / og_mtz.compute_dHKL()["dHKL"], - np.abs(cum_phase_changes[N - 1]), - color="orangered", - alpha=0.05, - ) - # ax.scatter(1/og_mtz.compute_dHKL()["dHKL"], np.abs(ph_err_corrs[N-1]), color='blue', alpha=0.5) - - res_mean, data_mean = dsutils.resolution_shells( - np.abs(cum_phase_changes[N - 1]), 1 / og_mtz.compute_dHKL()["dHKL"], 15 - ) - ax.plot(res_mean, data_mean, linewidth=3, linestyle="--", color="black") - plt.tight_layout() - fig.savefig("{p}/{n}cum-phase-change-dhkl.png".format(p=path, n=name)) - - fig, ax = plt.subplots(figsize=(6, 5)) - ax.set_title("$\lambda$ = {}".format(l)) - ax.set_xlabel(r"|Fobs|") - ax.set_ylabel( - r"Cumulative < $\phi_\mathrm{c}$ - $\phi_\mathrm{TV}$ > ($^\circ$)" - ) - ax.scatter( - og_mtz["WDF"], - np.abs(cum_phase_changes[N - 1]), - color="mediumpurple", - alpha=0.5, - ) - plt.tight_layout() - fig.savefig("{p}/{n}cum-phase-change-Fobs.png".format(p=path, n=name)) - - print("DONE") - - -if __name__ == "__main__": - main() diff --git a/meteor/TV_filtering/tv_denoise_map.py b/meteor/TV_filtering/tv_denoise_map.py deleted file mode 100755 index 43e6460..0000000 --- a/meteor/TV_filtering/tv_denoise_map.py +++ /dev/null @@ -1,116 +0,0 @@ -import argparse -import numpy as np -import os -import reciprocalspaceship as rs - - -from meteor import meteor_io -from meteor import dsutils -from meteor import tv - - -""" - -Apply a total variation (TV) filter to a map. -The level of filtering (determined by the regularization parameter lambda) -is chosen so as to maximize the map negentropy - - -""" - - -def parse_arguments(): - """Parse commandline arguments""" - parser = argparse.ArgumentParser( - formatter_class=argparse.RawTextHelpFormatter, description=__doc__ - ) - - # Required arguments - parser.add_argument( - "-mtz", - "--mtz", - nargs=3, - metavar=("mtz", "F_col", "phi_col"), - required=True, - help=("MTZ to be used for initial map. Specified as (filename, F, Phi)"), - ) - - # Optional arguments - parser.add_argument( - "-d_h", - "--highres", - type=float, - default=None, - help="If set, high res to truncate maps", - ) - - return parser.parse_args() - - -def main(): - - # Magic numbers - map_spacing = 4 - - # Parse commandline arguments - args = parse_arguments() - - path = os.path.split(args.mtz[0])[0] - name = os.path.split(args.mtz[0])[1].split(".")[ - 0 - ] # TODO this seems problematic when full paths are not given - - # Read in mtz file - og_mtz = meteor_io.subset_to_FandPhi( - *args.mtz, {args.mtz[1]: "F", args.mtz[2]: "Phi"} - ).dropna() - og_mtz = og_mtz.compute_dHKL() - - # Apply resolution cut if specified - if args.highres is not None: - high_res = args.highres - else: - high_res = np.min(["dHKL"]) - - og_mtz = og_mtz.compute_dHKL() - og_mtz = og_mtz.loc[og_mtz["dHKL"] > high_res] - - # Find and save denoised maps that maximizes the map negentropy - ( - TVmap_best_err, - TVmap_best_entr, - lambda_best_err, - lambda_best_entr, - errors, - entropies, - amp_change, - ph_change, - ) = tv.find_TVmap(og_mtz, "F", "Phi", name, path, map_res, cell, space_group) - - meteor_io.map2mtzfile( - TVmap_best_err, - "{n}_TV_{l}_besterror.mtz".format( - n=name, l=np.round(lambda_best_err, decimals=3) - ), - high_res, - ) - meteor_io.map2mtzfile( - TVmap_best_entr, - "{n}_TV_{l}_bestentropy.mtz".format( - n=name, l=np.round(lambda_best_entr, decimals=3) - ), - high_res, - ) - - print( - "Writing out TV denoised map with weights={lerr} and {lentr}".format( - lerr=np.round(lambda_best_err, decimals=3), - lentr=np.round(lambda_best_entr, decimals=3), - ) - ) - - print("DONE.") - - -if __name__ == "__main__": - main() diff --git a/meteor/dsutils.py b/meteor/dsutils.py index a5d4c20..8f20224 100644 --- a/meteor/dsutils.py +++ b/meteor/dsutils.py @@ -96,55 +96,3 @@ def map_from_Fs(dataset, Fs, phis, map_res): return ccp4 - -def from_gemmi(gemmi_mtz): - - """ - Construct DataSet from gemmi.Mtz object - - If the gemmi.Mtz object contains an M/ISYM column and contains duplicated - Miller indices, an unmerged DataSet will be constructed. The Miller indices - will be mapped to their observed values, and a partiality flag will be - extracted and stored as a boolean column with the label, ``PARTIAL``. - Otherwise, a merged DataSet will be constructed. - If columns are found with the ``MTZInt`` dtype and are labeled ``PARTIAL`` - or ``CENTRIC``, these will be interpreted as boolean flags used to - label partial or centric reflections, respectively. - - Parameters - ---------- - gemmi_mtz : gemmi.Mtz - gemmi Mtz object - - Returns - ------- - rs.DataSet - """ - - dataset = rs.DataSet(spacegroup=gemmi_mtz.spacegroup, cell=gemmi_mtz.cell) - - # Build up DataSet - for c in gemmi_mtz.columns: - dataset[c.label] = c.array - # Special case for CENTRIC and PARTIAL flags - if c.type == "I" and c.label in ["CENTRIC", "PARTIAL"]: - dataset[c.label] = dataset[c.label].astype(bool) - else: - dataset[c.label] = dataset[c.label].astype(c.type) - dataset.set_index(["H", "K", "L"], inplace=True) - - # Handle unmerged DataSet. Raise ValueError if M/ISYM column is not unique - m_isym = dataset.get_m_isym_keys() - if m_isym and dataset.index.duplicated().any(): - if len(m_isym) == 1: - dataset.merged = False - dataset.hkl_to_observed(m_isym[0], inplace=True) - else: - raise ValueError( - "Only a single M/ISYM column is supported for unmerged data" - ) - else: - dataset.merged = True - - return dataset - diff --git a/meteor/mask.py b/meteor/mask.py deleted file mode 100644 index e69de29..0000000 diff --git a/meteor/validate.py b/meteor/validate.py index 88fbca4..90014a5 100644 --- a/meteor/validate.py +++ b/meteor/validate.py @@ -1,26 +1,21 @@ import numpy as np from scipy.stats import differential_entropy -from . import mask - -def negentropy(X): +def negentropy(samples: np.ndarray, tolerance: float = 1e-4) -> float: """ Return negentropy (float) of X (numpy array) - """ - - # negetropy is the difference between the entropy of samples x - # and a Gaussian with same variance - # http://gregorygundersen.com/blog/2020/09/01/gaussian-entropy/ - - std = np.std(X) - # neg_e = np.log(std*np.sqrt(2*np.pi*np.exp(1))) - differential_entropy(X) - neg_e = 0.5 * np.log(2.0 * np.pi * std**2) + 0.5 - differential_entropy(X) - # assert neg_e >= 0.0 + 1e-8 - - return neg_e - + The negetropy is defined as the difference between the entropy of a distribution + and a Gaussian with same variance. + citation: + http://gregorygundersen.com/blog/2020/09/01/gaussian-entropy/ + """ + std = np.std(samples.flatten()) + neg_e = 0.5 * np.log(2.0 * np.pi * std**2) + 0.5 - differential_entropy(samples.flatten()) + if not neg_e >= -tolerance: + raise ValueError(f"negentropy is a relatively big negative number {neg_e} that exceeds the tolerance {tolerance} -- something may have gone wrong") + return neg_e diff --git a/pyproject.toml b/pyproject.toml index 263b9a1..30d319c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,9 @@ authors = [ { name = "Thomas Lane", email = "thomas.lane@desy.de" } ] dependencies = [ + "numpy", + "scipy", + "skimage", "reciprocalspaceship", ] diff --git a/test/test_dsutils.py b/test/test_dsutils.py deleted file mode 100644 index 4189ce3..0000000 --- a/test/test_dsutils.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest -import numpy as np -import reciprocalspaceship as rs -from pathlib import Path -from meteor import dsutils - -@pytest.fixture -def dummy_mtz(): - data = np.array([1, 2, 3, 4, 5]) - dhkl = np.array([1.1, 1.2, 1.3, 1.4, 1.5]) - dataset = rs.DataSet({"data": data, "dHKL": dhkl}) - return dataset - -def test_res_cutoff(dummy_mtz): - df = dummy_mtz - h_res = 1.2 - l_res = 1.4 - - filtered_df = dsutils.res_cutoff(df, h_res, l_res) - - assert len(filtered_df) == 2 - assert filtered_df["data"].tolist() == [2, 3] - -def test_res_cutoff_raises_error(dummy_mtz): - df = dummy_mtz - h_res = 1.5 - l_res = 1.2 - - with pytest.raises(ValueError): - dsutils.res_cutoff(df, h_res, l_res) - -def test_resolution_shells(dummy_mtz): - data = dummy_mtz["data"] - dhkl = dummy_mtz["dHKL"] - n = 2 - - bin_centers, mean_data = dsutils.resolution_shells(data, dhkl, n) - - assert len(bin_centers) == n - assert len(mean_data.statistic) == n - -def test_resolution_shells_raises_error(dummy_mtz): - data = dummy_mtz["data"] - dhkl = dummy_mtz["dHKL"] - n = 10 - - with pytest.raises(ValueError): - dsutils.resolution_shells(data, dhkl, n) \ No newline at end of file diff --git a/test/test_io.py b/test/test_io.py deleted file mode 100644 index 0f6835b..0000000 --- a/test/test_io.py +++ /dev/null @@ -1,26 +0,0 @@ - -import pathlib - -from meteor import io - -# todo this is probably the right way to do this -#PDB_FILE = pathlib.Path(__file__) / 'data' / 'dark.pdb' -PDB_FILE = './data/dark.pdb' - - -def test_get_pdbinfo(): - - unit_cell, space_group = io.get_pdbinfo(PDB_FILE) - - assert unit_cell == [51.99, 62.91, 72.03, 90.0, 90.0, 90.0], 'unit cell incorrect' - assert space_group == 'P212121', 'space group incorrect' - - # todo write a file without CRYST1 - # check to make sure that fails - - return - - -if __name__ == '__main__': - test_get_pdbinfo() - diff --git a/test/unit/test_validate.py b/test/unit/test_validate.py new file mode 100644 index 0000000..c34a15c --- /dev/null +++ b/test/unit/test_validate.py @@ -0,0 +1,11 @@ + +import numpy as np +from meteor import validate + +def test_negentropy_gaussian() -> None: + n_samples = 100 + samples = np.random.normal(size=n_samples) + negentropy = validate.negentropy(samples) + + # negentropy should be small for a Gaussian sample + assert np.abs(negentropy) < 1e-5