From db5af341b834c55332e6d88ff13cb1f8b62f882a Mon Sep 17 00:00:00 2001 From: dineshpinto Date: Fri, 4 Aug 2023 08:47:53 +0200 Subject: [PATCH] Added hyperparameter optimization and updated README --- README.md | 67 ++++++++++++++++++++++++---- poetry.lock | 40 ++++++++--------- qudi_hira_analysis/analysis_logic.py | 64 +++++++++++++++++++++++--- 3 files changed, 137 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 5c58a57..e4a4e57 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ pip install --upgrade qudi-hira-analysis If you are publishing scientific results, you can cite this work as: https://doi.org/10.5281/zenodo.7604670 -## Examples +## Usage First set up the `DataHandler` object (henceforth referred to as `dh`) with the correct paths to the data and figure folders. @@ -89,7 +89,56 @@ The `load_measurements` function returns a dictionary containing the measurement - The dictionary values are `MeasurementDataclass` objects whose schema is shown visually [here](#measurement-dataclass-schema). -### Example 0: NV-PL measurements +### Example 0: 2D NV-ODMR measurements + +```python +odmr_measurements = dh.load_measurements(measurement_str="2d_odmr_map") +odmr_measurements = dict(sorted(odmr_measurements.items())) + +# Optional: Try and optimize the hyperparameters for the ODMR fitting +highest_min_r2, optimal_parameters = dh.optimize_hyperparameters(odmr_measurements, num_samples=100, num_params=3) + +# Perform parallel (=num CPU cores) ODMR fitting +odmr_measurements = dh.raster_odmr_fitting( + odmr_measurements, + r2_thresh=0.95, + thresh_frac=0.5, + sigma_thresh_frac=0.1, + min_thresh=0.01, +) + +# Calculate residuals and 2D ODMR map +pixels = int(np.sqrt(len(odmr_measurements))) +image = np.zeros((pixels, pixels)) +residuals = np.zeros(len(odmr_measurements)) + +for idx, odmr in enumerate(odmr_measurements.values()): + row, col = odmr.xy_position + residuals[idx] = odmr.fit_model.rsquared + + if len(odmr.fit_model.params) == 6: + # Single Lorentzian, no splitting + image[row, col] = 0 + else: + if odmr.fit_model.rsquared < 0.95: + # Bad fit, set to NaN + image[row, col] = np.nan + else: + # Calculate splitting + splitting = np.abs(odmr.fit_model.best_values["l1_center"] - odmr.fit_model.best_values["l0_center"]) + image[row, col] = splitting + +fig, (ax, ax1) = plt.subplots(ncols=2) +# Plot residuals +sns.lineplot(residuals, ax=ax) +# Plot 2D ODMR map +sns.heatmap(image, cbar_kws={"label": r"$\Delta E$ (MHz)"}, ax=ax1) + +# Save the figure to the figure folder specified earlier +dh.save_figures(filepath="2d_odmr_map_with_residuals", fig=fig, only_jpg=True) +``` + +### Example 1: NV-PL measurements ```python pixel_scanner_measurements = dh.load_measurements(measurement_str="PixelScanner") @@ -114,7 +163,7 @@ cbar.set_label("NV-PL (kcps)") dh.save_figures(filepath="nv_pl_scan", fig=fig, only_jpg=True) ``` -### Example 1: Nanonis AFM measurements +### Example 2: Nanonis AFM measurements ```python afm_measurements = dh.load_measurements(measurement_str="Scan", extension=".sxm", qudi=False) @@ -142,7 +191,7 @@ cbar.set_label("Height (nm)") dh.save_figures(filepath="afm_topo", fig=fig, only_jpg=True) ``` -### Example 2: Autocorrelation measurements (Antibunching fit) +### Example 3: Autocorrelation measurements (Antibunching fit) ```python autocorrelation_measurements = dh.load_measurements(measurement_str="Autocorrelation") @@ -163,7 +212,7 @@ for autocorrelation in autocorrelation_measurements.values(): dh.save_figures(filepath="autocorrelation_variation", fig=fig) ``` -### Example 3: ODMR measurements (double Lorentzian fit) +### Example 4: ODMR measurements (double Lorentzian fit) ```python odmr_measurements = dh.load_measurements(measurement_str="ODMR", pulsed=True) @@ -179,7 +228,7 @@ for odmr in odmr_measurements.values(): dh.save_figures(filepath="odmr_variation", fig=fig) ``` -### Example 4: Rabi measurements (sine exponential decay fit) +### Example 5: Rabi measurements (sine exponential decay fit) ```python rabi_measurements = dh.load_measurements(measurement_str="Rabi", pulsed=True) @@ -195,7 +244,7 @@ for rabi in rabi_measurements.values(): dh.save_figures(filepath="rabi_variation", fig=fig) ``` -### Example 5: Temperature data +### Example 6: Temperature data ```python temperature_measurements = dh.load_measurements(measurement_str="Temperature", qudi=False) @@ -207,7 +256,7 @@ sns.lineplot(data=temperature, x="Time", y="Temperature", ax=ax) dh.save_figures(filepath="temperature_monitoring", fig=fig) ``` -### Example 6: PYS data (pi3diamond compatibility) +### Example 7: PYS data (pi3diamond compatibility) ```python pys_measurements = dh.load_measurements(measurement_str="ndmin", extension=".pys", qudi=False) @@ -218,7 +267,7 @@ sns.lineplot(x=pys["time_bins"], y=pys["counts"], ax=ax) dh.save_figures(filepath="pys_measurement", fig=fig) ``` -### Example 7: Bruker MFM data +### Example 8: Bruker MFM data ```python bruker_measurements = dh.load_measurements(measurement_str="", extension=".001", qudi=False) diff --git a/poetry.lock b/poetry.lock index 90a3524..37cd70f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. [[package]] name = "aiofiles" @@ -604,30 +604,30 @@ files = [ [[package]] name = "debugpy" -version = "1.6.8" +version = "1.6.7" description = "An implementation of the Debug Adapter Protocol for Python" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "debugpy-1.6.8-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:8c1f5a3286fb633f691c594649e9d2e8e30292c9eaf49e38d7da525151b33a83"}, - {file = "debugpy-1.6.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406b3a6cb7548d73260f69a511178ec9196779cafda68e563488c6f94cc88670"}, - {file = "debugpy-1.6.8-cp310-cp310-win32.whl", hash = "sha256:6830947f68b41cd6abe20941ec3303a8452c40ff5fe3637c6efe233e395ecebc"}, - {file = "debugpy-1.6.8-cp310-cp310-win_amd64.whl", hash = "sha256:1fe3baa28f5a14d8d2a60dded9ea088e27b33f1854ae9a0a1faa1ba03a8b7e47"}, - {file = "debugpy-1.6.8-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:5502e14de6b7241ecf7c4fa4ec6dd61d0824da7a09020c7ffe7be4cd09d36f24"}, - {file = "debugpy-1.6.8-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4a7193cec3f1e188963f6e8699e1187f758a0a4bbce511b3ad40caf618fc888"}, - {file = "debugpy-1.6.8-cp37-cp37m-win32.whl", hash = "sha256:591aac0e69bc75102d9f9294f1228e5d9ff9aa17b8c88e48b1bbb3dab8a54dcc"}, - {file = "debugpy-1.6.8-cp37-cp37m-win_amd64.whl", hash = "sha256:bb27b8e08f8e60705de6cf05b5da4c21e5a0bc2ca73f06fc36646f456df18ff5"}, - {file = "debugpy-1.6.8-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:6ca1c92e30e2aaeca156d5bd76e1587c23e332474a7b12e1900dd632b31ce05e"}, - {file = "debugpy-1.6.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:959f9b8181a4c544b067daff8d881cd3ac4c7aec1a3a4f41f81c529795b3d864"}, - {file = "debugpy-1.6.8-cp38-cp38-win32.whl", hash = "sha256:4172383b961a2334d29168c7f7b24f2f99d29291a945016986c78a5683fba915"}, - {file = "debugpy-1.6.8-cp38-cp38-win_amd64.whl", hash = "sha256:05d1b288167ce3bfc8e1912ebed036207a27b9569ae4476f18287902501689c6"}, - {file = "debugpy-1.6.8-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:95f7ce92450b72abcf0c479539a7d00c20e68f1f1fb447eef0b08d2a635d96d7"}, - {file = "debugpy-1.6.8-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f16bb157b6018ce6a23b64653a6b1892f046cc2b0576df1794c6b22f9fd82118"}, - {file = "debugpy-1.6.8-cp39-cp39-win32.whl", hash = "sha256:f7a80c50b89d8fb49c9e5b6ee28c0bfb822fbd33fef0f2f9843724d0d1984e4e"}, - {file = "debugpy-1.6.8-cp39-cp39-win_amd64.whl", hash = "sha256:2345beced3e79fd8ac4158e839a1604d5cccd19beb45561a1ffe2e5b33465f28"}, - {file = "debugpy-1.6.8-py2.py3-none-any.whl", hash = "sha256:1ca76d3ebb0e6368e107cf2e005e848d3c7705a5b513fdf65470a6f4e49a2de7"}, - {file = "debugpy-1.6.8.zip", hash = "sha256:3b7091d908dec70022b8966c32b1e9eaf183ff05291edf1d147fee153f4cb9f8"}, + {file = "debugpy-1.6.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b3e7ac809b991006ad7f857f016fa92014445085711ef111fdc3f74f66144096"}, + {file = "debugpy-1.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3876611d114a18aafef6383695dfc3f1217c98a9168c1aaf1a02b01ec7d8d1e"}, + {file = "debugpy-1.6.7-cp310-cp310-win32.whl", hash = "sha256:33edb4afa85c098c24cc361d72ba7c21bb92f501104514d4ffec1fb36e09c01a"}, + {file = "debugpy-1.6.7-cp310-cp310-win_amd64.whl", hash = "sha256:ed6d5413474e209ba50b1a75b2d9eecf64d41e6e4501977991cdc755dc83ab0f"}, + {file = "debugpy-1.6.7-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:38ed626353e7c63f4b11efad659be04c23de2b0d15efff77b60e4740ea685d07"}, + {file = "debugpy-1.6.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:279d64c408c60431c8ee832dfd9ace7c396984fd7341fa3116aee414e7dcd88d"}, + {file = "debugpy-1.6.7-cp37-cp37m-win32.whl", hash = "sha256:dbe04e7568aa69361a5b4c47b4493d5680bfa3a911d1e105fbea1b1f23f3eb45"}, + {file = "debugpy-1.6.7-cp37-cp37m-win_amd64.whl", hash = "sha256:f90a2d4ad9a035cee7331c06a4cf2245e38bd7c89554fe3b616d90ab8aab89cc"}, + {file = "debugpy-1.6.7-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5224eabbbeddcf1943d4e2821876f3e5d7d383f27390b82da5d9558fd4eb30a9"}, + {file = "debugpy-1.6.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bae1123dff5bfe548ba1683eb972329ba6d646c3a80e6b4c06cd1b1dd0205e9b"}, + {file = "debugpy-1.6.7-cp38-cp38-win32.whl", hash = "sha256:9cd10cf338e0907fdcf9eac9087faa30f150ef5445af5a545d307055141dd7a4"}, + {file = "debugpy-1.6.7-cp38-cp38-win_amd64.whl", hash = "sha256:aaf6da50377ff4056c8ed470da24632b42e4087bc826845daad7af211e00faad"}, + {file = "debugpy-1.6.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:0679b7e1e3523bd7d7869447ec67b59728675aadfc038550a63a362b63029d2c"}, + {file = "debugpy-1.6.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de86029696e1b3b4d0d49076b9eba606c226e33ae312a57a46dca14ff370894d"}, + {file = "debugpy-1.6.7-cp39-cp39-win32.whl", hash = "sha256:d71b31117779d9a90b745720c0eab54ae1da76d5b38c8026c654f4a066b0130a"}, + {file = "debugpy-1.6.7-cp39-cp39-win_amd64.whl", hash = "sha256:c0ff93ae90a03b06d85b2c529eca51ab15457868a377c4cc40a23ab0e4e552a3"}, + {file = "debugpy-1.6.7-py2.py3-none-any.whl", hash = "sha256:53f7a456bc50706a0eaabecf2d3ce44c4d5010e46dfc65b6b81a518b42866267"}, + {file = "debugpy-1.6.7.zip", hash = "sha256:c4c2f0810fa25323abfdfa36cbbbb24e5c3b1a42cb762782de64439c575d67f2"}, ] [[package]] diff --git a/qudi_hira_analysis/analysis_logic.py b/qudi_hira_analysis/analysis_logic.py index 86c4de9..c9af120 100644 --- a/qudi_hira_analysis/analysis_logic.py +++ b/qudi_hira_analysis/analysis_logic.py @@ -1,7 +1,9 @@ from __future__ import annotations import logging +import random import re +from itertools import product from typing import Tuple, TYPE_CHECKING import numpy as np @@ -287,13 +289,65 @@ def analyse_mean_norm( return signal_data, error_data - def raster_odmr_fitting( + def optimize_hyperparameters( self, + measurements: dict[str, MeasurementDataclass], + num_samples: int = 10, + num_params: int = 3, + ) -> Tuple[float, Tuple[float, float, float]]: + """ + This method optimizes the hyperparameters of the ODMR analysis. + It does so by randomly sampling a subset of the measurements and + then optimizing the hyperparameters for them. + + Args: + measurements: A dictionary of measurements to optimize the hyperparameters for. + num_params: The number of parameters to optimize. + num_samples: The number of measurements to sample. + + Returns: + The optimal hyperparameters. + """ + r2_threshs = np.around(np.linspace(start=0.9, stop=0.99, num=num_params), decimals=2) + thresh_fracs = np.around(np.linspace(start=0.5, stop=0.9, num=num_params), decimals=1) + sigma_thresh_fracs = np.around(np.linspace(start=0.1, stop=0.2, num=num_params), decimals=1) + + odmr_sample = {} + for k, v in random.sample(sorted(measurements.items()), k=num_samples): + odmr_sample[k] = v + + highest_min_r2 = 0 + optimal_params = (0, 0, 0) + + for idx, (r2_thresh, thresh_frac, sigma_thresh_frac) in enumerate( + product(r2_threshs, thresh_fracs, sigma_thresh_fracs)): + odmr_sample = self.raster_odmr_fitting( + odmr_sample, + r2_thresh=r2_thresh, + thresh_frac=thresh_frac, + sigma_thresh_frac=sigma_thresh_frac, + min_thresh=0.01, + progress_bar=False + ) + + r2s = np.zeros(len(odmr_sample)) + for _idx, odmr in enumerate(odmr_sample.values()): + r2s[_idx] = odmr.fit_model.rsquared + min_r2 = np.min(r2s) + + if highest_min_r2 < min_r2: + highest_min_r2 = min_r2 + optimal_params = (r2_thresh, thresh_frac, sigma_thresh_frac) + + return highest_min_r2, optimal_params + + @staticmethod + def raster_odmr_fitting( odmr_measurements: dict[str, MeasurementDataclass], r2_thresh: float = 0.95, - thresh_frac: float = 0.3, - min_thresh: float = 0.25, - sigma_thresh_frac: float = 0.3, + thresh_frac: float = 0.5, + sigma_thresh_frac: float = 0.15, + min_thresh: float = 0.01, extract_pixel_from_filename: bool = True, progress_bar: bool = True ) -> dict[str, MeasurementDataclass]: @@ -307,7 +361,7 @@ def raster_odmr_fitting( min_thresh: sigma_thresh_frac: extract_pixel_from_filename: Extract `(row, col)` (in this format) from filename - + progress_bar: Show progress bar Returns: List of ODMR data with fit, fit model and pixels in MeasurementDataclass """