From 26f902cdbe8da44b30430a8a08133adaa1af9e12 Mon Sep 17 00:00:00 2001 From: Carl Johnsen Date: Wed, 18 Sep 2024 08:24:40 +0200 Subject: [PATCH] #34 Added docstrings to 1050 --- src/processing_steps/1050_compute_ridges.py | 346 ++++++++++++++++++++ 1 file changed, 346 insertions(+) diff --git a/src/processing_steps/1050_compute_ridges.py b/src/processing_steps/1050_compute_ridges.py index 0c0ad69..ea9f6f8 100644 --- a/src/processing_steps/1050_compute_ridges.py +++ b/src/processing_steps/1050_compute_ridges.py @@ -1,3 +1,7 @@ +#! /usr/bin/python3 +''' +This script computes the connected lines in a 2D histogram. It can either be run in GUI mode, where one tries to find the optimal configuration parameters, or batch mode, where it processes one or more histograms into images. In GUI mode, one can specify the bounding box by dragging a box on one of the images, specify the line to highlight by left clicking and reset the bounding box by middle clicking. +''' import sys sys.path.append(sys.path[0]+"/../") import matplotlib @@ -16,6 +20,12 @@ import time def batch(): + ''' + Processes the histograms in batch mode. The histograms are stored in the output folder specified in command line arguments. + + The histograms are stored in the output folder as npz files, where each field is stored as a separate image. The fields are stored as uint8 images, where each pixel is assigned a unique value, corresponding to the connected line it belongs to. The fields are stored as the field name, e.g. "x", "y", "z", "r" or the field name from the histogram file. + ''' + global args, config if not os.path.exists(args.output): @@ -54,6 +64,30 @@ def batch(): print (f'Processed {sample}') def load_config(filename): + ''' + Loads the configuration file if it exists, otherwise it returns the default configuration. + + The configuration file is a json file, which stores the following parameters: + - min peak height: The minimum height of a peak in the line plot to be considered a peak. + - close kernel y: The y-size of the kernel used for closing the mask. + - close kernel x: The x-size of the kernel used for closing the mask. + - line smooth: The size of the gaussian filter used for smoothing the line plot. + - iter dilate: The number of iterations used for dilating the mask. + - iter erode: The number of iterations used for eroding the mask. + - min contour size: The minimum size of a contour to be considered a contour. + - joint kernel size: The size of the kernel used for finding joints in the skeleton. + + Parameters + ---------- + `filename` : str + The filename of the configuration file. + + Returns + ------- + `config` : dict[str, int] + The configuration parameters as a dictionary. + ''' + if os.path.exists(filename): with open(filename, 'r') as f: config = json.load(f) @@ -68,19 +102,68 @@ def load_config(filename): 'min contour size': 1500, 'joint kernel size': 2 } + return config def load_hists(filename): + ''' + Loads the histograms from the specified file. + + The histograms are stored in a npz file, where each field is stored as a separate array. The fields are stored as the field name, e.g. "x", "y", "z", "r" or the field name from the histogram file. + + Parameters + ---------- + `filename` : str + The filename of the histogram file. + + Returns + ------- + `hists` : dict[str, numpy.array[uint64]] + The histograms as a dictionary, where the keys are the field names and the values are the histograms. + ''' + hists = np.load(filename) results = {} for name, hist in hists.items(): hist_sum = np.sum(hist, axis=1) hist_sum[hist_sum==0] = 1 results[name] = hist / hist_sum[:,np.newaxis] + return hists class _range: + ''' + A class for storing the range of the bounding box. + + The class stores the start and stop values for the x and y axis of the bounding box. It also has a method for updating the values from the trackbars in the GUI. + + Attributes + ---------- + `x` : inner_range + The range for the x-axis. + `y` : inner_range + The range for the y-axis. + + Methods + ------- + `update()` + Updates the values of the range from the trackbars in the GUI. + `__init__(range_x_start=0, range_x_stop=0, range_y_start=0, range_y_stop=0)` + Initializes the range with the specified values. + ''' + class inner_range: + ''' + A class for storing the range of one axis. + + Attributes + ---------- + `start` : int + The start value of the range. + `stop` : int + The stop value of the range. + ''' + start = 0 stop = 0 @@ -88,6 +171,15 @@ class inner_range: y = inner_range() def update(self): + ''' + Updates the values of the range from the trackbars in the GUI. + + Returns + ------- + `changed` : bool + Whether the values have changed. + ''' + changed = False tmp = self.x.start self.x.start = cv2.getTrackbarPos('range start x', 'Histogram lines') @@ -101,15 +193,40 @@ def update(self): tmp = self.y.stop self.y.stop = cv2.getTrackbarPos('range stop y', 'Histogram lines') changed |= tmp == self.y.stop + return changed, self def __init__(self, range_x_start=0, range_x_stop=0, range_y_start=0, range_y_stop=0): + ''' + Initializes the range with the specified values. + + Parameters + ---------- + `range_x_start` : int + The start value of the x-axis. + `range_x_stop` : int + The stop value of the x-axis. + `range_y_start` : int + The start value of the y-axis. + `range_y_stop` : int + The stop value of the y-axis. + ''' + self.x.start = range_x_start self.x.stop = range_x_stop self.y.start = range_y_start self.y.stop = range_y_stop def parse_args(): + ''' + Parses the command line arguments. + + Returns + ------- + `args` : argparse.Namespace + The parsed command line arguments. + ''' + parser = argparse.ArgumentParser(description="""Computes the connected lines in a 2D histogram. It can either be run in GUI mode, where one tries to find the optimal configuration parameters, or batch mode, where it processes one or more histograms into images. In GUI mode, one can specify the bounding box by dragging a box on one of the images, specify the line to highlight by left clicking and reset the bounding box by middle clicking. Example command for running with default configuration: @@ -139,6 +256,24 @@ def parse_args(): return args def plot_line(line, rng: _range): + ''' + Plots the line and the meaned line with the peaks highlighted. + + The line is plotted as a line plot, where the meaned line is plotted on top of it. The peaks are highlighted with red dots. + + Parameters + ---------- + `line` : numpy.array[uint64] + The line to plot. + `rng` : _range + The range of the bounding box. + + Returns + ------- + `line_plot` : numpy.array[uint8] + The line plot as an image. + ''' + global config, disable_matplotlib if disable_matplotlib: @@ -156,6 +291,26 @@ def plot_line(line, rng: _range): return line_plot def process_closing(mask, config): + ''' + Processes the mask by closing (dilate followed by erode). + + The mask is first dilated and then eroded with a cross kernel of the specified size. The number of iterations for dilating and eroding can also be specified in the configuration. + + Parameters + ---------- + `mask` : numpy.array[uint8] + The mask to process. + `config` : dict[str, int] + The configuration parameters. + + Returns + ------- + `dilated` : numpy.array[uint8] + The dilated mask. + `eroded` : numpy.array[uint8] + The eroded mask (after dilation). + ''' + close_kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (config['close kernel x'], config['close kernel y'])) dilated = cv2.dilate(mask, close_kernel, iterations=config['iter dilate']) eroded = cv2.erode(dilated, close_kernel, iterations=config['iter erode']) @@ -163,6 +318,26 @@ def process_closing(mask, config): return dilated, eroded def process_contours(hist, rng: _range, config): + ''' + Processes the histogram by finding the contours. + + The histogram is first thresholded to find the contours. The contours are then filtered by size, where the minimum size can be specified in the configuration. The contours are then sorted by x-coordinate and drawn on a blank image. + + Parameters + ---------- + `hist` : numpy.array[uint64] + The histogram to process. + `rng` : _range + The range of the bounding box. + `config` : dict[str, int] + The configuration parameters. + + Returns + ------- + `result` : numpy.array[uint8] + The image with the contours drawn. + ''' + contours, _ = cv2.findContours(hist, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours_sizes = [cv2.arcLength(c, True) for c in contours] contours_filtered = [contours[i] for i, size in enumerate(contours_sizes) if size > config['min contour size']] @@ -176,6 +351,22 @@ def process_contours(hist, rng: _range, config): return result, [i+1 for i in range(len(contours_filtered))] def process_joints(hist): + ''' + Processes the histogram by finding the joints in the skeleton. + + The histogram is first thresholded to find the skeleton. The skeleton is then dilated with a horizontal and vertical kernel. The joints are then found by taking the intersection of the horizontal and vertical dilated skeletons. + + Parameters + ---------- + `hist` : numpy.array[uint64] + The histogram to process. + + Returns + ------- + `joints` : numpy.array[uint8] + The image with the joints drawn. + ''' + global config tmp = cv2.threshold(hist, 1, 1, cv2.THRESH_BINARY)[1] @@ -193,12 +384,56 @@ def process_joints(hist): return joints def process_line(line, config): + ''' + Processes the line by smoothing it and finding the peaks. + + The line is first smoothed with a gaussian filter. The peaks are then found by the scipy function `find_peaks`. + + Parameters + ---------- + `line` : numpy.array[uint64] + The line to process. + `config` : dict[str, int] + The configuration parameters. + + Returns + ------- + `meaned` : numpy.array[float64] + The smoothed line. + `peaks` : list[int] + The indices of the peaks. + ''' + meaned = gaussian_filter1d(line, config['line smooth']) peaks, _ = signal.find_peaks(meaned, .01*config['min peak height']*line.max()) return meaned, peaks def process_scatter_peaks(hist, rng: _range, peaks=None): + ''' + Processes the histogram by finding the peaks and plotting them. + + The histogram is first smoothed with a gaussian filter. The peaks are then found by the scipy function `find_peaks`. The peaks are then scattered on the cutout of the original histogram. + + Parameters + ---------- + `hist` : numpy.array[uint64] + The histogram to process. + `rng` : _range + The range of the bounding box. + `peaks` : list[int] + The peaks to scatter. If None, the peaks are found. + + Returns + ------- + `scatter_plot` : numpy.array[uint8] + The image with the peaks scattered. + `px` : list[int] + The x-coordinates of the peaks. + `py` : list[int] + The y-coordinates of the peaks. + ''' + global config, disable_matplotlib # Show the peaks scattered on the cutout of the original @@ -217,6 +452,32 @@ def process_scatter_peaks(hist, rng: _range, peaks=None): return scatter_plot, px, py def process_with_box(hist, rng: _range, selected_line, scale_x, scale_y, partial_size): + ''' + Processes the histogram by plotting the histogram with the bounding box and the selected line. + + The histogram is first normalized and scaled to 255. The bounding box is then drawn on the histogram. The selected line is also drawn on the histogram. + + Parameters + ---------- + `hist` : numpy.array[uint64] + The histogram to process. + `rng` : _range + The range of the bounding box. + `selected_line` : int + The selected line to highlight. + `scale_x` : float + The scaling factor for the x-axis. + `scale_y` : float + The scaling factor for the y-axis. + `partial_size` : tuple[int] + The size of the partial image. + + Returns + ------- + `colored` : numpy.array[uint8] + The image with the histogram, bounding box and selected line drawn. + ''' + display = (row_normalize(hist) * 255).astype(np.uint8) box_x_start = int(rng.x.start * scale_x) box_y_start = int(rng.y.start * scale_y) @@ -231,21 +492,58 @@ def process_with_box(hist, rng: _range, selected_line, scale_x, scale_y, partial return colored def save_config(): + ''' + Saves the configuration to the configuration file. + + The configuration is saved as a json file at the location specified in the command line arguments. + ''' + global config, args with open(args.config, 'w') as f: json.dump(config, f) def scatter_peaks(hist, config): + ''' + Scatters the peaks on the histogram. + + The histogram is first smoothed with a gaussian filter. The peaks are then found by the scipy function `find_peaks`. The peaks are then scattered on the histogram. + + Parameters + ---------- + `hist` : numpy.array[uint64] + The histogram to process. + `config` : dict[str, int] + The configuration parameters. + + Returns + ------- + `px` : list[int] + The x-coordinates of the peaks. + `py` : list[int] + The y-coordinates of the peaks. + ''' + meaned = gaussian_filter1d(hist, config['line smooth'], axis=1) maxs = meaned.max(axis=1)*.01*config['min peak height'] fp = lambda x,y: signal.find_peaks(x, y)[0] peaks_x = [fp(row, maxval) for row, maxval in zip(meaned, maxs)] peaks_y = [[i]*len(peaks) for i, peaks in enumerate(peaks_x)] flatten = lambda x: [elem for li in x for elem in li] + return flatten(peaks_x), flatten(peaks_y) def gui(): + ''' + Runs the GUI for finding the optimal configuration parameters. + + The GUI consists of several trackbars, where the user can adjust the parameters. The user can also drag a box on the histogram to specify the bounding box, left click to specify the line to highlight and middle click to reset the bounding box. The user can also save the configuration by pressing the 's' key. + + The GUI is updated whenever the trackbars are changed. + + The GUI is closed by pressing the escape key, which also saves the configuration if the 'dry run' commandline argument is not set. + ''' + global last, args, ready ready = False @@ -263,6 +561,29 @@ def update_image(_): update(True) def update_line(event, x, y, flags, param): + ''' + Updates the line and bounding box when the user interacts with the GUI. + + The line is updated when the user left clicks on the histogram. The bounding box is updated when the user drags a box on the histogram. The bounding box is reset when the user middle clicks on the histogram. + + Parameters + ---------- + `event` : int + The event type. + `x` : int + The x-coordinate of the event. + `y` : int + The y-coordinate of the event. + `flags` : int + The flags of the event, passed by OpenCV. + `param` : int + The parameter of the event, passed by OpenCV. + + Returns + ------- + `None` + ''' + global mx, my, scale_x, scale_y, disable_matplotlib width_factor = 3 @@ -310,6 +631,21 @@ def update_line(event, x, y, flags, param): # TODO Only do recomputation, whenever parameters that have an effect are changed. def update(force=False): # Note: all colors are in BGR format, as this is what OpenCV uses + ''' + Updates the GUI. + + The GUI is updated by processing the histogram with the specified parameters. The histogram is then displayed in the GUI. + + Parameters + ---------- + `force` : bool + Whether the GUI should be updated even if the parameters haven't changed. + + Returns + ------- + `None` + ''' + global last, config, scale_x, scale_y, disable_matplotlib, ready times = [] @@ -453,6 +789,15 @@ def update(force=False): # Note: all colors are in BGR format, as this is what O print (label, tim - times[i-1][1]) def update_config(): + ''' + Updates the configuration from the trackbars. + + Returns + ------- + `changed` : bool + Whether the configuration has changed. + ''' + global config, ready if not ready: @@ -533,6 +878,7 @@ def update_config(): args = parse_args() config = load_config(args.config) disable_matplotlib = args.disable_matplotlib + if args.batch: batch() else: