diff --git a/HISTORY.rst b/HISTORY.rst index 25ba7fb..2b4a9aa 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -345,4 +345,12 @@ History * added contour_overlayed * moved loopprocessor to multiprocessor * added printv -* added plot_marker \ No newline at end of file +* added plot_marker + +0.12.16 (2024-09-10) +------------------- +* added plt_confusion_matrix +* changed the name of imshow_series and imshow_by_subplots and plot_marker +* plot_marker is plt_mark +* fixes for plt_utils +* all tests passed! \ No newline at end of file diff --git a/lognflow/__init__.py b/lognflow/__init__.py index d281b14..3fc84cf 100644 --- a/lognflow/__init__.py +++ b/lognflow/__init__.py @@ -5,15 +5,15 @@ __version__ = '0.12.15' from .lognflow import lognflow, getLogger -from .logviewer import logviewer from .printprogress import printprogress -from .plt_utils import ( - plt_colorbar, plt_imshow, plt_violinplot, plt_imhist, transform3D_viewer) -from .utils import ( - select_directory, select_file, repr_raw, replace_all, - is_builtin_collection, text_to_collection, stack_to_frame, - stacks_to_frames, ssh_system, printv, Pyrunner) -from .multiprocessor import multiprocessor, loopprocessor +from .plt_utils import plt_imshow, plt_imhist, plt_hist2 +from .utils import (select_directory, + select_file, + is_builtin_collection, + text_to_collection, + printv) + +from .multiprocessor import multiprocessor def basicConfig(*args, **kwargs): ... \ No newline at end of file diff --git a/lognflow/lognflow.py b/lognflow/lognflow.py index 6001472..9727b46 100644 --- a/lognflow/lognflow.py +++ b/lognflow/lognflow.py @@ -52,18 +52,19 @@ from .utils import (repr_raw, replace_all, select_directory, - stack_to_frame, name_from_file, is_builtin_collection, text_to_collection, dummy_function) from .plt_utils import (plt_colorbar, + stack_to_frame, plt_hist, plt_surface, - imshow_series, - imshow_by_subplots, + plt_imshow_series, + plt_imshow_subplots, plt_imshow, - plt_scatter3) + plt_scatter3, + plt_confusion_matrix) from typing import Union @dataclass @@ -1343,7 +1344,7 @@ def imshow_subplots(self, if not self.enabled: return time_tag = self.time_tag if (time_tag is None) else time_tag - fig, ax = imshow_by_subplots(images = images, + fig, ax = plt_imshow_subplots(images = images, frame_shape = frame_shape, grid_locations = grid_locations, figsize = figsize, @@ -1422,17 +1423,17 @@ def imshow_series(self, if not self.enabled: return time_tag = self.time_tag if (time_tag is None) else time_tag - fig, ax = imshow_series(list_of_stacks, - list_of_masks = list_of_masks, - figsize = figsize, - figsize_ratio = figsize_ratio, - text_as_colorbar = text_as_colorbar, - colorbar = colorbar, - cmap = cmap, - list_of_titles_columns = list_of_titles_columns, - list_of_titles_rows = list_of_titles_rows, - fontsize = fontsize, - transpose = transpose) + fig, ax = plt_imshow_series(list_of_stacks, + list_of_masks = list_of_masks, + figsize = figsize, + figsize_ratio = figsize_ratio, + text_as_colorbar = text_as_colorbar, + colorbar = colorbar, + cmap = cmap, + list_of_titles_columns = list_of_titles_columns, + list_of_titles_rows = list_of_titles_rows, + fontsize = fontsize, + transpose = transpose) if not return_figure: fpath = self.savefig( @@ -1489,7 +1490,8 @@ def log_confusion_matrix(self, figsize = None, image_format = 'jpg', dpi = 1200, - time_tag = False): + time_tag = False, + close_plt = True): """log a confusion matrix given a sklearn confusion matrix (cm), make a nice plot @@ -1529,43 +1531,16 @@ def log_confusion_matrix(self, """ if not self.enabled: return - accuracy = np.trace(cm) / np.sum(cm).astype('float') - misclass = 1 - accuracy - - if figsize is None: - figsize = np.ceil(cm.shape[0]/3) - - if target_names is None: - target_names = [chr(x + 65) for x in range(cm.shape[0])] - - if cmap is None: - cmap = plt.get_cmap('Blues') - - plt.figure(figsize=(4*figsize, 4*figsize)) - im = plt.imshow(cm, interpolation='nearest', cmap=cmap) - - if target_names is not None: - tick_marks = np.arange(len(target_names)) - plt.xticks(tick_marks, target_names, rotation=45) - plt.yticks(tick_marks, target_names) - - for i, j in itertools_product(range(cm.shape[0]), range(cm.shape[1])): - clr = np.array([1, 1, 1, 0]) \ - * (cm[i, j] - cm.min()) \ - / (cm.max() - cm.min()) + np.array([0, 0, 0, 1]) - plt.text(j, i, f"{cm[i, j]:2.02f}", horizontalalignment="center", - color=clr) - - plt.ylabel('True label') - plt.xlabel('Predicted label\naccuracy={:0.4f}; ' \ - + 'misclass={:0.4f}'.format(accuracy, misclass)) - plt.title(title) - plt.colorbar(im, fraction=0.046, pad=0.04) + + plt_confusion_matrix( + cm, target_names=target_names, title=title, + cmap=cmap, figsize=figsize) fpath = self.savefig( parameter_name = parameter_name, image_format=image_format, dpi=dpi, - time_tag = time_tag) + time_tag = time_tag, + close_plt = close_plt) return fpath def log_animation( diff --git a/lognflow/plt_utils.py b/lognflow/plt_utils.py index a6a066d..8d75d95 100644 --- a/lognflow/plt_utils.py +++ b/lognflow/plt_utils.py @@ -9,6 +9,7 @@ from scipy.spatial.transform import Rotation as scipy_rotation from .printprogress import printprogress from itertools import cycle as itertools_cycle +from itertools import product as itertools_product matplotlib_lines_Line2D_markers_keys_cycle = itertools_cycle([ 's', '*', 'd', 'X', 'v', '.', 'x', '|', 'D', '<','^', '8','p', @@ -37,6 +38,93 @@ def complex2hsv(data_complex, vmin=None, vmax=None): return hsv_to_rgb(H), data_abs, data_angle +def stack_to_frame(stack, frame_shape : tuple = None, borders = 0): + """ turn a stack of images into a 2D frame of images + This is very useful when lots of images need to be tiled + against each other. + + Note: if the last dimension is 3, all images are RGB, if you don't wish that + you have to add another dimension at the end by np.expand_dim(arr, axis = -1) + + :param stack: np.ndarray + It must have the shape of either + n_im x n_r x n_c + n_im x n_r x 3 x 1 + n_im x n_r x n_c x 3 + + In all cases n_im will be turned into a frame + Remember if you have N images to put into a square, the input + shape should be 1 x n_r x n_c x N + :param frame_shape: tuple + The shape of the frame to put n_rows and n_colmnss of images + close to each other to form a rectangle of image. + :param borders: literal or np.inf or np.nan + When plotting images with matplotlib.pyplot.imshow, there + needs to be a border between them. This is the value for the + border elements. + + output + --------- + Since we have N channels to be laid into a square, the side + length would be ceil(N**0.5) if frame_shape is not given. + it produces an np.array of shape n_f x n_r * f_r x n_c * f_c or + n_f x n_r * f_r x n_c * f_c x 3 in case of an RGB input. + """ + is_rgb = stack.shape[-1] == 3 + + if(len(stack.shape) == 4): + if((stack.shape[2] == 3) & (stack.shape[3] == 1)): + stack = stack[..., 0] + + n_im, n_R, n_C = stack.shape[:3] + + if(len(stack.shape) == 4): + assert is_rgb, 'For a stack of images with axis 3, it should be 1 or 3.' + + assert (len(stack.shape) == 3) | (len(stack.shape) == 4), \ + f'The stack you provided can have specific shapes. it is {stack.shape}' + + if(frame_shape is None): + square_side = int(np.ceil(np.sqrt(n_im))) + frame_n_r, frame_n_c = (square_side, square_side) + else: + frame_n_r, frame_n_c = frame_shape + n_R += 2 + n_C += 2 + new_n_R = n_R * frame_n_r + new_n_C = n_C * frame_n_c + + if is_rgb: + frame = np.zeros((new_n_R, new_n_C, 3), dtype = stack.dtype) + else: + frame = np.zeros((new_n_R, new_n_C), dtype = stack.dtype) + used_ch_cnt = 0 + if(borders is not None): + frame += borders + for rcnt in range(frame_n_r): + for ccnt in range(frame_n_c): + ch_cnt = rcnt + frame_n_c*ccnt + if (ch_cnt window.width()) & (portrait is None): portrait = True @@ -381,7 +581,25 @@ def plt_imshow(img, else: ax = [fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2)] - if complex_type == 'abs_angle': + complex_real_imag = False + if cmap is not None: + if 'real_imag' in cmap: + complex_real_imag = True + if complex_real_imag: + cmap = cmap.split('real_imag')[0] + if len(cmap) == 0: cmap = None + else: cmap = cmap[:-1] + if angle_cmap is None: + angle_cmap = cmap + im = ax[0].imshow(np.real(img), cmap = cmap, **kwargs) + if(colorbar): + plt_colorbar(im) + ax[0].set_title('real') + im = ax[1].imshow(np.imag(img), cmap = angle_cmap, **kwargs) + if(colorbar): + plt_colorbar(im) + ax[1].set_title('imag') + else: im = ax[0].imshow(np.abs(img), cmap = cmap, **kwargs) if(colorbar): plt_colorbar(im) @@ -392,16 +610,7 @@ def plt_imshow(img, if(colorbar): plt_colorbar(im) ax[1].set_title('angle') - elif complex_type == 'real_imag': - im = ax[0].imshow(np.real(img), cmap = cmap, **kwargs) - if(colorbar): - plt_colorbar(im) - ax[0].set_title('real') - im = ax[1].imshow(np.imag(img), cmap = angle_cmap, **kwargs) - if(colorbar): - plt_colorbar(im) - ax[1].set_title('imag') - + if(remove_axis_ticks): plt.setp(ax[0], xticks=[], yticks=[]) ax[0].xaxis.set_ticks_position('none') @@ -411,6 +620,8 @@ def plt_imshow(img, ax[1].yaxis.set_ticks_position('none') if title is not None: fig.suptitle(title) + fig.canvas.manager.window.setWindowTitle(title) + return fig, ax def plt_hist(vectors_list, fig_ax = None, @@ -455,7 +666,8 @@ def plt_scatter3( data_N_by_3[:, 2], **kwargs) if title is not None: - ax.set_title(title) + ax.set_title(title) + fig.canvas.manager.window.setWindowTitle(title) try: elev_list = [int(elev_list)] except: pass @@ -691,18 +903,18 @@ def show(self, show_legend = True): def __call__(self, *args, **kwargs): self.addPlot(*args, **kwargs) -def imshow_series(list_of_stacks, - list_of_masks = None, - figsize = None, - figsize_ratio = 1, - text_as_colorbar = False, - colorbar = False, - cmap = 'viridis', - list_of_titles_columns = None, - list_of_titles_rows = None, - fontsize = None, - transpose = True, - ): +def plt_imshow_series(list_of_stacks, + list_of_masks = None, + figsize = None, + figsize_ratio = 1, + text_as_colorbar = False, + colorbar = False, + cmap = 'viridis', + list_of_titles_columns = None, + list_of_titles_rows = None, + fontsize = None, + transpose = True, + ): """ imshow a stack of images or sets of images in a shelf, input must be a list or array of images @@ -819,7 +1031,7 @@ def imshow_series(list_of_stacks, cbar.ax.tick_params(labelsize=1) return fig, None -def imshow_by_subplots( +def plt_imshow_subplots( images, grid_locations=None, frame_shape = None, title = None, titles=[], cmaps=[], colorbar=True, margin = 0.025, inter_image_margin = 0.01, colorbar_aspect=2, colorbar_pad_fraction=0.05, @@ -911,6 +1123,7 @@ def imshow_by_subplots( colorbar_pad_fraction=colorbar_pad_fraction) if title is not None: fig.suptitle(title) + fig.canvas.manager.window.setWindowTitle(title) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.margins(margin) @@ -1130,70 +1343,116 @@ def update_value(self, label, step_label, direction, event): step_size = float(self.params[f"{step_label}_text_box"].text) new_val = current_val + direction * step_size self.params[f"{label}_text_box"].set_val(f"{new_val:.6f}") - + class _questdiag: - def __init__(self, - question = '', - figsize=(6, 2), - buttons = {'Yes' : True, - 'No' : False, - 'Cancel' : None}, - row_spacing=0.05): - + def __init__(self, question, buttons, figsize, question_hratio): + assert isinstance(buttons, dict), \ ('buttons arg must be a dictionary of texts appearing on ' 'the buttons values to be returned.') self.buttons = buttons self.result = None - _, ax = plt.subplots(figsize=figsize) - plt.subplots_adjust(bottom=0.2) - - ax.text(0.5, 0.85, question, ha='center', va='center', fontsize=12) - plt.axis('off') - - # Calculate grid size - N = len(buttons) - n_rows = int(np.ceil(N ** 0.5)) - n_clms = int(np.ceil(N / n_rows)) - - # Button size and position - button_width = 0.8 / n_clms - button_height = 0.3 / n_rows - horizontal_spacing = (1 - button_width * n_clms) / (n_clms + 1) - vertical_spacing = (0.2 - button_height * n_rows) / (n_rows + 1) - - # Adjust vertical_spacing to add more space between rows - vertical_spacing += row_spacing - + + # Calculate the number of rows and columns for the buttons + N = len(self.buttons) + n_rows = int(np.ceil(N ** 0.5)) # Number of rows for buttons + n_cols = int(np.ceil(N / n_rows)) # Number of columns for buttons + + if N == 1: n_rows, n_cols = 1, 1 + if N == 2: n_rows, n_cols = 1, 2 + if N == 3: n_rows, n_cols = 1, 3 + if N == 6: n_rows, n_cols = 2, 3 + + if question_hratio is None: + if isinstance(question, np.ndarray): + question_hratio = 10 + else: + question_hratio = 1 + + if figsize is None: + if isinstance(question, np.ndarray): + figsize = (5, 5) + else: + figsize = (5, 1) + + # Create the figure and GridSpec layout + fig = plt.figure(figsize=figsize) + gs = matplotlib.gridspec.GridSpec(n_rows + 2, n_cols, + figure=fig, height_ratios=[question_hratio] + [1] * (n_rows + 1)) + # First row (3x height) for the question, remaining rows for buttons + + # Top section for the question (span the entire width) + ax_question = fig.add_subplot(gs[0, :]) + + # Handle different types of questions + if isinstance(question, np.ndarray): + if len(question.shape) == 1: + ax_question.plot(question) + elif len(question.shape) == 2: + plt_imshow(question, fig_ax = (fig, ax_question),) + plt.axis('on') # Keep axis on for plots and images + else: + ax_question.text(0.5, 0.5, str(question), ha='center', va='center', fontsize=12) + ax_question.set_axis_off() # No axis for text questions + + # Create buttons and place them on the grid button_objects = [] - for i, button_label in enumerate(buttons.keys()): - row = i // n_clms - col = i % n_clms - - button_ax = plt.axes([ - horizontal_spacing + col * (button_width + horizontal_spacing), - 0.2 + (n_rows - row - 1) * (button_height + vertical_spacing), # Start from top - button_width, - button_height - ]) - button = Button(button_ax, str(button_label)) + for i, (label, val) in enumerate(self.buttons.items()): + row = 2 + i // n_cols + col = i % n_cols + button_ax = fig.add_subplot(gs[row, col]) + button = Button(button_ax, str(label)) button.on_clicked(self.button_click) button_objects.append(button) plt.show() def button_click(self, event): - ind = event.inaxes.texts[0].get_text() - self.result = self.buttons[ind] # Return the corresponding output - plt.close() + ind = event.inaxes.texts[0].get_text() # Get text of the clicked button + self.result = self.buttons[ind] # Return the corresponding output + plt.close() # Close the plot after a button is clicked def question_dialog( - question = 'Yes/No/Cancel?', figsize=(6, 2), - buttons = {'Yes' : True, 'No' : False, 'Cancel' : None}): - return _questdiag(question, figsize, buttons).result + question = 'Yes/No/Cancel?', + buttons={'Yes': True, 'No': False, 'Cancel': None}, + figsize = None, + question_hratio = None): + """ Question dialog + Creates a dialog with a question displayed at the top and a grid of buttons below it. + + The function supports displaying questions as text, 1D numpy arrays (as line plots), + or 2D numpy arrays (as images). It displays buttons beneath the question, allowing the + user to select one of the provided options. The buttons are organized into a grid + layout based on the number of buttons provided. When a button is clicked, the function + returns the corresponding value associated with the button in the `buttons` dictionary. + + Parameters + ---------- + question : str, np.ndarray, optional + The question to be presented. It can be a string, a 1D numpy array (plotted as a + line), or a 2D numpy array (displayed as an image). Default is 'Yes/No/Cancel?'. + + buttons : dict, optional + A dictionary where the keys are the text labels that will appear on the buttons, + and the values are the corresponding values to return when the button is clicked. + Default is {'Yes': True, 'No': False, 'Cancel': None}. + + figsize : tuple, optional + A tuple specifying the size of the figure (width, height) in inches. Default is (6, 2). + + question_hratio: int, optional + If you are sending an image as a question, you can set the height ratio to + buttons here, we suggest 4 + Returns + ------- + result : any + The value associated with the button clicked by the user. If 'Yes' is clicked, + returns `True`; if 'No', returns `False`; and if 'Cancel', returns `None`. + """ + return _questdiag(question, buttons, figsize, question_hratio).result -def plot_marker( +def plt_mark( coords, fig_ax=None, figsize=(2, 2), marker=None, markersize = None, return_markersize = False): """ @@ -1232,8 +1491,9 @@ def plot_marker( else: return fig, ax -def plt_contours(Z_list, X_Y = None, fig_ax=None, levels=10, colors_list=None, - linestyles_list=None, title=None): +def plt_contours( + Z_list, X_Y = None, fig_ax = None, levels = 10, colors_list = None, + linestyles_list = None, linewidth = 0.5, fontsize = 3, title = None): """ Plot contours of multiple surfaces overlaid on the same plot. @@ -1247,7 +1507,7 @@ def plt_contours(Z_list, X_Y = None, fig_ax=None, levels=10, colors_list=None, - colors_list: List of colors for the contours of each surface. If None, defaults to a colormap. - linestyles_list: List of line styles for the contours of each surface. - If None, defaults to a pattern. + If None, defaults to a pattern. - title: Optional title for the plot. """ @@ -1271,13 +1531,16 @@ def plt_contours(Z_list, X_Y = None, fig_ax=None, levels=10, colors_list=None, X, Y = X_Y color = colors_list[i % len(colors_list)] linestyle = linestyles_list[i % len(linestyles_list)] - - contour = ax.contour(X, Y, Z, levels=levels, colors=[color],linestyles=linestyle) + contour = ax.contour(X, Y, Z, levels=levels, colors=[color], + linestyles=linestyle, linewidths = linewidth) # Add labels to contours - ax.clabel(contour, inline=True, fontsize=8, fmt='%.2f') + ax.clabel(contour, inline=True, fontsize=fontsize, fmt='%.2f') + ax.set_aspect('equal') + if title is not None: ax.set_title(title) + fig.canvas.manager.window.setWindowTitle(title) return fig, ax \ No newline at end of file diff --git a/lognflow/utils.py b/lognflow/utils.py index f2b7965..6bdc81e 100644 --- a/lognflow/utils.py +++ b/lognflow/utils.py @@ -171,130 +171,108 @@ def parse_node(node): tree = ast.parse(text, mode='eval') return parse_node(tree.body) +class SSHSystem: + """ + A class to handle basic SSH and SFTP operations on a remote system. -def stack_to_frame(stack, frame_shape : tuple = None, borders = 0): - """ turn a stack of images into a 2D frame of images - This is very useful when lots of images need to be tiled - against each other. - - Note: if the last dimension is 3, all images are RGB, if you don't wish that - you have to add another dimension at the end by np.expand_dim(arr, axis = -1) - - :param stack: np.ndarray - It must have the shape of either - n_im x n_r x n_c - n_im x n_r x 3 x 1 - n_im x n_r x n_c x 3 - - In all cases n_im will be turned into a frame - Remember if you have N images to put into a square, the input - shape should be 1 x n_r x n_c x N - :param frame_shape: tuple - The shape of the frame to put n_rows and n_colmnss of images - close to each other to form a rectangle of image. - :param borders: literal or np.inf or np.nan - When plotting images with matplotlib.pyplot.imshow, there - needs to be a border between them. This is the value for the - border elements. - - output - --------- - Since we have N channels to be laid into a square, the side - length would be ceil(N**0.5) if frame_shape is not given. - it produces an np.array of shape n_f x n_r * f_r x n_c * f_c or - n_f x n_r * f_r x n_c * f_c x 3 in case of an RGB input. + Attributes: + ssh_client (paramiko.SSHClient): The SSH client for executing commands on the remote system. + sftp_client (paramiko.SFTPClient): The SFTP client for file transfer operations. """ - is_rgb = stack.shape[-1] == 3 - - if(len(stack.shape) == 4): - if((stack.shape[2] == 3) & (stack.shape[3] == 1)): - stack = stack[..., 0] - - n_im, n_R, n_C = stack.shape[:3] - - if(len(stack.shape) == 4): - assert is_rgb, 'For a stack of images with axis 3, it should be 1 or 3.' - assert (len(stack.shape) == 3) | (len(stack.shape) == 4), \ - f'The stack you provided can have specific shapes. it is {stack.shape}' + def __init__(self, hostname: str, username: str, password: str): + """ + Initialize the SSHSystem by setting up the SSH and SFTP clients. - if(frame_shape is None): - square_side = int(np.ceil(np.sqrt(n_im))) - frame_n_r, frame_n_c = (square_side, square_side) - else: - frame_n_r, frame_n_c = frame_shape - n_R += 2 - n_C += 2 - new_n_R = n_R * frame_n_r - new_n_C = n_C * frame_n_c - - if is_rgb: - frame = np.zeros((new_n_R, new_n_C, 3), dtype = stack.dtype) - else: - frame = np.zeros((new_n_R, new_n_C), dtype = stack.dtype) - used_ch_cnt = 0 - if(borders is not None): - frame += borders - for rcnt in range(frame_n_r): - for ccnt in range(frame_n_c): - ch_cnt = rcnt + frame_n_c*ccnt - if (ch_cnt bool: + """ + Check if a file exists on the remote system. - def close_connection(self): - self.sftp_client.close() - self.ssh_client.close() + Args: + path (Path): The path to the file on the remote system. + + Returns: + bool: True if the file exists, False otherwise. + """ + try: + stdin, stdout, stderr = self.ssh_client.exec_command(f'test -f {path} && echo "exists"') + return "exists" in stdout.read().decode() + except Exception as e: + print(f"Error checking file {path}: {e}") + return False -def printv(var): + def close_connection(self): + """ + Close the SSH and SFTP connections to the remote system. + """ + if self.sftp_client: + self.sftp_client.close() + if self.ssh_client: + self.ssh_client.close() + +def printv(var, **kwargs): # Get the name of the variable passed to the function frame = inspect.currentframe().f_back var_name = [name for name, value in frame.f_locals.items() if value is var] @@ -328,81 +323,175 @@ def printv(var): else: var_name = 'variable' - is_array = True - toprint = f'{type(var).__name__} {var_name}:' + is_np_torch = True + var_class = type(var).__name__ + toprint = f'{var_class} {var_name}: ' try: - toprint += f'shape={var.shape} dtype={var.dtype}' + array_shape = var.shape + toprint += f'shape={array_shape}' except: - is_array = False + is_np_torch = False try: - toprint += f'device={var.device}' + array_dtype = var.dtype + toprint += f', dtype={array_dtype}' except: pass - if not is_array: + try: + toprint += f', device={var.device}' + except: pass + + if is_np_torch: + arr_size = np.prod(array_shape) + if 'array_size_threshold' in kwargs: + array_size_threshold = kwargs['array_size_threshold'] + else: + array_size_threshold = 1e+6 + if arr_size < array_size_threshold: + try: + toprint += f', min={var.min():.6f}' + except: pass + try: + toprint += f', max={var.max():.6f}' + except: pass + try: + toprint += f', mean={var.mean():.6f}' + except: pass + try: + toprint += f', std={var.std():.6f}' + except: pass + + if not is_np_torch: toprint += str(var) # Print the information print(toprint) class Pyrunner: - def __init__(self, fpath, logger = None): - """ Jupyter like runner for Python - """ + """ + A Jupyter-like Python code runner that executes code in blocks based on cell numbers, + supports saving and loading kernel states, and allows interactive execution. + + Attributes: + fpath (Path): The path to the Python file to execute. + logger_ (callable): An optional logger function to log messages. + log (str): A string containing the accumulated log messages. + saved_state (dict): A dictionary to hold saved kernel states. + exit (bool): A flag to indicate when to stop execution. + """ + + def __init__(self, fpath: str, logger=None): + """ + Initializes the Pyrunner class, runs the Python file in an interactive loop, + and allows execution of specific code blocks identified by cell numbers. + + Args: + fpath (str): The file path to the Python script to be executed. + logger (callable, optional): A logger function to log output (default is None). + """ self.logger_ = logger self.fpath = Path(fpath) - assert self.fpath.is_file() + assert self.fpath.is_file(), f"File {fpath} does not exist." self.log = '' - self.logger(f'file: {fpath}') self.saved_state = {} self.exit = False + + self.logger(f'file: {fpath}') while not self.exit: show_and_ask_result = self.show(globals()) if show_and_ask_result is None: continue globals().update(show_and_ask_result) - exec(pyrunner_code, globals()) + exec(globals().get('pyrunner_code', ''), globals()) + + def logger(self, toprint: str, end: str = '\n'): + """ + Logs the provided message. If a logger is provided, it logs the message using that function. + Otherwise, it appends the message to the internal log. - def logger(self, toprint, end = '\n'): + Args: + toprint (str): The message to log. + end (str, optional): The string appended after each message (default is '\n'). + """ toprint = str(toprint) + end self.log += toprint if self.logger_ is not None: self.logger_(toprint) - def save_or_load_kernel_state(self, globals_, saved_state = None): + def save_or_load_kernel_state(self, globals_: dict, saved_state=None): + """ + Saves or loads the kernel state using the `dill` library. If `saved_state` is provided, it loads + the state into `globals_`. If `saved_state` is None, it returns a serialized form of the current + global variables. + + Args: + globals_ (dict): The global variables to save or update. + saved_state (bytes, optional): The serialized kernel state to load (default is None). + + Returns: + bytes: A serialized version of the global variables if saving the state. + """ import dill as pickle if saved_state is None: return pickle.dumps( {k: v for k, v in globals_.items() - if not k.startswith('__') and not callable(v)}) + if not k.startswith('__') and not callable(v)} + ) else: globals_.update(pickle.loads(saved_state)) @property - def n_saves(self): + def n_saves(self) -> int: + """ + Returns the number of saved states. + + Returns: + int: The number of saved states. + """ return len(self.saved_state.keys()) - def show(self, globals_, figsize = (3, 2)) -> (str, int): + def show(self, globals_: dict, figsize: tuple = (3, 2)) -> dict: + """ + Displays available cell blocks for execution and handles user interaction to + run specific blocks or manage kernel states (save/load/delete). + + Args: + globals_ (dict): The global variables of the current session. + figsize (tuple, optional): The size of the dialog box (default is (3, 2)). + + Returns: + dict: A dictionary containing the updated global variables if a cell block is selected. + """ pyrunner_code = open(self.fpath).read() pattern = r"if\s+pyrunner_cell_no\s*==\s*(\d+):" matches = re.findall(pattern, pyrunner_code) + if len(matches) == 0: - print(f'Running the pyrunner_code in {self.fpath}') - print('No blocks found that checks pyrunner_cell_no') + self.logger(f'Running the pyrunner_code in {self.fpath}') + self.logger('No blocks found that check pyrunner_cell_no') + return + pyrunner_cell_nos = sorted(set(int(num) for num in matches)) buttons = {} + for pyrunner_cell_no in pyrunner_cell_nos: buttons[f'{pyrunner_cell_no}'] = pyrunner_cell_no + + # Add options for saved states for key in self.saved_state: buttons[f'load_{key}'] = f'load_{key}' - for key in self.saved_state: buttons[f'del_{key}'] = f'del_{key}' + buttons[f'save_{self.n_saves + 1}'] = f'save_{self.n_saves + 1}' buttons['exit'] = 'exit' + # Display dialog for user interaction show_and_ask_result = question_dialog( - question='Choose a cell number', figsize=figsize, buttons=buttons) + question='Choose a cell number', figsize=figsize, buttons=buttons + ) if show_and_ask_result is None: self.logger(f'pyrunner: closing reloads, press Exit to close.') return - elif show_and_ask_result == str(show_and_ask_result): + + # Handle user selection + if isinstance(show_and_ask_result, str): if show_and_ask_result == 'exit': self.exit = True return @@ -410,7 +499,7 @@ def show(self, globals_, figsize = (3, 2)) -> (str, int): elif 'save' in show_and_ask_result: key = show_and_ask_result.split('save_')[1] self.saved_state[key] = self.save_or_load_kernel_state(globals_) - self.logger(f'saved state: {key}') + self.logger(f'Saved state: {key}') return elif 'load' in show_and_ask_result: @@ -424,7 +513,10 @@ def show(self, globals_, figsize = (3, 2)) -> (str, int): self.saved_state.pop(key) self.logger(f'Deleted state: {key}') return - elif show_and_ask_result == int(show_and_ask_result): + + elif isinstance(show_and_ask_result, int): globals_['pyrunner_code'] = pyrunner_code globals_['pyrunner_cell_no'] = show_and_ask_result return globals_ + + diff --git a/setup.cfg b/setup.cfg index a39eb9e..49a2713 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.12.14 +current_version = 0.12.15 commit = True tag = True diff --git a/setup.py b/setup.py index 908bfbb..703078a 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ __author__ = 'Alireza Sadri' __email__ = 'arsadri@gmail.com' -__version__ = '0.12.15' +__version__ = '0.12.16' with open('README.rst') as readme_file: readme = readme_file.read() diff --git a/tests/test_multiprocessor.py b/tests/test_multiprocessor.py index 3814302..299704f 100644 --- a/tests/test_multiprocessor.py +++ b/tests/test_multiprocessor.py @@ -1,5 +1,5 @@ -from lognflow import multiprocessor, loopprocessor, printprogress -from lognflow.multiprocessor import multiprocessor_gen +from lognflow import multiprocessor, printprogress +from lognflow.multiprocessor import multiprocessor_gen, loopprocessor import numpy as np import inspect import time @@ -132,8 +132,9 @@ def test_error_handling_in_multiprocessor(): error_multiprocessor_targetFunc, iterables, shareables, verbose = True) raise - except: + except Exception as e: print('Error has been raised') + print(e) def noslice_multiprocessor_targetFunc(iterables_sliced, shareables): idx = iterables_sliced @@ -160,9 +161,6 @@ def test_noslice_multiprocessor(): stats = multiprocessor( noslice_multiprocessor_targetFunc, iterables, shareables, verbose = True) - -############################################ - def compute(data, mask): for _ in range(400): res = np.median(data[mask==1]) @@ -228,12 +226,6 @@ def test_multiprocessor_gen(): print(IDs) results_mp = results_mp[0] - # results_mp = [arrivals[0] - # for arrivals in multiprocessor_gen(compute_arg_scatterer, - # iterables = (data, mask), - # shareables = None, - # verbose = True)] - results = np.zeros(N) for cnt in printprogress(range(N)): results[cnt], _ = compute(data[cnt], mask[cnt]) diff --git a/tests/test_plt_utils.py b/tests/test_plt_utils.py index d2a48a8..e52627d 100644 --- a/tests/test_plt_utils.py +++ b/tests/test_plt_utils.py @@ -7,7 +7,7 @@ import lognflow from lognflow.plt_utils import ( plt_imshow, complex2hsv_colorbar, plt_imhist,complex2hsv, - transform3D_viewer, plot_marker, plt_contours) + transform3D_viewer, plt_mark, plt_contours, question_dialog) import numpy as np def test_transform3D_viewer(): @@ -51,7 +51,6 @@ def test_numbers_as_images(): dataset = lognflow.plt_utils.numbers_as_images_4D( dataset_shape, fontsize) - ########################################################################## n_x, n_y, n_r, n_c = dataset_shape txt_width = int(np.log(np.maximum(n_x, n_y)) /np.log(np.maximum(n_x, n_y))) + 1 @@ -97,22 +96,22 @@ def test_plt_fig_to_numpy(): print(np_data.shape) plt.close() -def test_imshow_series(): +def test_plt_imshow_series(): data = [np.random.rand(10, 100, 100), np.random.rand(10, 10, 10)] - lognflow.plt_utils.imshow_series(data) + lognflow.plt_utils.plt_imshow_series(data) plt.show() -def test_imshow_by_subplots(): +def test_plt_imshow_subplots(): data = np.random.rand(15, 100, 100, 3) - lognflow.plt_utils.imshow_by_subplots(data, colorbar = False) + lognflow.plt_utils.plt_imshow_subplots(data, colorbar = False) data = [np.random.rand(100, 100), np.random.rand(100, 150), np.random.rand(50, 100)] - lognflow.plt_utils.imshow_by_subplots(data) + lognflow.plt_utils.plt_imshow_subplots(data) data = np.random.rand(15, 100, 100) grid_locations = (np.random.rand(len(data), 2)*1000).astype('int') - lognflow.plt_utils.imshow_by_subplots(data, grid_locations = grid_locations) + lognflow.plt_utils.plt_imshow_subplots(data, grid_locations = grid_locations) plt.show() @@ -139,15 +138,12 @@ def test_plt_imhist(): def test_plt_imshow_complex(): - # Define the meshgrid for testing comx, comy = np.meshgrid(np.arange(-7, 8, 1), np.arange(-7, 8, 1)) com = comx + 1j * comy print(comx) print(comy) - # Use the existing complex2hsv function to convert complex data to RGB img, data_abs, data_angle = complex2hsv(com) - # Calculate min and max angles vmin = data_abs.min() vmax = data_abs.max() try: @@ -159,30 +155,37 @@ def test_plt_imshow_complex(): except: max_angle = 0 - # Plot the complex image - fig, ax = plt.subplots(figsize=(5, 5)) - im = ax.imshow(img, extent=(-7, 8, -7, 8)) + fig, ax = plt_imshow(img, extent=(-7, 8, -7, 8), title = 'complex2hsv', + colorbar = False) - # Annotate each pixel with its corresponding comx and comy values for i in range(0, comx.shape[0], 1): for j in range(0, comx.shape[1], 1): - ax.text(j - 7+0.5, -i + 7+0.5, f'({comx[i, j]}, {comy[i, j]})', ha='center', va='center', fontsize=8, color='white') + ax.text(j - 7+0.5, -i + 7+0.5, f'({comx[i, j]}, {comy[i, j]})', + ha='center', va='center', fontsize=8, color='white') - # Create and plot the color disc as an inset - fig, ax_inset = complex2hsv_colorbar((fig, ax.inset_axes([0.79, 0.03, 0.18, 0.18], transform=ax.transAxes)), - vmin=vmin, vmax=vmax, min_angle=min_angle, max_angle=max_angle) - ax_inset.patch.set_alpha(0) # Make the background of the inset axis transparent + fig, ax_inset = complex2hsv_colorbar( + (fig, ax.inset_axes([0.79, 0.03, 0.18, 0.18], transform=ax.transAxes)), + vmin=vmin, vmax=vmax, min_angle=min_angle, max_angle=max_angle) + ax_inset.patch.set_alpha(0) + + plt_imshow(np.random.rand(100, 100) + 1j * np.random.rand(100, 100), + cmap = 'gray_real_imag') + + plt_imshow(np.random.rand(100, 100) + 1j * np.random.rand(100, 100), + cmap = 'jet_real_imag') + + plt_imshow(np.random.rand(100, 100) + 1j * np.random.rand(100, 100)) plt.show() -def test_plot_marker(): +def test_plt_mark(): coords = np.random.rand(1000, 2)*100 - fig, ax, markersize = plot_marker( + fig, ax, markersize = plt_mark( coords, fig_ax=None, figsize=None, markersize=None, return_markersize = True) for cnt in range(23): coords = np.array([np.arange(1 + 1*cnt,1+ 1*(cnt+1)), np.zeros(1)]).T - fig_ax = plot_marker(coords, fig_ax=(fig, ax), markersize=markersize) + fig_ax = plt_mark(coords, fig_ax=(fig, ax), markersize=markersize) plt.show() @@ -193,16 +196,44 @@ def test_plt_contours(): plt_contours(Z_list) plt.show() +def test_question_dialog(): + vec = np.random.rand(100) + img = np.random.rand(100, 100) + question = 'how good is it?' + + result = question_dialog(vec) + print(result) + result = question_dialog(img) + print(result) + result = question_dialog(question) + print(result) + +def test_stack_to_frame(): + data4d = np.random.rand(25, 32, 32, 3) + img = lognflow.plt_utils.stack_to_frame(data4d, borders = np.nan) + plt.figure() + plt.imshow(img) + + data4d = np.random.rand(32, 32, 16, 16, 3) + stack = data4d.reshape(-1, *data4d.shape[2:]) + frame = lognflow.plt_utils.stack_to_frame(stack, borders = np.nan) + plt.figure() + im = plt.imshow(frame) + lognflow.plt_utils.plt_colorbar(im) + plt.show() + if __name__ == '__main__': - test_plt_contours() test_plt_imshow_complex() + test_question_dialog() + test_plt_contours() test_complex2hsv_colorbar() - test_plot_marker() + test_plt_mark() test_plot_gaussian_gradient() test_transform3D_viewer() - test_imshow_by_subplots() + test_plt_imshow_subplots() test_plt_imhist() test_plt_imshow() - test_imshow_series() + test_plt_imshow_series() test_numbers_as_images() test_plt_fig_to_numpy() + test_stack_to_frame() \ No newline at end of file diff --git a/tests/test_printprogress.py b/tests/test_printprogress.py index 4fe3acf..8c36490 100644 --- a/tests/test_printprogress.py +++ b/tests/test_printprogress.py @@ -13,7 +13,7 @@ temp_dir = tempfile.gettempdir() def test_printprogress(): - for N in list([100, 200, 400, 1000, 2000, 4000, 6000]): + for N in list([100, 200, 400, 1000]): pprog = printprogress(N) for _ in range(N): time.sleep(0.01) @@ -32,7 +32,7 @@ def test_printprogress_ETA(): pprog = printprogress(N, print_function = None) for _ in range(N): ETA = pprog() - print(ETA) + print(f'ETA: {ETA:.2f}') def test_specific_timing(): logger = lognflow(temp_dir) @@ -53,7 +53,7 @@ def test_generator_type(): print(f'sum: {sum}') def test_varying_periods(): - vec = np.arange(60) + vec = np.arange(30) sum = 0 for _ in printprogress(vec): sum += _ diff --git a/tests/test_pyrunner_code.py b/tests/test_pyrunner_code.py index 0af1a11..931d8e0 100644 --- a/tests/test_pyrunner_code.py +++ b/tests/test_pyrunner_code.py @@ -1,25 +1,30 @@ from lognflow.plt_utils import plt, np, plt_imhist import numpy as np -if pyrunner_cell_no == 1: - vec = [1, 2, 3] +try: -if pyrunner_cell_no == 2: - vec = [i ** 2 for i in vec] - print("Squared vec:", vec) - -if pyrunner_cell_no == 3: - vec = [np.exp(-i ** 2) for i in vec] - print("Squared vec:", vec) - plt_imhist(np.random.randn(100, 100)) - plt.show() - -if pyrunner_cell_no == 4: - vec = [np.exp(i) for i in vec] - print("Squared vec:", vec) - -if pyrunner_cell_no == 5: - vec = [np.log(i) for i in vec] - print("Squared vec:", vec) + if pyrunner_cell_no == 1: + vec = [1, 2, 3] + + if pyrunner_cell_no == 2: + vec = [i ** 2 for i in vec] + print("Squared vec:", vec) + + if pyrunner_cell_no == 3: + vec = [np.exp(-i ** 2) for i in vec] + print("Squared vec:", vec) + plt_imhist(np.random.randn(100, 100)) + plt.show() + + if pyrunner_cell_no == 4: + vec = [np.exp(i) for i in vec] + print("Squared vec:", vec) + + if pyrunner_cell_no == 5: + vec = [np.log(i) for i in vec] + print("Squared vec:", vec) + + print(f"Current state of vec: {vec}") -print(f"Current state of vec: {vec}") \ No newline at end of file +except Exception as e: + print(e) \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index fd10862..1fc922b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,20 +7,6 @@ import lognflow import numpy as np -def test_stack_to_frame(): - data4d = np.random.rand(25, 32, 32, 3) - img = lognflow.stack_to_frame(data4d, borders = np.nan) - plt.figure() - plt.imshow(img) - - data4d = np.random.rand(32, 32, 16, 16, 3) - stack = data4d.reshape(-1, *data4d.shape[2:]) - frame = lognflow.stack_to_frame(stack, borders = np.nan) - plt.figure() - im = plt.imshow(frame) - lognflow.plt_colorbar(im) - plt.show() - def test_is_builtin_collection(): # Test the function with various types @@ -45,7 +31,7 @@ def test_ssh_system(): remote_dir = Path('/remote/folder/path') local_dir = Path('/local/folder/path') target_fname = 'intresting_file.log' - ssh.monitor_and_move(remote_dir, local_dir, target_fname) + ssh.monitor_and_remove(remote_dir, local_dir, target_fname) ssh.close_connection() except: print('SSH test not passed maybe because you did not set the credentials.') @@ -76,12 +62,11 @@ def test_save_or_load_kernel_state(): print("State restored successfully!") def test_Pyrunner(): - from lognflow import Pyrunner + from lognflow.utils import Pyrunner Pyrunner(Path('./test_pyrunner_code.py'), logger = print) if __name__ == '__main__': - test_Pyrunner() test_printv() + test_Pyrunner() test_is_builtin_collection() - test_stack_to_frame() test_ssh_system() \ No newline at end of file