From c3a883affe023208bb1c540c7b6a24da9e65b1a3 Mon Sep 17 00:00:00 2001 From: Sam Welborn Date: Fri, 13 Sep 2024 11:35:11 -0400 Subject: [PATCH] run `ruff format .` and `ruff --check . --fix` - also fix things that couldn't be fixed automatically --- docs/source/conf.py | 53 ++-- python/stempy/__init__.py | 2 +- python/stempy/contrib/__init__.py | 20 +- python/stempy/image/__init__.py | 404 +++++++++++++++++------------- python/stempy/io/__init__.py | 105 ++++---- python/stempy/io/compatibility.py | 15 +- python/stempy/io/sparse_array.py | 272 ++++++++++---------- python/stempy/utils.py | 8 +- setup.py | 38 +-- tests/conftest.py | 73 +++--- tests/test_image.py | 25 +- tests/test_simple.py | 10 +- tests/test_sparse_array.py | 189 +++++++------- 13 files changed, 620 insertions(+), 594 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 89e72477..8f78a3b9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,21 +12,19 @@ import os import sys - -import sphinx_rtd_theme from unittest import mock -sys.path.insert(0, os.path.abspath('../../python')) +sys.path.insert(0, os.path.abspath("../../python")) # -- Project information ----------------------------------------------------- -project = 'stempy' -copyright = '2020, Kitware, INC' -author = 'Kitware, INC' +project = "stempy" +copyright = "2020, Kitware, INC" +author = "Kitware, INC" # The full version, including alpha/beta/rc tags -release = '1.0' +release = "1.0" # -- General configuration --------------------------------------------------- @@ -34,20 +32,17 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx_rtd_theme', - 'recommonmark' -] +extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "recommonmark"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] + # Unfortunately, if we have a class that inherits sphinx's mock # object, and we use that class in a default argument, we run into # attribute errors. Fix it by defining an explicit mock of the class. @@ -55,8 +50,9 @@ class MockReader: class H5Format: Frame = 1 -sys.modules['stempy._io'] = mock.Mock() -_io_mock = sys.modules['stempy._io'] + +sys.modules["stempy._io"] = mock.Mock() +_io_mock = sys.modules["stempy._io"] # We have to override these so we get don't get conflicting metaclasses _io_mock._sector_reader = MockReader @@ -70,36 +66,33 @@ class H5Format: # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Get autodoc to mock these imports, because we don't actually need them # for generating the docs -autodoc_mock_imports = [ - 'stempy._image', - 'numpy', - 'h5py' -] +autodoc_mock_imports = ["stempy._image", "numpy", "h5py"] + # Modify this function to customize which classes/functions are skipped def autodoc_skip_member_handler(app, what, name, obj, skip, options): # Exclude any names that start with a string in this list - exclude_startswith_list = ['_'] + exclude_startswith_list = ["_"] - if any([name.startswith(x) for x in exclude_startswith_list]): + if any(name.startswith(x) for x in exclude_startswith_list): return True # Exclude any names that match a string in this list exclude_names = [ - 'ReaderMixin', - 'PyReader', - 'SectorReader', - 'get_hdf5_reader', - 'SectorThreadedMultiPassReader', + "ReaderMixin", + "PyReader", + "SectorReader", + "get_hdf5_reader", + "SectorThreadedMultiPassReader", ] return name in exclude_names @@ -109,4 +102,4 @@ def autodoc_skip_member_handler(app, what, name, obj, skip, options): def setup(app): # Connect the autodoc-skip-member event from apidoc to the callback # This allows us to customize which classes/functions are skipped - app.connect('autodoc-skip-member', autodoc_skip_member_handler) + app.connect("autodoc-skip-member", autodoc_skip_member_handler) diff --git a/python/stempy/__init__.py b/python/stempy/__init__.py index 9ebc766b..b775bd3a 100644 --- a/python/stempy/__init__.py +++ b/python/stempy/__init__.py @@ -3,5 +3,5 @@ __version__ = get_version() __all__ = [ - '__version__', + "__version__", ] diff --git a/python/stempy/contrib/__init__.py b/python/stempy/contrib/__init__.py index a63d6316..32065cb1 100644 --- a/python/stempy/contrib/__init__.py +++ b/python/stempy/contrib/__init__.py @@ -18,9 +18,7 @@ class FileSuffix(Enum): def check_only_one_filepath(file_path): if len(file_path) > 1: - raise ValueError( - "Multiple files match that input. Add scan_id to be more specific." - ) + raise ValueError("Multiple files match that input. Add scan_id to be more specific.") elif len(file_path) == 0: raise FileNotFoundError("No file with those parameters can be found.") @@ -125,13 +123,9 @@ def get_scan_path_version_1( if file_suffix == FileSuffix.STANDARD: file_pattern = f"FOURD_*_{str(scan_id).zfill(5)}_?????.{filetype}" else: - file_pattern = ( - f"FOURD_*_{str(scan_id).zfill(5)}_?????{file_suffix.value}.{filetype}" - ) + file_pattern = f"FOURD_*_{str(scan_id).zfill(5)}_?????{file_suffix.value}.{filetype}" elif scan_num is not None: - file_pattern = ( - f"FOURD_*_*_{str(scan_num).zfill(5)}{file_suffix.value}.{filetype}" - ) + file_pattern = f"FOURD_*_*_{str(scan_num).zfill(5)}{file_suffix.value}.{filetype}" file_path = list(directory.glob(file_pattern)) check_only_one_filepath(file_path) @@ -188,9 +182,7 @@ def get_scan_path( th=th, file_suffix=file_suffix, ), - 1: lambda: get_scan_path_version_1( - directory, scan_num=scan_num, scan_id=scan_id, file_suffix=file_suffix - ), + 1: lambda: get_scan_path_version_1(directory, scan_num=scan_num, scan_id=scan_id, file_suffix=file_suffix), } for ver in versions_to_try: @@ -202,6 +194,4 @@ def get_scan_path( except FileNotFoundError: continue - raise FileNotFoundError( - "No file with those parameters can be found for any version." - ) + raise FileNotFoundError("No file with those parameters can be found for any version.") diff --git a/python/stempy/image/__init__.py b/python/stempy/image/__init__.py index 1909c89f..87c497d3 100644 --- a/python/stempy/image/__init__.py +++ b/python/stempy/image/__init__.py @@ -1,17 +1,28 @@ import warnings +import h5py +import numpy as np + from stempy import _image -from stempy import _io from stempy.io import ( - get_hdf5_reader, ReaderMixin, PyReader, SectorThreadedReader, - SectorThreadedMultiPassReader, SparseArray + PyReader, + ReaderMixin, + SectorThreadedMultiPassReader, + SectorThreadedReader, + SparseArray, + get_hdf5_reader, ) -import h5py -import numpy as np -def create_stem_images(input, inner_radii, outer_radii, scan_dimensions=(0, 0), - center=(-1, -1), frame_dimensions=None, frame_offset=0): +def create_stem_images( + input, + inner_radii, + outer_radii, + scan_dimensions=(0, 0), + center=(-1, -1), + frame_dimensions=None, + frame_offset=0, +): """Create a series of stem images from the input. :param input: the file reader that has already opened the data, or an @@ -57,52 +68,75 @@ def create_stem_images(input, inner_radii, outer_radii, scan_dimensions=(0, 0), outer_radii = [outer_radii] # Electron counted data attributes - ecd_attrs = ['data', 'scan_dimensions', 'frame_dimensions'] + ecd_attrs = ["data", "scan_dimensions", "frame_dimensions"] if isinstance(input, h5py._hl.files.File): # Handle h5py file input = get_hdf5_reader(input) - imgs = _image.create_stem_images(input.begin(), input.end(), - inner_radii, outer_radii, - scan_dimensions, center) + imgs = _image.create_stem_images( + input.begin(), + input.end(), + inner_radii, + outer_radii, + scan_dimensions, + center, + ) elif issubclass(type(input), ReaderMixin): # Handle standard reader - imgs = _image.create_stem_images(input.begin(), input.end(), - inner_radii, outer_radii, - scan_dimensions, center) - elif hasattr(input, '_electron_counted_data'): + imgs = _image.create_stem_images( + input.begin(), + input.end(), + inner_radii, + outer_radii, + scan_dimensions, + center, + ) + elif hasattr(input, "_electron_counted_data"): # Handle electron counted data with C++ object # This could also be a SparseArray created via electron_count() - imgs = _image.create_stem_images(input._electron_counted_data, - inner_radii, outer_radii, center) + imgs = _image.create_stem_images(input._electron_counted_data, inner_radii, outer_radii, center) elif isinstance(input, SparseArray): - imgs = _image.create_stem_images(input.data, inner_radii, outer_radii, - input.scan_shape[::-1], - input.frame_shape, center) - elif all([hasattr(input, x) for x in ecd_attrs]): + imgs = _image.create_stem_images( + input.data, + inner_radii, + outer_radii, + input.scan_shape[::-1], + input.frame_shape, + center, + ) + elif all(hasattr(input, x) for x in ecd_attrs): # Handle electron counted data without C++ object # Assume this is v1 data and that each frame is one scan position data = input.data if data.ndim == 1: data = data[:, np.newaxis] - imgs = _image.create_stem_images(input.data, inner_radii, outer_radii, - input.scan_dimensions, - input.frame_dimensions, center) + imgs = _image.create_stem_images( + input.data, + inner_radii, + outer_radii, + input.scan_dimensions, + input.frame_dimensions, + center, + ) elif isinstance(input, np.ndarray): # The presence of frame dimensions implies it is sparse data if frame_dimensions is not None: # Handle sparse data if input.ndim == 1: input = input[:, np.newaxis] - imgs = _image.create_stem_images(input, inner_radii, outer_radii, - scan_dimensions, frame_dimensions, - center) + imgs = _image.create_stem_images( + input, + inner_radii, + outer_radii, + scan_dimensions, + frame_dimensions, + center, + ) else: # Handle raw data # Make sure the scan dimensions were passed if not scan_dimensions or scan_dimensions == (0, 0): - msg = ('scan_dimensions must be provided for np.ndarray ' - 'raw data input') + msg = "scan_dimensions must be provided for np.ndarray " "raw data input" raise Exception(msg) # Should have shape (num_images, frame_height, frame_width) @@ -111,19 +145,22 @@ def create_stem_images(input, inner_radii, outer_radii, scan_dimensions=(0, 0), block_size = 32 reader = PyReader(input, image_numbers, scan_dimensions, block_size, num_images) - imgs = _image.create_stem_images(reader.begin(), reader.end(), - inner_radii, outer_radii, - scan_dimensions, center) + imgs = _image.create_stem_images( + reader.begin(), + reader.end(), + inner_radii, + outer_radii, + scan_dimensions, + center, + ) else: - raise Exception('Type of input, ' + str(type(input)) + - ', is not known to stempy.image.create_stem_images()') + raise Exception("Type of input, " + str(type(input)) + ", is not known to stempy.image.create_stem_images()") images = [np.array(img, copy=False) for img in imgs] return np.array(images, copy=False) -def create_stem_histogram(numBins, reader, inner_radii, - outer_radii, scan_dimensions=(0, 0), - center=(-1, -1)): + +def create_stem_histogram(numBins, reader, inner_radii, outer_radii, scan_dimensions=(0, 0), center=(-1, -1)): """Create a histogram of the stem images generated from the input. :param numBins: the number of bins the histogram should have. @@ -149,9 +186,7 @@ def create_stem_histogram(numBins, reader, inner_radii, :rtype: a tuple of length 2 of lists """ # create stem images - imgs = _image.create_stem_images(reader.begin(), reader.end(), - inner_radii, outer_radii, - scan_dimensions, center) + imgs = _image.create_stem_images(reader.begin(), reader.end(), inner_radii, outer_radii, scan_dimensions, center) allBins = [] allFreqs = [] @@ -165,6 +200,7 @@ def create_stem_histogram(numBins, reader, inner_radii, return allBins, allFreqs + class ImageArray(np.ndarray): def __new__(cls, array, dtype=None, order=None): obj = np.asarray(array, dtype=dtype, order=order).view(cls) @@ -172,8 +208,9 @@ def __new__(cls, array, dtype=None, order=None): return obj def __array_finalize__(self, obj): - if obj is None: return - self._image = getattr(obj, '_image', None) + if obj is None: + return + self._image = getattr(obj, "_image", None) def calculate_average(reader): @@ -185,18 +222,26 @@ def calculate_average(reader): :return: The average image. :rtype: stempy.image.ImageArray """ - image = _image.calculate_average(reader.begin(), reader.end()) - img = ImageArray(np.array(image, copy = False)) + image = _image.calculate_average(reader.begin(), reader.end()) + img = ImageArray(np.array(image, copy=False)) img._image = image return img -def electron_count(reader, darkreference=None, number_of_samples=40, - background_threshold_n_sigma=4, xray_threshold_n_sigma=10, - threshold_num_blocks=1, scan_dimensions=(0, 0), - verbose=False, gain=None, apply_row_dark=False, - apply_row_dark_use_mean=True): +def electron_count( + reader, + darkreference=None, + number_of_samples=40, + background_threshold_n_sigma=4, + xray_threshold_n_sigma=10, + threshold_num_blocks=1, + scan_dimensions=(0, 0), + verbose=False, + gain=None, + apply_row_dark=False, + apply_row_dark_use_mean=True, +): """Generate a list of coordinates of electron hits. :param reader: the file reader that has already opened the data. @@ -238,11 +283,11 @@ def electron_count(reader, darkreference=None, number_of_samples=40, # Invert, as we will multiply in C++ # It also must be a float32 gain = np.power(gain, -1) - gain = _safe_cast(gain, np.float32, 'gain') + gain = _safe_cast(gain, np.float32, "gain") if isinstance(darkreference, np.ndarray): # Must be float32 for correct conversions - darkreference = _safe_cast(darkreference, np.float32, 'dark reference') + darkreference = _safe_cast(darkreference, np.float32, "dark reference") # Special case for threaded reader if isinstance(reader, (SectorThreadedReader, SectorThreadedMultiPassReader)): @@ -262,25 +307,28 @@ def electron_count(reader, darkreference=None, number_of_samples=40, data = _image.electron_count(reader, options) else: deprecation_message = ( - 'Using a reader in electron_count() that is not a ' - 'SectorThreadedReader or a SectorThreadedMultiPassReader is ' - 'deprecated in stempy==1.1 and will be removed in stempy==1.2' + "Using a reader in electron_count() that is not a " + "SectorThreadedReader or a SectorThreadedMultiPassReader is " + "deprecated in stempy==1.1 and will be removed in stempy==1.2" ) - warnings.warn(deprecation_message, category=DeprecationWarning, - stacklevel=2) + warnings.warn(deprecation_message, category=DeprecationWarning, stacklevel=2) blocks = [] - for i in range(threshold_num_blocks): + for _ in range(threshold_num_blocks): blocks.append(next(reader)) - if darkreference is not None and hasattr(darkreference, '_image'): + if darkreference is not None and hasattr(darkreference, "_image"): darkreference = darkreference._image args = [[b._block for b in blocks]] if darkreference is not None: args.append(darkreference) - args = args + [number_of_samples, background_threshold_n_sigma, xray_threshold_n_sigma] + args = args + [ + number_of_samples, + background_threshold_n_sigma, + xray_threshold_n_sigma, + ] if gain is not None: args.append(gain) @@ -291,21 +339,20 @@ def electron_count(reader, darkreference=None, number_of_samples=40, xray_threshold = res.xray_threshold if verbose: - print('****Statistics for calculating electron thresholds****') - print('number of samples:', res.number_of_samples) - print('min sample:', res.min_sample) - print('max sample:', res.max_sample) - print('mean:', res.mean) - print('variance:', res.variance) - print('std dev:', res.std_dev) - print('number of bins:', res.number_of_bins) - print('x-ray threshold n sigma:', res.xray_threshold_n_sigma) - print('background threshold n sigma:', - res.background_threshold_n_sigma) - print('optimized mean:', res.optimized_mean) - print('optimized std dev:', res.optimized_std_dev) - print('background threshold:', background_threshold) - print('xray threshold:', xray_threshold) + print("****Statistics for calculating electron thresholds****") + print("number of samples:", res.number_of_samples) + print("min sample:", res.min_sample) + print("max sample:", res.max_sample) + print("mean:", res.mean) + print("variance:", res.variance) + print("std dev:", res.std_dev) + print("number of bins:", res.number_of_bins) + print("x-ray threshold n sigma:", res.xray_threshold_n_sigma) + print("background threshold n sigma:", res.background_threshold_n_sigma) + print("optimized mean:", res.optimized_mean) + print("optimized std dev:", res.optimized_std_dev) + print("background threshold:", background_threshold) + print("xray threshold:", xray_threshold) # Reset the reader reader.reset() @@ -336,10 +383,10 @@ def electron_count(reader, darkreference=None, number_of_samples=40, metadata = _electron_counted_metadata_to_dict(data.metadata) kwargs = { - 'data': np_data, - 'scan_shape': data.scan_dimensions[::-1], - 'frame_shape': data.frame_dimensions, - 'metadata': {'electron_counting': metadata}, + "data": np_data, + "scan_shape": data.scan_dimensions[::-1], + "frame_shape": data.frame_dimensions, + "metadata": {"electron_counting": metadata}, } array = SparseArray(**kwargs) @@ -348,6 +395,7 @@ def electron_count(reader, darkreference=None, number_of_samples=40, return array + def radial_sum(reader, center=(-1, -1), scan_dimensions=(0, 0)): """Generate a radial sum from which STEM images can be generated. @@ -367,8 +415,7 @@ def radial_sum(reader, center=(-1, -1), scan_dimensions=(0, 0)): :rtype: numpy.ndarray """ - sum = _image.radial_sum(reader.begin(), reader.end(), scan_dimensions, - center) + sum = _image.radial_sum(reader.begin(), reader.end(), scan_dimensions, center) return np.array(sum, copy=False) @@ -398,26 +445,28 @@ def maximum_diffraction_pattern(reader, darkreference=None): def com_sparse(array, crop_to=None, init_center=None, replace_nans=True): """Compute center of mass (COM) for counted data directly from sparse (single) - electron data. Empty frames will have COM set to NAN. There is an option to crop to a - smaller region around the initial full frame COM to improve finding the center of the zero beam. If the - cropped region has no counts in it then the frame is considered empty and COM will be NAN. - - :param array: A SparseArray of the electron counted data - :type array: stempy.io.SparseArray - :param crop_to: optional; The size of the region to crop around initial full frame COM for improved COM near - the zero beam - :type crop_to: tuple of ints of length 2 - :param init_center: optional; The initial center to use before cropping. If this is not set then cropping will be applie around - the center of mass of the each full frame. - :type init_center: tuple of ints of length 2 - :param replace_nans: If true (default) empty frames will have their center of mass set as the mean of the center of mass - of the the entire data set. If this is False they will be set to np.NAN. - :type replace_nans: bool - :return: The center of mass in X and Y for each scan position. If a position - has no data (len(electron_counts) == 0) then the center of the - frame is used as the center of mass. - :rtype: numpy.ndarray (2D) - """ + electron data. Empty frames will have COM set to NAN. There is an option to crop to a + smaller region around the initial full frame COM to improve finding the center of the zero beam. If the + cropped region has no counts in it then the frame is considered empty and COM will be NAN. + + :param array: A SparseArray of the electron counted data + :type array: stempy.io.SparseArray + :param crop_to: optional; The size of the region to crop around initial full frame COM for improved COM near + the zero beam + :type crop_to: tuple of ints of length 2 + :param init_center: optional; The initial center to use before cropping. + If this is not set then cropping will be applie around + the center of mass of the each full frame. + :type init_center: tuple of ints of length 2 + :param replace_nans: If true (default) empty frames will have their center + of mass set as the mean of the center of mass + of the the entire data set. If this is False they will be set to np.NAN. + :type replace_nans: bool + :return: The center of mass in X and Y for each scan position. If a position + has no data (len(electron_counts) == 0) then the center of the + frame is used as the center of mass. + :rtype: numpy.ndarray (2D) + """ com = np.zeros((2, array.num_scans), np.float32) for scan_position, frames in enumerate(array.data): # Combine the sparse arrays into one array. @@ -428,7 +477,7 @@ def com_sparse(array, crop_to=None, init_center=None, replace_nans=True): x = ev // array.frame_shape[0] y = ev % array.frame_shape[1] mm0 = len(ev) - + if init_center is None: # Initialize center as full frame COM comx0 = np.sum(x) / mm0 @@ -436,11 +485,15 @@ def com_sparse(array, crop_to=None, init_center=None, replace_nans=True): else: comx0 = init_center[0] comy0 = init_center[1] - + if crop_to is not None: # Crop around the initial center - keep = (x > comx0 - crop_to[0]) & (x <= comx0 + crop_to[0]) & (y > comy0 - crop_to[1]) & ( - y <= comy0 + crop_to[1]) + keep = ( + (x > comx0 - crop_to[0]) + & (x <= comx0 + crop_to[0]) + & (y > comy0 - crop_to[1]) + & (y <= comy0 + crop_to[1]) + ) x = x[keep] y = y[keep] mm = len(x) @@ -457,23 +510,26 @@ def com_sparse(array, crop_to=None, init_center=None, replace_nans=True): # Center of mass of the full frame comx = np.sum(x) / mm0 comy = np.sum(y) / mm0 - - com[:, scan_position] = (comy, comx) # save the comx and comy. Needs to be reversed (comy, comx) + + com[:, scan_position] = ( + comy, + comx, + ) # save the comx and comy. Needs to be reversed (comy, comx) else: com[:, scan_position] = (np.nan, np.nan) # empty frame - + com = com.reshape((2, *array.scan_shape)) - + if replace_nans: - com_mean = np.nanmean(com, axis=(1,2)) - np.nan_to_num(com[0,:,:], nan=com_mean[0], copy=False) - np.nan_to_num(com[1,:,:], nan=com_mean[1], copy=False) - + com_mean = np.nanmean(com, axis=(1, 2)) + np.nan_to_num(com[0, :, :], nan=com_mean[0], copy=False) + np.nan_to_num(com[1, :, :], nan=com_mean[1], copy=False) + return com def filter_bad_sectors(com, cut_off): - """ Missing sectors of data can greatly affect the center of mass calculation. This + """Missing sectors of data can greatly affect the center of mass calculation. This function attempts to fix that issue. Usually, the problem will be evident by a bimodal histogram for the horizontal center of mass. Set the cut_off to a value between the real values and the outliers. cut_off is a tuple where values below cut_off[0] @@ -493,10 +549,11 @@ def filter_bad_sectors(com, cut_off): x, y = np.where(_) for ii, jj in zip(x, y): - com[0, ii, jj] = np.nanmedian(com[0, ii - 1:ii + 2, jj - 1:jj + 2]) - com[1, ii, jj] = np.nanmedian(com[1, ii - 1:ii + 2, jj - 1:jj + 2]) + com[0, ii, jj] = np.nanmedian(com[0, ii - 1 : ii + 2, jj - 1 : jj + 2]) + com[1, ii, jj] = np.nanmedian(com[1, ii - 1 : ii + 2, jj - 1 : jj + 2]) return com + def com_dense(frames): """Compute the center of mass for a set of dense 2D frames. @@ -511,7 +568,7 @@ def com_dense(frames): if frames.ndim == 2: frames = frames[None, :, :] - YY, XX = np.mgrid[0:frames.shape[1], 0:frames.shape[2]] + YY, XX = np.mgrid[0 : frames.shape[1], 0 : frames.shape[2]] com = np.zeros((2, frames.shape[0]), dtype=np.float64) for ii, dp in enumerate(frames): mm = dp.sum() @@ -534,7 +591,7 @@ def calculate_sum_sparse(electron_counts, frame_dimensions): :return: A summed diffraction pattern. :rtype: numpy.ndarray """ - dp = np.zeros((frame_dimensions[0] * frame_dimensions[1]), '>> sp = stempy.io.load_electron_counts('file.h5') >>> df2 = stempy.image.virtual_darkfield(sp, (288, 260), (288, 160), (10, 10)) # 2 apertures >>> df1 = stempy.image.virtual_darkfield(sp, 260, 160, 10) # 1 aperture - + """ - + # Change to iterable if single value if isinstance(centers_x, (int, float)): - centers_x = (centers_x,) + centers_x = (centers_x,) if isinstance(centers_y, (int, float)): - centers_y = (centers_y,) + centers_y = (centers_y,) if isinstance(radii, (int, float)): - radii = (radii,) + radii = (radii,) rs_image = np.zeros((array.shape[0] * array.shape[1],), dtype=array.dtype) for ii, events in enumerate(array.data): @@ -698,33 +756,35 @@ def virtual_darkfield(array, centers_x, centers_y, radii): ev_rows = ev // array.frame_shape[0] ev_cols = ev % array.frame_shape[1] for cc_0, cc_1, rr in zip(centers_x, centers_y, radii): - dist = np.sqrt((ev_rows - cc_1)**2 + (ev_cols - cc_0)**2) + dist = np.sqrt((ev_rows - cc_1) ** 2 + (ev_cols - cc_0) ** 2) rs_image[ii] += len(np.where(dist < rr)[0]) rs_image = rs_image.reshape(array.scan_shape) - + return rs_image + def plot_virtual_darkfield(image, centers_x, centers_y, radii, axes=None): """Plot circles on the diffraction pattern corresponding to the position and size of virtual dark field apertures. - This has the same center and radii inputs as stempy.image.virtual_darkfield so users can check their input is physically correct. - + This has the same center and radii inputs as stempy.image.virtual_darkfield + so users can check their input is physically correct. + :param image: The diffraction pattern to plot over :type image: np.ndarray, 2D - + :param centers_x: The center of each round aperture as the row locations :type centers_x: iterable - + :param centers_y: The center of each round aperture as the column locations :type centers_y: iterable - + :param radii: The radius of each aperture. :type radii: iterable - + :param axes: A matplotlib axes instance to use for the plotting. If None then a new plot is created. :type axes: matplotlib.axes._subplots.AxesSubplot - + :rtype: matplotlib.axes._subplots.AxesSubplot - + :example: >>> sp = stempy.io.load_electron_counts('file.h5') >>> stempy.image.plot_virtual_darkfield(sp.sum(axis=(0, 1), 260, 160, 10) # 1 aperture @@ -735,33 +795,35 @@ def plot_virtual_darkfield(image, centers_x, centers_y, radii, axes=None): # Change to iterable if single value if isinstance(centers_x, (int, float)): - centers_x = (centers_x,) + centers_x = (centers_x,) if isinstance(centers_y, (int, float)): - centers_y = (centers_y,) + centers_y = (centers_y,) if isinstance(radii, (int, float)): - radii = (radii,) - + radii = (radii,) + if not axes: fg, axes = plt.subplots(1, 1) - - axes.imshow(image, cmap='magma', norm=LogNorm()) - + + axes.imshow(image, cmap="magma", norm=LogNorm()) + # Place a circle at each apertue location for cc_0, cc_1, rr in zip(centers_x, centers_y, radii): - C = Circle((cc_0, cc_1), rr, fc='none', ec='c') + C = Circle((cc_0, cc_1), rr, fc="none", ec="c") axes.add_patch(C) - + return axes + def mask_real_space(array, mask): """Calculate a diffraction pattern from an arbitrary set of positions defined in a mask in real space - + :param array: The sparse dataset :type array: SparseArray - - :param mask: The mask to apply with 0 for probe positions to ignore and 1 for probe positions to include in the sum. Must have the same scan shape as array + + :param mask: The mask to apply with 0 for probe positions to ignore and 1 for probe positions to include in the sum. + Must have the same scan shape as array :type mask: np.ndarray - + :rtype: np.ndarray """ assert array.scan_shape[0] == mask.shape[0] and array.scan_shape[1] == mask.shape[1] diff --git a/python/stempy/io/__init__.py b/python/stempy/io/__init__.py index 22ffc9a3..214ae51e 100644 --- a/python/stempy/io/__init__.py +++ b/python/stempy/io/__init__.py @@ -1,26 +1,25 @@ from collections import namedtuple -import numpy as np + import h5py +import numpy as np -from stempy._io import ( - _reader, _sector_reader, _pyreader, _threaded_reader, - _threaded_multi_pass_reader -) +from stempy._io import _pyreader, _reader, _sector_reader, _threaded_multi_pass_reader, _threaded_reader # For exporting SparseArray from .sparse_array import SparseArray -COMPILED_WITH_HDF5 = hasattr(_sector_reader, 'H5Format') +COMPILED_WITH_HDF5 = hasattr(_sector_reader, "H5Format") -class FileVersion(object): +class FileVersion: VERSION1 = 1 VERSION2 = 2 VERSION3 = 3 VERSION4 = 4 VERSION5 = 5 -class ReaderMixin(object): + +class ReaderMixin: def __iter__(self): return self @@ -37,16 +36,16 @@ def read(self): :return: The block of data that was read. Includes the header also. :rtype: Block (named tuple with fields 'header' and 'data') """ - b = super(ReaderMixin, self).read() + b = super().read() # We are at the end of the stream if b.header.version == 0: return None - block = namedtuple('Block', ['header', 'data']) + block = namedtuple("Block", ["header", "data"]) block._block = b block.header = b.header - block.data = np.array(b, copy = False) + block.data = np.array(b, copy=False) return block @@ -54,15 +53,19 @@ def read(self): class Reader(ReaderMixin, _reader): pass + class SectorReader(ReaderMixin, _sector_reader): pass + class PyReader(ReaderMixin, _pyreader): pass + class SectorThreadedReader(ReaderMixin, _threaded_reader): pass + class SectorThreadedMultiPassReader(ReaderMixin, _threaded_multi_pass_reader): @property def num_frames_per_scan(self): @@ -100,12 +103,8 @@ def read_frames(self, scan_position, frames_slice=None): if isinstance(scan_position, (list, tuple)): # Unravel the scan position scan_shape = self.scan_shape - if (any(not 0 <= scan_position[i] < scan_shape[i] - for i in range(len(scan_position)))): - raise IndexError( - f'Invalid position {scan_position} ' - f'for scan_shape {scan_shape}' - ) + if any(not 0 <= scan_position[i] < scan_shape[i] for i in range(len(scan_position))): + raise IndexError(f"Invalid position {scan_position} " f"for scan_shape {scan_shape}") image_number = scan_position[0] * scan_shape[1] + scan_position[1] else: @@ -128,18 +127,15 @@ def read_frames(self, scan_position, frames_slice=None): # Slice into the frames object try: frames = frames[frames_slice] - except IndexError: - msg = ( - f'frames_slice "{frames_slice}" is invalid for ' - f'num_frames "{num_frames}"' - ) - raise IndexError(msg) + except IndexError as e: + msg = f'frames_slice "{frames_slice}" is invalid for ' f'num_frames "{num_frames}"' + raise IndexError(msg) from e blocks = [] raw_blocks = self._load_frames(image_number, frames) for b in raw_blocks: - block = namedtuple('Block', ['header', 'data']) + block = namedtuple("Block", ["header", "data"]) block._block = b block.header = b.header block.data = np.array(b, copy=False)[0] @@ -150,22 +146,24 @@ def read_frames(self, scan_position, frames_slice=None): def get_hdf5_reader(h5file): # the initialization is at the io.cpp - dset_frame=h5file['frames'] - dset_frame_shape=dset_frame.shape - totalImgNum=dset_frame_shape[0] - scan_dimensions = dset_frame.attrs.get('scan_dimensions') + dset_frame = h5file["frames"] + dset_frame_shape = dset_frame.shape + totalImgNum = dset_frame_shape[0] + scan_dimensions = dset_frame.attrs.get("scan_dimensions") if scan_dimensions is None: # Must be an older file. Give a warning and fall back to the shape. - print('WARNING: "scan_dimensions" not found on "/frames"', - '(which may imply an older file is being loaded).', - 'Falling back to the shape of "stem/images"') - dset_stem_shape = h5file['stem/images'].shape + print( + 'WARNING: "scan_dimensions" not found on "/frames"', + "(which may imply an older file is being loaded).", + 'Falling back to the shape of "stem/images"', + ) + dset_stem_shape = h5file["stem/images"].shape scan_dimensions = (dset_stem_shape[2], dset_stem_shape[1]) - blocksize=32 + blocksize = 32 # construct the consecutive image_numbers if there is no scan_positions data set in hdf5 file - if("scan_positions" in h5file): - image_numbers = h5file['scan_positions'] + if "scan_positions" in h5file: + image_numbers = h5file["scan_positions"] else: image_numbers = np.arange(totalImgNum) @@ -188,14 +186,14 @@ def reader(path, version=FileVersion.VERSION1, backend=None, **options): :rtype: stempy.io.Reader, stempy.io.SectorReader, or stempy.io.PyReader """ # check if the input is the hdf5 dataset - if(isinstance(path, h5py._hl.files.File)): + if isinstance(path, h5py._hl.files.File): reader = get_hdf5_reader(path) elif version in [FileVersion.VERSION4, FileVersion.VERSION5]: - if backend == 'thread': + if backend == "thread": reader = SectorThreadedReader(path, version, **options) - elif backend == 'multi-pass': + elif backend == "multi-pass": if version != FileVersion.VERSION5: - raise Exception('The multi pass threaded reader only support file verison 5') + raise Exception("The multi pass threaded reader only support file verison 5") reader = SectorThreadedMultiPassReader(path, **options) elif backend is None: @@ -208,8 +206,8 @@ def reader(path, version=FileVersion.VERSION1, backend=None, **options): return reader -def save_raw_data(path, data, scan_dimensions=None, scan_positions=None, - zip_data=False): + +def save_raw_data(path, data, scan_dimensions=None, scan_positions=None, zip_data=False): """Save the raw data to an HDF5 file. :param path: path to the HDF5 file. @@ -234,20 +232,20 @@ def save_raw_data(path, data, scan_dimensions=None, scan_positions=None, if rdcc_nbytes < chunk_size: rdcc_nbytes = chunk_size - with h5py.File(path, 'a', rdcc_nbytes=rdcc_nbytes) as f: + with h5py.File(path, "a", rdcc_nbytes=rdcc_nbytes) as f: if zip_data: # Make each chunk the size of a frame chunk_shape = (1, data.shape[1], data.shape[2]) - frames = f.create_dataset('frames', data=data, compression='gzip', - chunks=chunk_shape) + frames = f.create_dataset("frames", data=data, compression="gzip", chunks=chunk_shape) else: - frames = f.create_dataset('frames', data=data) + frames = f.create_dataset("frames", data=data) if scan_dimensions is not None: - frames.attrs['scan_dimensions'] = scan_dimensions + frames.attrs["scan_dimensions"] = scan_dimensions if scan_positions is not None: - f.create_dataset('scan_positions', data=scan_positions) + f.create_dataset("scan_positions", data=scan_positions) + def save_electron_counts(path, array): """Save the electron counted data to an HDF5 file. @@ -259,6 +257,7 @@ def save_electron_counts(path, array): """ array.write_to_hdf5(path) + def load_electron_counts(path, keep_flyback=True): """Load electron counted data from an HDF5 file. @@ -271,6 +270,7 @@ def load_electron_counts(path, keep_flyback=True): """ return SparseArray.from_hdf5(path, keep_flyback=keep_flyback) + def save_stem_images(outputFile, images, names): """Save STEM images to an HDF5 file. @@ -283,15 +283,16 @@ def save_stem_images(outputFile, images, names): :type names: a list of strings """ if len(images) != len(names): - raise Exception('`images` and `names` must be the same length!') + raise Exception("`images` and `names` must be the same length!") - with h5py.File(outputFile, 'a') as f: - stem_group = f.require_group('stem') - dataset = stem_group.create_dataset('images', data=images) - dataset.attrs['names'] = names + with h5py.File(outputFile, "a") as f: + stem_group = f.require_group("stem") + dataset = stem_group.create_dataset("images", data=images) + dataset.attrs["names"] = names if COMPILED_WITH_HDF5: + def write_hdf5(path, reader, format=SectorReader.H5Format.Frame): """write_hdf5(path, reader, format=SectorReader.H5Format.Frame) diff --git a/python/stempy/io/compatibility.py b/python/stempy/io/compatibility.py index 32ccc85e..3781cf94 100644 --- a/python/stempy/io/compatibility.py +++ b/python/stempy/io/compatibility.py @@ -1,8 +1,7 @@ import numpy as np -def convert_data_format(data, scan_positions, scan_shape, frame_shape, - from_version, to_version): +def convert_data_format(data, scan_positions, scan_shape, frame_shape, from_version, to_version): """Convert the data and return new (data, scan_positions) as a tuple""" if from_version == to_version: if to_version < 3: @@ -18,7 +17,7 @@ def convert_data_format(data, scan_positions, scan_shape, frame_shape, key = (from_version, to_version) if key not in func_map: - msg = f'Conversion not implemented: {from_version} => {to_version}' + msg = f"Conversion not implemented: {from_version} => {to_version}" raise NotImplementedError(msg) return func_map[key](data, scan_positions, scan_shape) @@ -52,10 +51,9 @@ def convert_data_v1_to_v2(data, scan_positions, scan_shape): # Resize the data and add on the extra rows new_size = ret.shape[0] + len(extra_rows) new_data = np.empty(new_size, dtype=object) - new_data[:ret.shape[0]] = ret - new_data[ret.shape[0]:] = extra_rows - new_scan_positions = np.append(scan_positions, - extra_scan_positions) + new_data[: ret.shape[0]] = ret + new_data[ret.shape[0] :] = extra_rows + new_scan_positions = np.append(scan_positions, extra_scan_positions) ret = new_data scan_positions = new_scan_positions @@ -64,8 +62,7 @@ def convert_data_v1_to_v2(data, scan_positions, scan_shape): def convert_data_v1_to_v3(data, scan_positions, scan_shape): - data, scan_positions = convert_data_v1_to_v2(data, scan_positions, - scan_shape) + data, scan_positions = convert_data_v1_to_v2(data, scan_positions, scan_shape) return convert_data_v2_to_v3(data, scan_positions, scan_shape) diff --git a/python/stempy/io/sparse_array.py b/python/stempy/io/sparse_array.py index cec4a04b..1671b807 100644 --- a/python/stempy/io/sparse_array.py +++ b/python/stempy/io/sparse_array.py @@ -1,8 +1,8 @@ -from collections.abc import Sequence import copy -from functools import wraps import inspect import sys +from collections.abc import Sequence +from functools import wraps import numpy as np @@ -12,9 +12,9 @@ def _format_axis(func): @wraps(func) def wrapper(self, axis=None, *args, **kwargs): - if axis == 'scan': + if axis == "scan": axis = self.scan_axes - elif axis == 'frame': + elif axis == "frame": axis = self.frame_axes if axis is None: @@ -23,6 +23,7 @@ def wrapper(self, axis=None, *args, **kwargs): axis = (axis,) axis = tuple(sorted(axis)) return func(self, axis, *args, **kwargs) + return wrapper @@ -35,7 +36,7 @@ def wrapper(self, *args, **kwargs): if ret is not None: return ret - _warning(f'performing full expansion for {name}') + _warning(f"performing full expansion for {name}") prev_sparse_slicing = self.sparse_slicing self.sparse_slicing = False try: @@ -43,6 +44,7 @@ def wrapper(self, *args, **kwargs): finally: self.sparse_slicing = prev_sparse_slicing return getattr(expanded, name)(*args, **kwargs) + return wrapper @@ -57,6 +59,7 @@ def wrapper(*args, **kwargs): _warning(f'"{key}" is not implemented for {name}') return func(*args, **kwargs) + return wrapper @@ -86,10 +89,19 @@ class SparseArray: and stride. Slicing returns either a new sparse array or a dense array depending on the user-defined settings. """ + VERSION = 3 - def __init__(self, data, scan_shape, frame_shape, dtype=np.uint32, - sparse_slicing=True, allow_full_expand=False, metadata=None): + def __init__( + self, + data, + scan_shape, + frame_shape, + dtype=np.uint32, + sparse_slicing=True, + allow_full_expand=False, + metadata=None, + ): """Initialize a sparse array. :param data: the sparse array data, where the outer array represents @@ -144,18 +156,18 @@ def __init__(self, data, scan_shape, frame_shape, dtype=np.uint32, def _validate(self): if self.data.ndim != 2: msg = ( - 'The data must have 2 dimensions, but instead has ' - f'{self.data.ndim}. The outer dimension should be the ' - 'scan position, and the inner dimension should be the ' - 'frame index at that scan position' + "The data must have 2 dimensions, but instead has " + f"{self.data.ndim}. The outer dimension should be the " + "scan position, and the inner dimension should be the " + "frame index at that scan position" ) raise Exception(msg) if len(self.data) != self.num_scans: msg = ( - f'The length of the data array ({len(self.data)}) must be ' - f'equal to the number of scans ({self.num_scans}), which ' - f'is computed via np.prod(scan_shape)' + f"The length of the data array ({len(self.data)}) must be " + f"equal to the number of scans ({self.num_scans}), which " + f"is computed via np.prod(scan_shape)" ) raise Exception(msg) @@ -163,9 +175,9 @@ def _validate(self): for j, frame in enumerate(scan_frames): if not np.issubdtype(frame.dtype, np.integer): msg = ( - f'All frames of data must have integral dtype, but ' - f'scan_position={i} frame={j} has a dtype of ' - f'{frame.dtype}' + f"All frames of data must have integral dtype, but " + f"scan_position={i} frame={j} has a dtype of " + f"{frame.dtype}" ) raise Exception(msg) @@ -185,38 +197,38 @@ def from_hdf5(cls, filepath, keep_flyback=True, **init_kwargs): """ import h5py - with h5py.File(filepath, 'r') as f: - version = f.attrs.get('version', 1) + with h5py.File(filepath, "r") as f: + version = f.attrs.get("version", 1) + + frames = f["electron_events/frames"] + scan_positions_group = f["electron_events/scan_positions"] + scan_shape = [scan_positions_group.attrs[x] for x in ["Nx", "Ny"]] + frame_shape = [frames.attrs[x] for x in ["Nx", "Ny"]] - frames = f['electron_events/frames'] - scan_positions_group = f['electron_events/scan_positions'] - scan_shape = [scan_positions_group.attrs[x] for x in ['Nx', 'Ny']] - frame_shape = [frames.attrs[x] for x in ['Nx', 'Ny']] - if keep_flyback: - data = frames[()] # load the full data set + data = frames[()] # load the full data set scan_positions = scan_positions_group[()] else: # Generate the original scan indices from the scan_shape - orig_indices = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) - # Remove the indices of the last column - crop_indices = np.delete(orig_indices, orig_indices[scan_shape[0]-1::scan_shape[0]]) + orig_indices = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)], scan_shape) + # Remove the indices of the last column + crop_indices = np.delete(orig_indices, orig_indices[scan_shape[0] - 1 :: scan_shape[0]]) # Load only the data needed data = frames[crop_indices] # Reduce the column shape by 1 scan_shape[0] = scan_shape[0] - 1 # Create the proper scan_positions without the flyback column - scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)],scan_shape) - + scan_positions = np.ravel_multi_index([ii.ravel() for ii in np.indices(scan_shape)], scan_shape) + # Load any metadata metadata = {} - if 'metadata' in f: - load_h5_to_dict(f['metadata'], metadata) + if "metadata" in f: + load_h5_to_dict(f["metadata"], metadata) scan_shape = scan_shape[::-1] - + if version >= 3: - # Convert to int to avoid integer division that results in + # Convert to int to avoid integer division that results in # a float frames_per_scan = len(data) // np.prod(scan_shape, dtype=int) # Need to reshape the data, as it was flattened before saving @@ -225,20 +237,20 @@ def from_hdf5(cls, filepath, keep_flyback=True, **init_kwargs): # We may need to convert the version of the data if version != cls.VERSION: kwargs = { - 'data': data, - 'scan_positions': scan_positions, - 'scan_shape': scan_shape, - 'frame_shape': frame_shape, - 'from_version': version, - 'to_version': cls.VERSION, + "data": data, + "scan_positions": scan_positions, + "scan_shape": scan_shape, + "frame_shape": frame_shape, + "from_version": version, + "to_version": cls.VERSION, } data = convert_data_format(**kwargs) kwargs = { - 'data': data, - 'scan_shape': scan_shape, - 'frame_shape': frame_shape, - 'metadata': metadata, + "data": data, + "scan_shape": scan_shape, + "frame_shape": frame_shape, + "metadata": metadata, } kwargs.update(init_kwargs) return cls(**kwargs) @@ -252,29 +264,27 @@ def write_to_hdf5(self, path): import h5py data = self.data - with h5py.File(path, 'a') as f: - f.attrs['version'] = self.VERSION + with h5py.File(path, "a") as f: + f.attrs["version"] = self.VERSION - group = f.require_group('electron_events') - scan_positions = group.create_dataset('scan_positions', - data=self.scan_positions) - scan_positions.attrs['Nx'] = self.scan_shape[1] - scan_positions.attrs['Ny'] = self.scan_shape[0] + group = f.require_group("electron_events") + scan_positions = group.create_dataset("scan_positions", data=self.scan_positions) + scan_positions.attrs["Nx"] = self.scan_shape[1] + scan_positions.attrs["Ny"] = self.scan_shape[0] # We can't store a 2D array that contains variable length arrays # in h5py currently. So flatten the data, and we will reshape it # when we read it back in. See h5py/h5py#876 coordinates_type = h5py.special_dtype(vlen=np.uint32) - frames = group.create_dataset('frames', (np.prod(data.shape),), - dtype=coordinates_type) + frames = group.create_dataset("frames", (np.prod(data.shape),), dtype=coordinates_type) # Add the frame dimensions as attributes - frames.attrs['Nx'] = self.frame_shape[0] - frames.attrs['Ny'] = self.frame_shape[1] + frames.attrs["Nx"] = self.frame_shape[0] + frames.attrs["Ny"] = self.frame_shape[1] frames[...] = data.ravel() # Write out the metadata - save_dict_to_h5(self.metadata, f.require_group('metadata')) + save_dict_to_h5(self.metadata, f.require_group("metadata")) def to_dense(self): """Create and return a fully dense version of the sparse array @@ -344,19 +354,17 @@ def reshape(self, *shape): shape = shape[0] if len(shape) not in (3, 4): - raise ValueError('Shape must be length 3 or 4') + raise ValueError("Shape must be length 3 or 4") scan_shape = shape[:-2] frame_shape = shape[-2:] # Make sure the shapes are valid if np.prod(scan_shape) != np.prod(self.scan_shape): - raise ValueError('Cannot reshape scan array of size ' - f'{self.scan_shape} into shape {scan_shape}') + raise ValueError("Cannot reshape scan array of size " f"{self.scan_shape} into shape {scan_shape}") if np.prod(frame_shape) != np.prod(self.frame_shape): - raise ValueError('Cannot reshape frame array of size ' - f'{self.frame_shape} into shape {frame_shape}') + raise ValueError("Cannot reshape frame array of size " f"{self.frame_shape} into shape {frame_shape}") # If we made it here, we can perform the reshaping... self.scan_shape = scan_shape @@ -568,18 +576,18 @@ def bin_scans(self, bin_factor, in_place=False): :rtype: SparseArray """ if not all(x % bin_factor == 0 for x in self.scan_shape): - raise ValueError(f'scan_shape must be equally divisible by ' - f'bin_factor {bin_factor}') + raise ValueError(f"scan_shape must be equally divisible by " f"bin_factor {bin_factor}") # No need to modify the data, just the scan shape and the scan # positions. new_scan_shape = tuple(x // bin_factor for x in self.scan_shape) all_positions = self.scan_positions - all_positions_reshaped = all_positions.reshape( - new_scan_shape[0], bin_factor, new_scan_shape[1], bin_factor) + all_positions_reshaped = all_positions.reshape(new_scan_shape[0], bin_factor, new_scan_shape[1], bin_factor) - new_data_shape = (np.prod(new_scan_shape), - bin_factor**2 * self.num_frames_per_scan) + new_data_shape = ( + np.prod(new_scan_shape), + bin_factor**2 * self.num_frames_per_scan, + ) new_data = np.empty(new_data_shape, dtype=object) for i in range(new_scan_shape[0]): for j in range(new_scan_shape[1]): @@ -594,12 +602,12 @@ def bin_scans(self, bin_factor, in_place=False): return self kwargs = { - 'data': new_data, - 'scan_shape': new_scan_shape, - 'frame_shape': self.frame_shape, - 'dtype': self.dtype, - 'sparse_slicing': self.sparse_slicing, - 'allow_full_expand': self.allow_full_expand, + "data": new_data, + "scan_shape": new_scan_shape, + "frame_shape": self.frame_shape, + "dtype": self.dtype, + "sparse_slicing": self.sparse_slicing, + "allow_full_expand": self.allow_full_expand, } return SparseArray(**kwargs) @@ -619,8 +627,7 @@ def bin_frames(self, bin_factor, in_place=False): :rtype: SparseArray """ if not all(x % bin_factor == 0 for x in self.frame_shape): - raise ValueError(f'frame_shape must be equally divisible by ' - f'bin_factor {bin_factor}') + raise ValueError(f"frame_shape must be equally divisible by " f"bin_factor {bin_factor}") # Ravel the data for easier manipulation in this function data = self.data.ravel() @@ -628,13 +635,12 @@ def bin_frames(self, bin_factor, in_place=False): rows = data // self.frame_shape[0] // bin_factor cols = data % self.frame_shape[1] // bin_factor - rows *= (self.frame_shape[0] // bin_factor) + rows *= self.frame_shape[0] // bin_factor rows += cols new_data = rows frames_per_scan = self.data.shape[1] - raveled_scan_positions = np.repeat(self.scan_positions, - frames_per_scan) + raveled_scan_positions = np.repeat(self.scan_positions, frames_per_scan) new_scan_positions = raveled_scan_positions # Place duplicates in separate frames @@ -655,10 +661,9 @@ def bin_frames(self, bin_factor, in_place=False): # Resize the new data and add on the extra rows new_size = rows.shape[0] + len(extra_rows) new_data = np.empty(new_size, dtype=object) - new_data[:rows.shape[0]] = rows - new_data[rows.shape[0]:] = extra_rows - new_scan_positions = np.append(new_scan_positions, - extra_scan_positions) + new_data[: rows.shape[0]] = rows + new_data[rows.shape[0] :] = extra_rows + new_scan_positions = np.append(new_scan_positions, extra_scan_positions) # Find the max number of extra frames per scan position and use # that. @@ -691,12 +696,12 @@ def bin_frames(self, bin_factor, in_place=False): return self kwargs = { - 'data': new_data, - 'scan_shape': self.scan_shape, - 'frame_shape': new_frame_shape, - 'dtype': self.dtype, - 'sparse_slicing': self.sparse_slicing, - 'allow_full_expand': self.allow_full_expand, + "data": new_data, + "scan_shape": self.scan_shape, + "frame_shape": new_frame_shape, + "dtype": self.dtype, + "sparse_slicing": self.sparse_slicing, + "allow_full_expand": self.allow_full_expand, } return SparseArray(**kwargs) @@ -781,6 +786,7 @@ def _split_slices(self, slices): Returns `scan_slices, frame_slices`. """ + def validate_if_advanced_indexing(obj, max_ndim, is_scan): """Check if the obj is advanced indexing (and convert to ndarray if it is) @@ -799,12 +805,12 @@ def validate_if_advanced_indexing(obj, max_ndim, is_scan): # If it is a numpy array, it is advanced indexing. # Ensure that there are not too many dimensions. if obj.ndim > max_ndim: - msg = 'Too many advanced indexing dimensions.' + msg = "Too many advanced indexing dimensions." if is_scan: msg += ( - ' Cannot perform advanced indexing on both the ' - 'scan positions and the frame positions ' - 'simultaneously' + " Cannot perform advanced indexing on both the " + "scan positions and the frame positions " + "simultaneously" ) raise IndexError(msg) @@ -816,62 +822,55 @@ def validate_if_advanced_indexing(obj, max_ndim, is_scan): if not slices: # It's an empty tuple - return tuple(), tuple() + return (), () # Figure out which slices belong to which parts. # The first slice should definitely be part of the scan. first_slice = slices[0] - first_slice = validate_if_advanced_indexing(first_slice, - len(self.scan_shape), - is_scan=True) + first_slice = validate_if_advanced_indexing(first_slice, len(self.scan_shape), is_scan=True) frame_start = 1 if len(self.scan_shape) > 1: # We might have 2 slices for the scan shape - if not isinstance(first_slice, - np.ndarray) or first_slice.ndim == 1: + if not isinstance(first_slice, np.ndarray) or first_slice.ndim == 1: # We have another scan slice frame_start += 1 if frame_start == 2 and len(slices) > 1: # Validate the second scan slice. - second_slice = validate_if_advanced_indexing(slices[1], 1, - is_scan=True) + second_slice = validate_if_advanced_indexing(slices[1], 1, is_scan=True) scan_slices = (first_slice, second_slice) else: scan_slices = (first_slice,) # If there are frame indices, validate them too - frame_slices = tuple() + frame_slices = () for i in range(frame_start, len(slices)): max_ndim = frame_start + 2 - i if max_ndim == 0: - raise IndexError('Too many indices for frame positions') - frame_slices += (validate_if_advanced_indexing(slices[i], max_ndim, - is_scan=False),) + raise IndexError("Too many indices for frame positions") + frame_slices += (validate_if_advanced_indexing(slices[i], max_ndim, is_scan=False),) # Verify that we are not doing advanced indexing on both the scan # positions and frame positions simultaneously. - if (any(isinstance(x, np.ndarray) for x in scan_slices) and - any(isinstance(x, np.ndarray) for x in frame_slices)): - msg = ( - 'Cannot perform advanced indexing on both scan positions ' - 'and frame positions simultaneously' - ) + if any(isinstance(x, np.ndarray) for x in scan_slices) and any(isinstance(x, np.ndarray) for x in frame_slices): + msg = "Cannot perform advanced indexing on both scan positions " "and frame positions simultaneously" raise IndexError(msg) # Verify that if there are any 2D advanced indexing arrays, they # must be boolean and of the same shape. - first_frame_slice = (slice(None) if not frame_slices - else frame_slices[0]) + first_frame_slice = slice(None) if not frame_slices else frame_slices[0] for i, to_check in enumerate((first_slice, first_frame_slice)): req_shape = self.scan_shape if i == 0 else self.frame_shape - if (isinstance(to_check, np.ndarray) and to_check.ndim == 2 and - (to_check.dtype != np.bool_ or to_check.shape != req_shape)): + if ( + isinstance(to_check, np.ndarray) + and to_check.ndim == 2 + and (to_check.dtype != np.bool_ or to_check.shape != req_shape) + ): msg = ( - '2D advanced indexing is only allowed for boolean arrays ' - 'that match either the scan shape or the frame shape ' - '(whichever it is indexing into)' + "2D advanced indexing is only allowed for boolean arrays " + "that match either the scan shape or the frame shape " + "(whichever it is indexing into)" ) raise IndexError(msg) @@ -882,8 +881,7 @@ def _sparse_frames(self, indices): indices = (indices,) if len(indices) != len(self.scan_shape): - msg = (f'{indices} shape does not match scan shape' - f'{self.scan_shape}') + msg = f"{indices} shape does not match scan shape" f"{self.scan_shape}" raise ValueError(msg) if len(indices) == 1: @@ -894,15 +892,13 @@ def _sparse_frames(self, indices): return self.data[scan_ind] def _slice_dense(self, scan_slices, frame_slices): - new_scan_shape = np.empty(self.scan_shape, - dtype=bool)[scan_slices].shape - new_frame_shape = np.empty(self.frame_shape, - dtype=bool)[frame_slices].shape + new_scan_shape = np.empty(self.scan_shape, dtype=bool)[scan_slices].shape + new_frame_shape = np.empty(self.frame_shape, dtype=bool)[frame_slices].shape result_shape = new_scan_shape + new_frame_shape if result_shape == self.shape and not self.allow_full_expand: - raise FullExpansionDenied('Full expansion is not allowed') + raise FullExpansionDenied("Full expansion is not allowed") # Create the result result = np.zeros(result_shape, dtype=self.dtype) @@ -944,10 +940,8 @@ def _slice_sparse(self, scan_slices, frame_slices): new_frames = self.data if frame_shape_modified: - # Map old frame indices to new ones. Invalid values will be -1. - frame_indices = np.arange(self._frame_shape_flat[0]).reshape( - self.frame_shape) + frame_indices = np.arange(self._frame_shape_flat[0]).reshape(self.frame_shape) sliced = frame_indices[frame_slices] new_frame_shape = sliced.shape @@ -975,12 +969,12 @@ def _slice_sparse(self, scan_slices, frame_slices): new_data = copy.deepcopy(new_frames) kwargs = { - 'data': new_data, - 'scan_shape': new_scan_shape, - 'frame_shape': new_frame_shape, - 'dtype': self.dtype, - 'sparse_slicing': self.sparse_slicing, - 'allow_full_expand': self.allow_full_expand, + "data": new_data, + "scan_shape": new_scan_shape, + "frame_shape": new_frame_shape, + "dtype": self.dtype, + "sparse_slicing": self.sparse_slicing, + "allow_full_expand": self.allow_full_expand, } return SparseArray(**kwargs) @@ -989,7 +983,7 @@ def _expand(self, indices): indices = (indices,) if len(indices) >= len(self.scan_shape): - scan_indices = indices[:len(self.scan_shape)] + scan_indices = indices[: len(self.scan_shape)] sparse_frames = self._sparse_frames(scan_indices) if len(indices) < len(self.scan_shape): @@ -1019,12 +1013,11 @@ def _expand(self, indices): elif len(indices) == len(self.scan_shape) + 2: # The last two indices are frame indices frame_indices = indices[-2:] - flat_frame_index = (frame_indices[0] * self.frame_shape[1] + - frame_indices[1]) + flat_frame_index = frame_indices[0] * self.frame_shape[1] + frame_indices[1] return sum(flat_frame_index in frame for frame in sparse_frames) max_length = len(self.shape) + 1 - raise ValueError(f'0 < len(indices) < {max_length} is required') + raise ValueError(f"0 < len(indices) < {max_length} is required") @staticmethod def _is_advanced_indexing(obj): @@ -1053,8 +1046,7 @@ def is_int_or_bool_ndarray(x): return True if isinstance(obj, tuple): - return any(isinstance(x, Sequence) or is_int_or_bool_ndarray(x) - for x in obj) + return any(isinstance(x, Sequence) or is_int_or_bool_ndarray(x) for x in obj) return False @@ -1064,7 +1056,7 @@ def _is_none_slice(x): def _warning(msg): - print(f'Warning: {msg}', file=sys.stderr) + print(f"Warning: {msg}", file=sys.stderr) class FullExpansionDenied(Exception): diff --git a/python/stempy/utils.py b/python/stempy/utils.py index ce5b3dbb..fa5dac57 100644 --- a/python/stempy/utils.py +++ b/python/stempy/utils.py @@ -1,12 +1,14 @@ try: # Use importlib metadata if available (python >=3.8) - from importlib.metadata import version, PackageNotFoundError + from importlib.metadata import PackageNotFoundError, version except ImportError: # No importlib metadata. Try to use pkg_resources instead. from pkg_resources import ( - get_distribution, DistributionNotFound as PackageNotFoundError, ) + from pkg_resources import ( + get_distribution, + ) def version(x): return get_distribution(x).version @@ -14,7 +16,7 @@ def version(x): def get_version(): try: - return version('stempy') + return version("stempy") except PackageNotFoundError: # package is not installed pass diff --git a/setup.py b/setup.py index d55fe476..fc7f356f 100644 --- a/setup.py +++ b/setup.py @@ -7,42 +7,42 @@ def extra_cmake_args(): # Warning: if you use paths on Windows, you should use "\\" # for the path delimiter to work on CI. - env = os.getenv('EXTRA_CMAKE_ARGS') - return env.split(';') if env else [] + env = os.getenv("EXTRA_CMAKE_ARGS") + return env.split(";") if env else [] cmake_args = [] + extra_cmake_args() -if os.name == 'nt': +if os.name == "nt": # Need to export all headers on windows... - cmake_args.append('-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE') + cmake_args.append("-DCMAKE_WINDOWS_EXPORT_ALL_SYMBOLS=TRUE") -if os.getenv('USE_PYTHON_IN_PATH'): - python_exe = shutil.which('python') +if os.getenv("USE_PYTHON_IN_PATH"): + python_exe = shutil.which("python") if python_exe: # For this program, we use find_package(Python3 ...) - cmake_args.append(f'-DPython3_EXECUTABLE={python_exe}') + cmake_args.append(f"-DPython3_EXECUTABLE={python_exe}") -with open('requirements.txt') as f: +with open("requirements.txt") as f: install_requires = f.read() setup( - name='stempy', + name="stempy", use_scm_version=True, - description='A package for the ingestion of 4D STEM data.', - long_description='A package for the ingestion of 4D STEM data.', - url='https://github.com/OpenChemistry/stempy', - author='Kitware Inc', - license='BSD 3-Clause', + description="A package for the ingestion of 4D STEM data.", + long_description="A package for the ingestion of 4D STEM data.", + url="https://github.com/OpenChemistry/stempy", + author="Kitware Inc", + license="BSD 3-Clause", classifiers=[ - 'Development Status :: 3 - Alpha', - 'License :: OSI Approved :: BSD License', - 'Programming Language :: Python :: 3', + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", ], - keywords='', - packages=['stempy'], + keywords="", + packages=["stempy"], install_requires=install_requires, cmake_args=cmake_args, ) diff --git a/tests/conftest.py b/tests/conftest.py index d8cc8f25..a9075e51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,13 +6,12 @@ from stempy.io.sparse_array import SparseArray - DATA_URLS = { - 'electron_small': 'https://data.kitware.com/api/v1/file/6065f00d2fa25629b93bdabe/download', # noqa - 'electron_large': 'https://data.kitware.com/api/v1/file/6065f2792fa25629b93c0303/download', # noqa - 'cropped_multi_frames_v1': 'https://data.kitware.com/api/v1/file/623119814acac99f4261aa59/download', # noqa - 'cropped_multi_frames_v2': 'https://data.kitware.com/api/v1/file/6244e9944acac99f424743df/download', # noqa - 'cropped_multi_frames_v3': 'https://data.kitware.com/api/v1/file/624601cd4acac99f425c73f6/download', # noqa + "electron_small": "https://data.kitware.com/api/v1/file/6065f00d2fa25629b93bdabe/download", # noqa + "electron_large": "https://data.kitware.com/api/v1/file/6065f2792fa25629b93c0303/download", # noqa + "cropped_multi_frames_v1": "https://data.kitware.com/api/v1/file/623119814acac99f4261aa59/download", # noqa + "cropped_multi_frames_v2": "https://data.kitware.com/api/v1/file/6244e9944acac99f424743df/download", # noqa + "cropped_multi_frames_v3": "https://data.kitware.com/api/v1/file/624601cd4acac99f425c73f6/download", # noqa } DATA_RESPONSES = {} @@ -34,34 +33,35 @@ def io_object(key): @pytest.fixture def electron_data_small(): - return io_object('electron_small') + return io_object("electron_small") @pytest.fixture def electron_data_large(): - return io_object('electron_large') + return io_object("electron_large") @pytest.fixture def cropped_multi_frames_data_v1(): - return io_object('cropped_multi_frames_v1') + return io_object("cropped_multi_frames_v1") @pytest.fixture def cropped_multi_frames_data_v2(): - return io_object('cropped_multi_frames_v2') + return io_object("cropped_multi_frames_v2") @pytest.fixture def cropped_multi_frames_data_v3(): - return io_object('cropped_multi_frames_v3') + return io_object("cropped_multi_frames_v3") # SparseArray fixtures + @pytest.fixture def sparse_array_10x10(): - aa = np.zeros((4, 3), dtype='object') + aa = np.zeros((4, 3), dtype="object") for ii in range(4): aa[ii, 0] = np.array((0, 48, 72, 65)) aa[ii, 1] = np.array((0, 48, 72, 65)) @@ -73,7 +73,7 @@ def sparse_array_10x10(): @pytest.fixture def sparse_array_small(electron_data_small): kwargs = { - 'dtype': np.uint64, + "dtype": np.uint64, } array = SparseArray.from_hdf5(electron_data_small, **kwargs) @@ -109,54 +109,53 @@ def cropped_multi_frames_v2(cropped_multi_frames_data_v2): def cropped_multi_frames_v3(cropped_multi_frames_data_v3): return SparseArray.from_hdf5(cropped_multi_frames_data_v3, dtype=np.uint16) + @pytest.fixture def simulate_sparse_array(): - """Make a ndarray with sparse disk. scan_size: Real space scan size (pixels) frame_size: detector size (pixels) center: the center of the disk in diffraction space how_sparse: Percent sparseness (0-1). Large number means fewer electrons per frame disk_size: The radius in pixels of the diffraction disk - - """ - + """ + scan_size = (100, 100) - frame_size = (100,100) + frame_size = (100, 100) center = (30, 70) - how_sparse = .8 # 0-1; larger is less electrons - disk_size = 10 # disk radius in pixels + how_sparse = 0.8 # 0-1; larger is less electrons + disk_size = 10 # disk radius in pixels num_frames = 2 - - YY, XX = np.mgrid[0:frame_size[0], 0:frame_size[1]] - RR = np.sqrt((YY-center[1])**2 + (XX-center[0])**2) + + YY, XX = np.mgrid[0 : frame_size[0], 0 : frame_size[1]] + RR = np.sqrt((YY - center[1]) ** 2 + (XX - center[0]) ** 2) RR[RR <= disk_size] = 1 RR[RR > disk_size] = 0 - sparse = np.random.rand(scan_size[0],scan_size[1],frame_size[0],frame_size[1], num_frames) * RR[:,:,None] + sparse = np.random.rand(scan_size[0], scan_size[1], frame_size[0], frame_size[1], num_frames) * RR[:, :, None] sparse[sparse >= how_sparse] = 1 sparse[sparse < how_sparse] = 0 sparse = sparse.astype(np.uint8) - - new_data = np.empty((sparse.shape[0],sparse.shape[1], num_frames), dtype=object) + + new_data = np.empty((sparse.shape[0], sparse.shape[1], num_frames), dtype=object) for j in range(sparse.shape[0]): for k in range(sparse.shape[1]): for i in range(num_frames): - events = np.ravel_multi_index(np.where(sparse[j,k,:,:,i] == 1),frame_size) - new_data[j,k,i] = events + events = np.ravel_multi_index(np.where(sparse[j, k, :, :, i] == 1), frame_size) + new_data[j, k, i] = events - new_data = new_data.reshape(new_data.shape[0]*new_data.shape[1],num_frames) + new_data = new_data.reshape(new_data.shape[0] * new_data.shape[1], num_frames) kwargs = { - 'data': new_data, - 'scan_shape': sparse.shape[0:2], - 'frame_shape': sparse.shape[2:4], - 'dtype': sparse.dtype, - 'sparse_slicing': True, - 'allow_full_expand': False, + "data": new_data, + "scan_shape": sparse.shape[0:2], + "frame_shape": sparse.shape[2:4], + "dtype": sparse.dtype, + "sparse_slicing": True, + "allow_full_expand": False, } - - return SparseArray(**kwargs) \ No newline at end of file + + return SparseArray(**kwargs) diff --git a/tests/test_image.py b/tests/test_image.py index 4c18a907..d3a16acc 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,5 +1,3 @@ -import pytest - import numpy as np from stempy.image import com_dense, com_sparse, radial_sum_sparse @@ -15,9 +13,9 @@ def test_com_sparse(sparse_array_small, full_array_small): data[1][0] = np.array([0]) data[1][1] = np.array([0, 1, 2]) kwargs = { - 'data': data, - 'scan_shape': (2, 1), - 'frame_shape': (2, 2), + "data": data, + "scan_shape": (2, 1), + "frame_shape": (2, 2), } array = SparseArray(**kwargs) @@ -40,24 +38,23 @@ def test_com_sparse(sparse_array_small, full_array_small): def test_radial_sum_sparse(sparse_array_10x10): - rr = radial_sum_sparse(sparse_array_10x10, center=(5, 5)) assert np.array_equal(rr[0, 0, :], [0, 6, 0, 0, 3]) + def test_com_sparse_parameters(simulate_sparse_array): - - sp = simulate_sparse_array #((100,100), (100,100), (30,70), (0.8), (10)) - + sp = simulate_sparse_array # ((100,100), (100,100), (30,70), (0.8), (10)) + # Test no inputs. This should be the full frame COM com0 = com_sparse(sp) assert round(com0[0,].mean()) == 30 - + # Test crop_to input. Initial COM should be full frame COM - com1 = com_sparse(sp, crop_to=(10,10)) + com1 = com_sparse(sp, crop_to=(10, 10)) assert round(com1[0,].mean()) == 30 - + # Test crop_to and init_center input. # No counts will be in the center so all positions will be np.nan - com2 = com_sparse(sp, crop_to=(10,10), init_center=(1,1)) - assert np.isnan(com2[0,0,0]) + com2 = com_sparse(sp, crop_to=(10, 10), init_center=(1, 1)) + assert np.isnan(com2[0, 0, 0]) diff --git a/tests/test_simple.py b/tests/test_simple.py index f51cf447..e79119ef 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -12,24 +12,24 @@ def test_simple(): # io.get_hdf5_reader() is hard-coded to that. shape = (32, 576, 576) data = np.arange(np.prod(shape)).reshape(shape) - dataset_name = 'frames' + dataset_name = "frames" # Stempy uses uint16. Let's prevent overflow... data %= np.iinfo(np.uint16).max mean = data.mean(axis=0) # Fake stuff that io.get_hdf5_reader() needs... - fake_scan_name = 'stem/images' + fake_scan_name = "stem/images" fake_scan = np.zeros((4, 8, 1)) temp = tempfile.NamedTemporaryFile(delete=False) file_name = temp.name try: - with h5py.File(file_name, 'w') as wf: + with h5py.File(file_name, "w") as wf: wf.create_dataset(dataset_name, data=data) wf.create_dataset(fake_scan_name, data=fake_scan) - with h5py.File(file_name, 'r') as rf: + with h5py.File(file_name, "r") as rf: reader = io.get_hdf5_reader(rf) average = image.calculate_average(reader) finally: @@ -41,5 +41,5 @@ def test_simple(): assert np.array_equal(average, mean) -if __name__ == '__main__': +if __name__ == "__main__": test_simple() diff --git a/tests/test_sparse_array.py b/tests/test_sparse_array.py index e088884c..6a465c56 100644 --- a/tests/test_sparse_array.py +++ b/tests/test_sparse_array.py @@ -1,12 +1,10 @@ import os import tempfile -import pytest - import numpy as np +import pytest -from stempy.io.sparse_array import FullExpansionDenied -from stempy.io.sparse_array import SparseArray +from stempy.io.sparse_array import FullExpansionDenied, SparseArray def test_full_expansion(sparse_array_small, full_array_small): @@ -56,7 +54,7 @@ def test_dense_slicing(sparse_array_small, full_array_small): for i in range(len(slices)): # Try using just the first slice, then the first and the second, # etc. until using all slices. - s = slices[:i + 1] + s = slices[: i + 1] if len(s) == 1 and s[0] == slice(None): # Skip over full expansions @@ -74,9 +72,24 @@ def test_sparse_slicing(sparse_array_small, full_array_small): # Just test a few indices test_frames = [ - (0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 7), (0, 9), (0, 11), - (1, 0), (1, 3), (2, 2), (3, 1), (8, 4), (9, 6), (10, 13), (17, 17), - (14, 16), (18, 15), + (0, 0), + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (0, 7), + (0, 9), + (0, 11), + (1, 0), + (1, 3), + (2, 2), + (3, 1), + (8, 4), + (9, 6), + (10, 13), + (17, 17), + (14, 16), + (18, 15), ] for slices in TEST_SLICES: sliced = array[slices] @@ -113,8 +126,7 @@ def test_frame_slicing(sparse_array_small, full_array_small): # Now test the dense slicing version array.sparse_slicing = False - assert np.array_equal(array[:, :, 100:200, 200:300], - full[:, :, 100:200, 200:300]) + assert np.array_equal(array[:, :, 100:200, 200:300], full[:, :, 100:200, 200:300]) def test_scan_axes_arithmetic(sparse_array_small, full_array_small): @@ -161,9 +173,9 @@ def test_bin_scans(sparse_array_small, full_array_small): continue binned = array.bin_scans(factor) - full_binned = full.reshape(shape[0] // factor, factor, - shape[1] // factor, factor, - shape[2], shape[3]).sum(axis=(1, 3)) + full_binned = full.reshape(shape[0] // factor, factor, shape[1] // factor, factor, shape[2], shape[3]).sum( + axis=(1, 3) + ) assert np.array_equal(binned.to_dense(), full_binned) @@ -192,9 +204,9 @@ def test_bin_frames(sparse_array_small, full_array_small): continue binned = array.bin_frames(factor) - full_binned = full.reshape(shape[0], shape[1], shape[2] // factor, - factor, shape[3] // factor, - factor).sum(axis=(3, 5)) + full_binned = full.reshape(shape[0], shape[1], shape[2] // factor, factor, shape[3] // factor, factor).sum( + axis=(3, 5) + ) assert np.array_equal(binned.to_dense(), full_binned) @@ -295,21 +307,16 @@ def test_slice_sum(sparse_array_small, full_array_small): data = sparse_array_small full = full_array_small - assert np.array_equal(data[0, 0:1, :, :].sum(axis='scan'), - full[0, 0:1, :, :].sum(axis=0)) - assert np.array_equal(data[0:1, 0:1, :, :].sum(axis='scan'), - full[0:1, 0:1, :, :].sum(axis=(0, 1))) - assert np.array_equal(data[0:2, 0:2, :, :].sum(axis='scan'), - full[0:2, 0:2, :, :].sum(axis=(0, 1))) - assert np.array_equal(data[0, 0, :, :].sum(), - full[0, 0, :, :].sum()) + assert np.array_equal(data[0, 0:1, :, :].sum(axis="scan"), full[0, 0:1, :, :].sum(axis=0)) + assert np.array_equal(data[0:1, 0:1, :, :].sum(axis="scan"), full[0:1, 0:1, :, :].sum(axis=(0, 1))) + assert np.array_equal(data[0:2, 0:2, :, :].sum(axis="scan"), full[0:2, 0:2, :, :].sum(axis=(0, 1))) + assert np.array_equal(data[0, 0, :, :].sum(), full[0, 0, :, :].sum()) - assert np.array_equal(data[:, :, 0:1, 0:2].sum(axis='frame'), - full[:, :, 0:1, 0:2].sum(axis=(2, 3))) - assert np.array_equal(data[:, :, 100:200, 100:200].sum(axis='frame'), - full[:, :, 100:200, 100:200].sum(axis=(2, 3))) - assert np.array_equal(data[:, :, 100:200:-3, 100:200:-3].sum(axis='frame'), - full[:, :, 100:200:-3, 100:200:-3].sum(axis=(2, 3))) + assert np.array_equal(data[:, :, 0:1, 0:2].sum(axis="frame"), full[:, :, 0:1, 0:2].sum(axis=(2, 3))) + assert np.array_equal(data[:, :, 100:200, 100:200].sum(axis="frame"), full[:, :, 100:200, 100:200].sum(axis=(2, 3))) + assert np.array_equal( + data[:, :, 100:200:-3, 100:200:-3].sum(axis="frame"), full[:, :, 100:200:-3, 100:200:-3].sum(axis=(2, 3)) + ) def test_scan_ravel(sparse_array_small, full_array_small): @@ -322,8 +329,7 @@ def test_scan_ravel(sparse_array_small, full_array_small): # Perform a few simple tests on the raveled SparseArray assert np.array_equal(array.sum(axis=0), full.sum(axis=(0, 1))) - assert np.array_equal(array.sum(axis=(1, 2)), - full.sum(axis=(2, 3)).ravel()) + assert np.array_equal(array.sum(axis=(1, 2)), full.sum(axis=(2, 3)).ravel()) full = full.reshape(expected_shape) for i in range(expected_shape[0]): @@ -335,14 +341,16 @@ def test_round_trip_hdf5(sparse_array_small): original = sparse_array_small # Store some metadata that we will test for - original.metadata.update({ - 'nested': { - 'string': 's', - }, - 'array1d': np.arange(3, dtype=int), - 'array2d': np.arange(9, dtype=np.float32).reshape((3, 3)), - 'float': np.float64(7.21), - }) + original.metadata.update( + { + "nested": { + "string": "s", + }, + "array1d": np.arange(3, dtype=int), + "array2d": np.arange(9, dtype=np.float32).reshape((3, 3)), + "float": np.float64(7.21), + } + ) temp = tempfile.NamedTemporaryFile(delete=False) try: @@ -361,7 +369,7 @@ def test_round_trip_hdf5(sparse_array_small): for a, b in zip(sfa, sfb): assert np.array_equal(a, b) - original.metadata['float'] = np.float64(6.5) + original.metadata["float"] = np.float64(6.5) # Now, the metadata shouldn't be equal assert not dicts_equal(original.metadata, array.metadata) @@ -373,7 +381,7 @@ def dicts_equal(d1, d2): for k, v1 in d1.items(): v2 = d2[k] - if type(v1) != type(v2): + if type(v1) is not type(v2): return False if isinstance(v1, np.ndarray): @@ -400,11 +408,11 @@ def test_multiple_frames_per_scan_position(): data[1][0] = np.array([0]) data[1][1] = np.array([0, 1]) kwargs = { - 'data': data, - 'scan_shape': (2, 1), - 'frame_shape': (2, 2), - 'sparse_slicing': False, - 'allow_full_expand': True, + "data": data, + "scan_shape": (2, 1), + "frame_shape": (2, 2), + "sparse_slicing": False, + "allow_full_expand": True, } array = SparseArray(**kwargs) @@ -447,10 +455,8 @@ def test_multiple_frames_per_scan_position(): assert np.array_equal(array[:, :, 0][1, 0], full[:, :, 0][1, 0]) assert np.array_equal(array[:, :, 0, 1][0, 0], full[:, :, 0, 1][0, 0]) assert np.array_equal(array[:, :, 0, 1:][0, 0], full[:, :, 0, 1:][0, 0]) - assert np.array_equal(array[:, :, 0:2, 0:2][0, 0], - full[:, :, 0:2, 0:2][0, 0]) - assert np.array_equal(array[:, :, 0:2, 0:2][1, 0], - full[:, :, 0:2, 0:2][1, 0]) + assert np.array_equal(array[:, :, 0:2, 0:2][0, 0], full[:, :, 0:2, 0:2][0, 0]) + assert np.array_equal(array[:, :, 0:2, 0:2][1, 0], full[:, :, 0:2, 0:2][1, 0]) assert np.array_equal(array[0, :, :, 0][0], full[0, :, :, 0][0]) assert np.array_equal(array[0, :, :, 1][0], full[0, :, :, 1][0]) assert np.array_equal(array[0, :, 0:2, 0:2][0], full[0, :, 0:2, 0:2][0]) @@ -460,13 +466,13 @@ def test_multiple_frames_per_scan_position(): # Test arithmetic array.allow_full_expand = False - axis = 'scan' + axis = "scan" assert np.array_equal(array.sum(axis=axis), [[4, 1], [1, 0]]) assert np.array_equal(array.max(axis=axis), [[2, 1], [1, 0]]) assert np.array_equal(array.min(axis=axis), [[2, 0], [0, 0]]) assert np.array_equal(array.mean(axis=axis), [[2, 0.5], [0.5, 0]]) - axis = 'frame' + axis = "frame" assert np.array_equal(array.sum(axis=axis), full.sum(axis=(2, 3))) assert np.array_equal(array.mean(axis=axis), full.mean(axis=(2, 3))) assert np.array_equal(array.min(axis=axis), full.min(axis=(2, 3))) @@ -480,10 +486,10 @@ def test_multiple_frames_per_scan_position(): data[1][1] = np.array([], dtype=int) kwargs = { - 'data': data, - 'scan_shape': (2, 1), - 'frame_shape': (4, 4), - 'sparse_slicing': False, + "data": data, + "scan_shape": (2, 1), + "frame_shape": (4, 4), + "sparse_slicing": False, } test_bin_frames_array = SparseArray(**kwargs) binned = test_bin_frames_array.bin_frames(2) @@ -501,11 +507,11 @@ def test_multiple_frames_per_scan_position(): data[i] = np.array([i % 4]) kwargs = { - 'data': data, - 'scan_shape': (4, 4), - 'frame_shape': (2, 2), - 'sparse_slicing': False, - 'allow_full_expand': True, + "data": data, + "scan_shape": (4, 4), + "frame_shape": (2, 2), + "sparse_slicing": False, + "allow_full_expand": True, } test_bin_scans_array = SparseArray(**kwargs) binned = test_bin_scans_array.bin_scans(2) @@ -527,8 +533,7 @@ def test_multiple_frames_per_scan_position(): assert np.array_equal(binned_4x[:], binned_twice[:]) -def test_data_conversion(cropped_multi_frames_v1, cropped_multi_frames_v2, - cropped_multi_frames_v3): +def test_data_conversion(cropped_multi_frames_v1, cropped_multi_frames_v2, cropped_multi_frames_v3): # These datasets were in older formats, but when they were converted # into a SparseArray via the fixture, they should have automatically # been converted to the latest version. Confirm this is the case. @@ -618,7 +623,7 @@ def compare_with_sparse(full, sparse): rng = np.random.default_rng(0) num_tries = 3 - for i in range(num_tries): + for _ in range(num_tries): scan_mask = rng.choice([True, False], array.scan_shape) frame_mask = rng.choice([True, False], array.frame_shape) @@ -653,23 +658,15 @@ def compare_with_sparse(full, sparse): # Now test a few arithmetic operations combined with the indexing array.sparse_slicing = True - assert np.array_equal(array[scan_mask].sum(axis='scan'), - full[scan_mask].sum(axis=(0,))) - assert np.array_equal(array[scan_mask].min(axis='scan'), - full[scan_mask].min(axis=(0,))) - assert np.array_equal(array[scan_mask].max(axis='scan'), - full[scan_mask].max(axis=(0,))) - assert np.allclose(array[scan_mask].mean(axis='scan'), - full[scan_mask].mean(axis=(0,))) - - assert np.array_equal(array[:, :, frame_mask].sum(axis='frame'), - full[:, :, frame_mask].sum(axis=(2,))) - assert np.allclose(array[:, :, frame_mask].mean(axis='frame'), - full[:, :, frame_mask].mean(axis=(2,))) - assert np.array_equal(array[:, :, frame_mask].min(axis='frame'), - full[:, :, frame_mask].min(axis=(2,))) - assert np.array_equal(array[:, :, frame_mask].max(axis='frame'), - full[:, :, frame_mask].max(axis=(2,))) + assert np.array_equal(array[scan_mask].sum(axis="scan"), full[scan_mask].sum(axis=(0,))) + assert np.array_equal(array[scan_mask].min(axis="scan"), full[scan_mask].min(axis=(0,))) + assert np.array_equal(array[scan_mask].max(axis="scan"), full[scan_mask].max(axis=(0,))) + assert np.allclose(array[scan_mask].mean(axis="scan"), full[scan_mask].mean(axis=(0,))) + + assert np.array_equal(array[:, :, frame_mask].sum(axis="frame"), full[:, :, frame_mask].sum(axis=(2,))) + assert np.allclose(array[:, :, frame_mask].mean(axis="frame"), full[:, :, frame_mask].mean(axis=(2,))) + assert np.array_equal(array[:, :, frame_mask].min(axis="frame"), full[:, :, frame_mask].min(axis=(2,))) + assert np.array_equal(array[:, :, frame_mask].max(axis="frame"), full[:, :, frame_mask].max(axis=(2,))) # Now do some basic tests with multiple frames per scan data = np.empty((2, 2), dtype=object) @@ -678,11 +675,11 @@ def compare_with_sparse(full, sparse): data[1][0] = np.array([0]) data[1][1] = np.array([0, 1]) kwargs = { - 'data': data, - 'scan_shape': (2, 1), - 'frame_shape': (2, 2), - 'sparse_slicing': False, - 'allow_full_expand': True, + "data": data, + "scan_shape": (2, 1), + "frame_shape": (2, 2), + "sparse_slicing": False, + "allow_full_expand": True, } # The multiple frames per scan array m_array = SparseArray(**kwargs) @@ -702,10 +699,8 @@ def compare_with_sparse(full, sparse): assert np.array_equal(m_array[[1, 0], [0]][0], position_one) assert np.array_equal(m_array[[1, 0], [0]][1], position_zero) assert np.array_equal(m_array[0, 0, [0, 1], [0, 0]], [2, 1]) - assert np.array_equal(m_array[0, 0, [[True, False], [True, False]]], - [2, 1]) - assert np.array_equal(m_array[0, 0, [[False, True], [False, True]]], - [0, 0]) + assert np.array_equal(m_array[0, 0, [[True, False], [True, False]]], [2, 1]) + assert np.array_equal(m_array[0, 0, [[False, True], [False, True]]], [0, 0]) assert np.array_equal(m_array[[True, False], 0][0], position_zero) assert np.array_equal(m_array[[False, True], 0][0], position_one) @@ -715,7 +710,8 @@ def test_keep_flyback(electron_data_small): assert flyback.scan_shape[1] == 50 no_flyback = SparseArray.from_hdf5(electron_data_small, keep_flyback=False) assert no_flyback.scan_shape[1] == 49 - + + # Test binning until this number TEST_BINNING_UNTIL = 33 @@ -736,10 +732,7 @@ def test_keep_flyback(electron_data_small): (slice(4), slice(7), slice(3), slice(5)), (slice(2, 8), slice(5, 9), slice(2, 20), slice(50, 200)), (slice(2, 18, 2), slice(3, 20, 3), slice(2, 50, 2), slice(20, 100)), - (slice(None, None, 2), slice(None, None, 3), slice(None, None, 4), - slice(None, None, 5)), - (slice(3, None, 2), slice(5, None, 5), slice(4, None, 4), - slice(20, None, 5)), - (slice(20, 4, -2), slice(None, None, -1), slice(4, None, -3), - slice(100, 3, -5)), + (slice(None, None, 2), slice(None, None, 3), slice(None, None, 4), slice(None, None, 5)), + (slice(3, None, 2), slice(5, None, 5), slice(4, None, 4), slice(20, None, 5)), + (slice(20, 4, -2), slice(None, None, -1), slice(4, None, -3), slice(100, 3, -5)), ]