From c3486050cb4ed3c42c82db79e431362db509a97c Mon Sep 17 00:00:00 2001 From: pciturri Date: Fri, 23 Aug 2024 14:47:29 +0200 Subject: [PATCH] refactor: Added type hints and docstrings to helper functions. Removed unused helper functions --- .github/workflows/build-test.yml | 2 + csep/utils/plots.py | 927 ++++++++++++------------------- requirements.yml | 1 - tests/test_plots.py | 402 +++++++------- 4 files changed, 572 insertions(+), 760 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index cf407487..a71e8564 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -34,10 +34,12 @@ jobs: - name: Install pyCSEP run: | pip install --no-deps -e . + pip install rasterio python -c "import csep; print('Version: ', csep.__version__)" - name: Test with pytest run: | + pip install vcrpy==4.3.1 pytest pytest-cov pytest --cov=./ --cov-config=.coveragerc diff --git a/csep/utils/plots.py b/csep/utils/plots.py index d864b585..840ca36c 100644 --- a/csep/utils/plots.py +++ b/csep/utils/plots.py @@ -1,8 +1,7 @@ import os import shutil -import string import warnings -from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple, Sequence +from typing import TYPE_CHECKING, Optional, Any, List, Union, Tuple, Sequence, Dict import cartopy import cartopy.crs as ccrs @@ -13,8 +12,6 @@ import numpy import numpy as np import pandas as pandas -import rasterio -import scipy.stats from cartopy.io import img_tiles from cartopy.io.img_tiles import GoogleWTS from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER @@ -22,8 +19,10 @@ from matplotlib.dates import AutoDateLocator, DateFormatter from matplotlib.lines import Line2D from rasterio import DatasetReader -from rasterio import plot as rioplot +from rasterio import plot as rio_plot +from rasterio import open as rio_open from scipy.integrate import cumulative_trapezoid +from scipy.stats import poisson, nbinom, beta # PyCSEP imports import csep.utils.time_utils @@ -66,7 +65,7 @@ "secondary_color": "red", "alpha": 0.8, "linewidth": 1, - "linestyle": '-', + "linestyle": "-", "secondary_linestyle": "red", "size": 5, "marker": "o", @@ -532,7 +531,7 @@ def plot_distribution_test( linestyle="--", label=obs_label + numpy.isinf(observation) * " (-inf)", ) - else: + elif isinstance(observation, (list, np.ndarray)): observation = observation[~numpy.isnan(observation)] ax.hist( observation, @@ -592,11 +591,14 @@ def plot_calibration_test( Plots a calibration test (QQ plot) with confidence intervals. Args: - evaluation_result (EvaluationResult): The evaluation result object containing the test distribution. + evaluation_result (EvaluationResult): The evaluation result object containing the test + distribution. percentile (float): Percentile to build confidence interval - ax (Optional[matplotlib.axes.Axes]): Axes object to plot on. If None, creates a new figure. + ax (Optional[matplotlib.axes.Axes]): Axes object to plot on. If None, creates a new + figure. show (bool): If True, displays the plot. Default is False. - label (Optional[str]): Label for the plotted data. If None, uses `evaluation_result.sim_name`. + label (Optional[str]): Label for the plotted data. If None, uses + `evaluation_result.sim_name`. **kwargs: Additional keyword arguments for customizing the plot. These are merged with `DEFAULT_PLOT_ARGS`. @@ -615,8 +617,8 @@ def plot_calibration_test( # Compute confidence intervals for order statistics using beta distribution inf = (100 - percentile) / 2 sup = 100 - (100 - percentile) / 2 - ulow = scipy.stats.beta.ppf(inf / 100, k, n - k + 1) - uhigh = scipy.stats.beta.ppf(sup / 100, k, n - k + 1) + ulow = beta.ppf(inf / 100, k, n - k + 1) + uhigh = beta.ppf(sup / 100, k, n - k + 1) # Quantiles should be sorted for plotting sorted_td = numpy.sort(evaluation_result.test_distribution) @@ -935,13 +937,13 @@ def plot_consistency_test( # Alarm-based plots ################### def plot_concentration_ROC_diagram( - forecast: "GriddedForecast", - catalog: "CSEPCatalog", - linear: bool = True, - plot_uniform: bool = True, - show: bool = True, - ax: Optional[pyplot.Axes] = None, - **kwargs + forecast: "GriddedForecast", + catalog: "CSEPCatalog", + linear: bool = True, + plot_uniform: bool = True, + show: bool = True, + ax: Optional[pyplot.Axes] = None, + **kwargs, ) -> matplotlib.axes.Axes: """ Plots the Concentration ROC Diagram for a given forecast and observed catalog. @@ -951,7 +953,8 @@ def plot_concentration_ROC_diagram( catalog (CSEPCatalog): Catalog object containing observed data. linear (bool): If True, uses a linear scale for the X-axis, otherwise logarithmic. ax (Optional[pyplot.Axes]): Axes object to plot on (default: None). - plot_uniform (bool): If True, plots the uniform (random) model as a reference (default: True). + plot_uniform (bool): If True, plots the uniform (random) model as a reference + (default: True). show (bool): If True, displays the plot (default: True). ax (Optional[pyplot.Axes]): Axes object to plot on (default: None). **kwargs: Additional keyword arguments for customization. @@ -975,12 +978,12 @@ def plot_concentration_ROC_diagram( obs_counts = catalog.spatial_counts() rate = forecast.spatial_counts() - I = numpy.argsort(rate) - I = numpy.flip(I) + indices = numpy.argsort(rate) + indices = numpy.flip(indices) - fore_norm_sorted = numpy.cumsum(rate[I]) / numpy.sum(rate) - area_norm_sorted = numpy.cumsum(area_km2[I]) / numpy.sum(area_km2) - obs_norm_sorted = numpy.cumsum(obs_counts[I]) / numpy.sum(obs_counts) + fore_norm_sorted = numpy.cumsum(rate[indices]) / numpy.sum(rate) + area_norm_sorted = numpy.cumsum(area_km2[indices]) / numpy.sum(area_km2) + obs_norm_sorted = numpy.cumsum(obs_counts[indices]) / numpy.sum(obs_counts) # Plot data if plot_uniform: @@ -990,27 +993,36 @@ def plot_concentration_ROC_diagram( area_norm_sorted, fore_norm_sorted, label=forecast_label, - color=plot_args['secondary_color'] + color=plot_args["secondary_color"], ) ax.step( area_norm_sorted, obs_norm_sorted, label=observed_label, - color=plot_args['color'], - linestyle=plot_args['linestyle'], + color=plot_args["color"], + linestyle=plot_args["linestyle"], ) # Plot formatting - ax.set_title(plot_args['title'], fontsize=plot_args['title_fontsize']) + ax.set_title(plot_args["title"], fontsize=plot_args["title_fontsize"]) ax.grid(plot_args["grid"]) if not linear: ax.set_xscale("log") - ax.set_ylabel(plot_args['ylabel'] or "True Positive Rate", fontsize=plot_args['ylabel_fontsize']) - ax.set_xlabel(plot_args['xlabel'] or "False Positive Rate (Normalized Area)", fontsize=plot_args['xlabel_fontsize']) + ax.set_ylabel( + plot_args["ylabel"] or "True Positive Rate", fontsize=plot_args["ylabel_fontsize"] + ) + ax.set_xlabel( + plot_args["xlabel"] or "False Positive Rate (Normalized Area)", + fontsize=plot_args["xlabel_fontsize"], + ) if plot_args["legend"]: - ax.legend(loc=plot_args['legend_loc'], shadow=True, fontsize=plot_args['legend_fontsize'], - framealpha=plot_args['legend_framealpha']) + ax.legend( + loc=plot_args["legend_loc"], + shadow=True, + fontsize=plot_args["legend_fontsize"], + framealpha=plot_args["legend_framealpha"], + ) if plot_args["tight_layout"]: fig.tight_layout() if show: @@ -1020,28 +1032,30 @@ def plot_concentration_ROC_diagram( def plot_ROC_diagram( - forecast: "GriddedForecast", - catalog: "CSEPCatalog", - linear: bool = True, - plot_uniform: bool = True, - show: bool = True, - ax: Optional[pyplot.Axes] = None, - **kwargs + forecast: "GriddedForecast", + catalog: "CSEPCatalog", + linear: bool = True, + plot_uniform: bool = True, + show: bool = True, + ax: Optional[pyplot.Axes] = None, + **kwargs, ) -> matplotlib.pyplot.Axes: """ - Plots the ROC (Receiver Operating Characteristic) curve for a given forecast and observed catalog. - - Args: - forecast (GriddedForecast): Forecast object containing spatial forecast data. - catalog (CSEPCatalog): Catalog object containing observed data. - linear (bool): If True, uses a linear scale for the X-axis, otherwise logarithmic. - plot_uniform (bool): If True, plots the uniform (random) model as a reference (default: True). - show (bool): If True, displays the plot (default: True). - ax (Optional[pyplot.Axes]): Axes object to plot on (default: None). - **kwargs: Additional keyword arguments for customization. - - Returns: - pyplot.Axes: The Axes object with the plot. + Plots the ROC (Receiver Operating Characteristic) curve for a given forecast and observed + catalog. + + Args: + forecast (GriddedForecast): Forecast object containing spatial forecast data. + catalog (CSEPCatalog): Catalog object containing observed data. + linear (bool): If True, uses a linear scale for the X-axis, otherwise logarithmic. + plot_uniform (bool): If True, plots the uniform (random) model as a reference (default: + True). + show (bool): If True, displays the plot (default: True). + ax (Optional[pyplot.Axes]): Axes object to plot on (default: None). + **kwargs: Additional keyword arguments for customization. + + Returns: + pyplot.Axes: The Axes object with the plot. """ # Initialize plot @@ -1054,10 +1068,10 @@ def plot_ROC_diagram( rate = forecast.spatial_counts() obs_counts = catalog.spatial_counts() - I = numpy.argsort(rate)[::-1] # Sort in descending order + indices = numpy.argsort(rate)[::-1] # Sort in descending order - thresholds = (rate[I]) / numpy.sum(rate) - obs_counts = obs_counts[I] + thresholds = (rate[indices]) / numpy.sum(rate) + obs_counts = obs_counts[indices] Table_ROC = pandas.DataFrame({"Threshold": [], "H": [], "F": []}) @@ -1089,8 +1103,8 @@ def plot_ROC_diagram( Table_ROC["F"], Table_ROC["H"], label=plot_args.get("forecast_label", forecast.name or "Forecast"), - color=plot_args['color'], - linestyle=plot_args['linestyle'], + color=plot_args["color"], + linestyle=plot_args["linestyle"], ) if plot_uniform: @@ -1103,18 +1117,21 @@ def plot_ROC_diagram( ) # Plot formatting - ax.set_ylabel(plot_args['ylabel'] or "Hit Rate", fontsize=plot_args['ylabel_fontsize']) - ax.set_xlabel(plot_args['xlabel'] or "Fraction of False Alarms", fontsize=plot_args['xlabel_fontsize']) + ax.set_ylabel(plot_args["ylabel"] or "Hit Rate", fontsize=plot_args["ylabel_fontsize"]) + ax.set_xlabel( + plot_args["xlabel"] or "Fraction of False Alarms", fontsize=plot_args["xlabel_fontsize"] + ) if not linear: ax.set_xscale("log") ax.set_yscale("linear") - ax.tick_params(axis="x", labelsize=plot_args['xticks_fontsize']) - ax.tick_params(axis="y", labelsize=plot_args['yticks_fontsize']) - if plot_args['legend']: - ax.legend(loc=plot_args['legend_loc'], shadow=True, - fontsize=plot_args['legend_fontsize']) - ax.set_title(plot_args['title'], fontsize=plot_args['title_fontsize']) - if plot_args['tight_layout']: + ax.tick_params(axis="x", labelsize=plot_args["xticks_fontsize"]) + ax.tick_params(axis="y", labelsize=plot_args["yticks_fontsize"]) + if plot_args["legend"]: + ax.legend( + loc=plot_args["legend_loc"], shadow=True, fontsize=plot_args["legend_fontsize"] + ) + ax.set_title(plot_args["title"], fontsize=plot_args["title_fontsize"]) + if plot_args["tight_layout"]: fig.tight_layout() if show: @@ -1124,13 +1141,13 @@ def plot_ROC_diagram( def plot_Molchan_diagram( - forecast: "GriddedForecast", - catalog: "CSEPCatalog", - linear: bool = True, - plot_uniform: bool = True, - show: bool = True, - ax: Optional[pyplot.Axes] = None, - **kwargs + forecast: "GriddedForecast", + catalog: "CSEPCatalog", + linear: bool = True, + plot_uniform: bool = True, + show: bool = True, + ax: Optional[pyplot.Axes] = None, + **kwargs, ) -> matplotlib.axes.Axes: """ Plot the Molchan Diagram based on forecast and test catalogs using the contingency table. @@ -1139,13 +1156,15 @@ def plot_Molchan_diagram( The Molchan diagram is computed following this procedure: 1. Obtain spatial rates from the GriddedForecast and the observed events from the catalog. 2. Rank the rates in descending order (highest rates first). - 3. Sort forecasted rates by ordering found in (2), and normalize rates so their sum is equal to unity. + 3. Sort forecasted rates by ordering found in (2), and normalize rates so their sum is equal + to unity. 4. Obtain binned spatial rates from the observed catalog. 5. Sort gridded observed rates by ordering found in (2). - 6. Test each ordered and normalized forecasted rate defined in (3) as a threshold value to obtain the - corresponding contingency table. - 7. Define the "nu" (Miss rate) and "tau" (Fraction of spatial alarmed cells) for each threshold using the - information provided by the corresponding contingency table defined in (6). + 6. Test each ordered and normalized forecasted rate defined in (3) as a threshold value to + obtain the corresponding contingency table. + 7. Define the "nu" (Miss rate) and "tau" (Fraction of spatial alarmed cells) for each + threshold using the information provided by the corresponding contingency table defined in + (6). Note: 1. The testing catalog and forecast should have exactly the same time-window (duration). @@ -1182,12 +1201,12 @@ def plot_Molchan_diagram( obs_counts = catalog.spatial_counts() # Get index of rates (descending sort) - I = numpy.argsort(rate) - I = numpy.flip(I) + indices = numpy.argsort(rate) + indices = numpy.flip(indices) # Order forecast and cells rates by highest rate cells first - thresholds = (rate[I]) / numpy.sum(rate) - obs_counts = obs_counts[I] + thresholds = (rate[indices]) / numpy.sum(rate) + obs_counts = obs_counts[indices] Table_molchan = pandas.DataFrame( { @@ -1280,11 +1299,11 @@ def plot_Molchan_diagram( ) ASscore = numpy.round(Tab_as_score.loc[Tab_as_score.index[-1], "AS_score"], 2) - bin = 0.01 + bin_size = 0.01 devstd = numpy.sqrt(1 / (12 * Table_molchan["Obs_active_bins"].iloc[0])) - devstd = devstd * bin**-1 + devstd = devstd * bin_size**-1 devstd = numpy.ceil(devstd + 0.5) - devstd = devstd / bin**-1 + devstd = devstd / bin_size**-1 dev_std = numpy.round(devstd, 2) # Plot the Molchan trajectory @@ -1292,8 +1311,8 @@ def plot_Molchan_diagram( Table_molchan["tau"], Table_molchan["nu"], label=f"{forecast_label}, ASS={ASscore}±{dev_std} ", - color=plot_args['color'], - linestyle=plot_args['linestyle'], + color=plot_args["color"], + linestyle=plot_args["linestyle"], ) # Plot uniform forecast @@ -1303,14 +1322,17 @@ def plot_Molchan_diagram( ax.plot(x_uniform, y_uniform, linestyle="--", color="gray", label="Uniform") # Plot formatting - ax.set_ylabel(plot_args['ylabel'] or "Miss Rate", fontsize=plot_args['ylabel_fontsize']) - ax.set_xlabel(plot_args['xlabel'] or "Fraction of area occupied by alarms", fontsize=plot_args['xlabel_fontsize']) + ax.set_ylabel(plot_args["ylabel"] or "Miss Rate", fontsize=plot_args["ylabel_fontsize"]) + ax.set_xlabel( + plot_args["xlabel"] or "Fraction of area occupied by alarms", + fontsize=plot_args["xlabel_fontsize"], + ) if not linear: ax.set_xscale("log") - ax.tick_params(axis="x", labelsize=plot_args['xlabel_fontsize']) - ax.tick_params(axis="y", labelsize=plot_args['ylabel_fontsize']) - ax.legend(loc=plot_args['legend_loc'], shadow=True, fontsize=plot_args['legend_fontsize']) - ax.set_title(plot_args['title'] or "Molchan Diagram", fontsize=plot_args['title_fontsize']) + ax.tick_params(axis="x", labelsize=plot_args["xlabel_fontsize"]) + ax.tick_params(axis="y", labelsize=plot_args["ylabel_fontsize"]) + ax.legend(loc=plot_args["legend_loc"], shadow=True, fontsize=plot_args["legend_fontsize"]) + ax.set_title(plot_args["title"] or "Molchan Diagram", fontsize=plot_args["title_fontsize"]) if plot_args["tight_layout"]: fig.tight_layout() @@ -1392,7 +1414,7 @@ def plot_basemap( ax.add_image(basemap_obj, tile_depth) # basemap_obj is a rasterio image elif isinstance(basemap_obj, DatasetReader): - ax = rioplot.show(basemap_obj, ax=ax) + ax = rio_plot.show(basemap_obj, ax=ax) except Exception as e: print( @@ -1463,8 +1485,14 @@ def plot_catalog( ax.scatter( catalog.get_longitudes(), catalog.get_latitudes(), - s=_autosize_scatter(values=catalog.get_magnitudes(), min_size=size, max_size=max_size, - power=power, min_val=min_val, max_val=max_val), + s=_autosize_scatter( + values=catalog.get_magnitudes(), + min_size=size, + max_size=max_size, + power=power, + min_val=min_val, + max_val=max_val, + ), transform=ccrs.PlateCarree(), color=plot_args["markercolor"], edgecolors=plot_args["markeredgecolor"], @@ -1486,15 +1514,25 @@ def plot_catalog( max_size=max_size, power=power, min_val=min_val or np.min(catalog.get_magnitudes()), - max_val=max_val or np.max(catalog.get_magnitudes()) + max_val=max_val or np.max(catalog.get_magnitudes()), ) # Create custom legend handles - handles = [pyplot.Line2D([0], [0], marker='o', lw=0, label=str(m), - markersize=np.sqrt(s), markerfacecolor='gray', alpha=0.5, - markeredgewidth=0.8, - markeredgecolor='black') - for m, s in zip(mag_ticks, legend_sizes)] + handles = [ + pyplot.Line2D( + [0], + [0], + marker="o", + lw=0, + label=str(m), + markersize=np.sqrt(s), + markerfacecolor="gray", + alpha=0.5, + markeredgewidth=0.8, + markeredgecolor="black", + ) + for m, s in zip(mag_ticks, legend_sizes) + ] ax.legend( handles, @@ -1581,7 +1619,7 @@ def plot_spatial_dataset( ax = plot_basemap(basemap, extent, ax=ax, set_global=set_global, show=False, **plot_args) # Define colormap and alpha transparency - colormap, alpha = _define_colormap_and_alpha(colormap, alpha_exp, alpha) + colormap, alpha = _get_colormap(colormap, alpha_exp, alpha) # Plot spatial dataset lons, lats = numpy.meshgrid( @@ -1595,10 +1633,18 @@ def plot_spatial_dataset( # Colorbar options if colorbar: - _add_colorbar( - ax, im, clabel, plot_args["colorbar_labelsize"], plot_args["colorbar_ticksize"] + cax = ax.get_figure().add_axes( + [ + ax.get_position().x1 + 0.01, + ax.get_position().y0, + 0.025, + ax.get_position().height, + ], + label="Colorbar", ) - + cbar = ax.get_figure().colorbar(im, ax=ax, cax=cax) + cbar.set_label(clabel, fontsize=plot_args["colorbar_labelsize"]) + cbar.ax.tick_params(labelsize=plot_args["colorbar_ticksize"]) # Draw forecast's region border if plot_region and not set_global: try: @@ -1618,8 +1664,18 @@ def plot_spatial_dataset( ##################### # Plot helper functions ##################### -def _get_marker_style(obs_stat, p, one_sided_lower): - """Returns matplotlib marker style as fmt string""" +def _get_marker_style(obs_stat: float, p: Sequence[float], one_sided_lower: bool) -> str: + """ + Returns the matplotlib marker style as a format string. + + Args: + obs_stat (float): The observed statistic. + p (Sequence[float, float]): A tuple of lower and upper percentiles. + one_sided_lower (bool): Indicates if the test is one-sided lower. + + Returns: + str: A format string representing the marker style. + """ if obs_stat < p[0] or obs_stat > p[1]: # red circle fmt = "ro" @@ -1634,21 +1690,39 @@ def _get_marker_style(obs_stat, p, one_sided_lower): return fmt -def _get_marker_t_color(distribution): - """Returns matplotlib marker style as fmt string""" +def _get_marker_t_color(distribution: Sequence[float]) -> str: + """ + Returns the color for the marker based on the distribution. + + Args: + distribution (Sequence[float, float]): A tuple representing the lower and upper bounds + of the test distribution. + + Returns: + str: Marker color + """ if distribution[0] > 0.0 and distribution[1] > 0.0: - fmt = "green" + color = "green" elif distribution[0] < 0.0 and distribution[1] < 0.0: - fmt = "red" + color = "red" else: - fmt = "grey" + color = "grey" + + return color - return fmt +def _get_marker_w_color(distribution: float, percentile: float) -> bool: + """ + Returns a boolean indicating whether the distribution's percentile is below a given + threshold. -def _get_marker_w_color(distribution, percentile): - """Returns matplotlib marker style as fmt string""" + Args: + distribution (float): The value of the distribution's percentile. + percentile (float): The percentile threshold. + Returns: + bool: True if the distribution's percentile is below the threshold, False otherwise. + """ if distribution < (1 - percentile / 100): fmt = True else: @@ -1657,34 +1731,59 @@ def _get_marker_w_color(distribution, percentile): return fmt -def _get_axis_limits(pnts, border=0.05): - """Returns a tuple of x_min and x_max given points on plot.""" - x_min = numpy.min(pnts) - x_max = numpy.max(pnts) +def _get_axis_limits(points: Union[Sequence, numpy.ndarray], + border: float = 0.05) -> Tuple[float, float]: + """ + Returns a tuple of x_min and x_max given points on a plot. + + Args: + points (numpy.ndarray): An array of points. + border (float): The border fraction to apply to the limits. + + Returns: + Sequence[float, float]: The x_min and x_max values adjusted with the border. + """ + x_min = numpy.min(points) + x_max = numpy.max(points) xd = (x_max - x_min) * border return x_min - xd, x_max + xd -def _get_basemap(basemap): - last_cache = os.path.join(os.path.dirname(cartopy.config["cache_dir"]), 'last_cartopy_cache') +def _get_basemap(basemap: str) -> Union[img_tiles.GoogleTiles, DatasetReader]: + """ + Returns the basemap tiles for a given basemap type or web service. + + Args: + basemap (str): The type of basemap for cartopy, an URL for a web service or a TIF file + path. + + Returns: + Union[img_tiles.GoogleTiles, rasterio.io.DatasetReader]: The corresponding tiles or + raster object. + + """ + last_cache = os.path.join( + os.path.dirname(cartopy.config["cache_dir"]), "last_cartopy_cache" + ) def _clean_cache(basemap_): if os.path.isfile(last_cache): - with open(last_cache, 'r') as fp: + with open(last_cache, "r") as fp: cache_src = fp.read() if cache_src != basemap_: if os.path.isdir(cartopy.config["cache_dir"]): - print(f'Cleaning existing {basemap_} cache') + print(f"Cleaning existing {basemap_} cache") shutil.rmtree(cartopy.config["cache_dir"]) def _save_cache_src(basemap_): - with open(last_cache, 'w') as fp: + with open(last_cache, "w") as fp: fp.write(basemap_) cache = True - warning_message_to_suppress = ('Cartopy created the following directory to cache' - ' GoogleWTS tiles') + warning_message_to_suppress = ( + "Cartopy created the following directory to cache" " GoogleWTS tiles" + ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", message=warning_message_to_suppress) if basemap == "google-satellite": @@ -1729,7 +1828,7 @@ def _save_cache_src(basemap_): _save_cache_src(basemap) elif os.path.isfile(basemap): - return rasterio.open(basemap) + return rio_open(basemap) else: try: @@ -1743,409 +1842,28 @@ def _save_cache_src(basemap_): return tiles -def _add_labels_for_publication(figure, style="bssa", labelsize=16): - """Adds publication labels too the outside of a figure.""" - all_axes = figure.get_axes() - ascii_iter = iter(string.ascii_lowercase) - for ax in all_axes: - # check for colorbar and ignore for annotations - if ax.get_label() == "Colorbar": - continue - annot = next(ascii_iter) - if style == "bssa": - ax.annotate( - f"({annot})", (0.025, 1.025), xycoords="axes fraction", fontsize=labelsize - ) - - return - - -def _plot_pvalues_and_intervals(test_results, ax, var=None): - """Plots p-values and intervals for a list of Poisson or NBD test results +def _autosize_scatter( + values: numpy.ndarray, + min_size: float = 50.0, + max_size: float = 400.0, + power: float = 3.0, + min_val: Optional[float] = None, + max_val: Optional[float] = None, +) -> numpy.ndarray: + """ + Auto-sizes scatter plot markers based on values. Args: - test_results (list): list of EvaluationResults for N-test. All tests should use the same - distribution (ie Poisson or NBD). - ax (matplotlib.axes.Axes.axis): axes to use for plot. create using matplotlib - var (float): variance of the NBD distribution. Must be used for NBD plots. + values (numpy.ndarray): The data values (e.g., magnitude) to base the sizing on. + min_size (float): The minimum marker size. + max_size (float): The maximum marker size. + power (float): The power to apply for scaling. + min_val (Optional[float]): The minimum value (e.g., magnitude) for normalization. + max_val (Optional[float]): The maximum value (e.g., magnitude) for normalization. Returns: - ax (matplotlib.axes.Axes.axis): axes handle containing this plot - - Raises: - ValueError: throws error if NBD tests are supplied without a variance + numpy.ndarray: The calculated marker sizes. """ - - variance = var - percentile = 97.5 - p_values = [] - - # Differentiate between N-tests and other consistency tests - if test_results[0].name == "NBD N-Test" or test_results[0].name == "Poisson N-Test": - legend_elements = [ - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="red", - lw=0, - label=r"p < 10e-5", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="#FF7F50", - lw=0, - label=r"10e-5 $\leq$ p < 10e-4", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="gold", - lw=0, - label=r"10e-4 $\leq$ p < 10e-3", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="white", - lw=0, - label=r"10e-3 $\leq$ p < 0.0125", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="skyblue", - lw=0, - label=r"0.0125 $\leq$ p < 0.025", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="blue", - lw=0, - label=r"p $\geq$ 0.025", - markersize=10, - markeredgecolor="k", - ), - ] - ax.legend(handles=legend_elements, loc=4, fontsize=13, edgecolor="k") - # Act on Negative binomial tests - if test_results[0].name == "NBD N-Test": - if var is None: - raise ValueError("var must not be None if N-tests use the NBD distribution.") - - for i in range(len(test_results)): - mean = test_results[i].test_distribution[1] - upsilon = 1.0 - ((variance - mean) / variance) - tau = mean**2 / (variance - mean) - phigh97 = scipy.stats.nbinom.ppf((1 - percentile / 100.0) / 2.0, tau, upsilon) - plow97 = scipy.stats.nbinom.ppf( - 1 - (1 - percentile / 100.0) / 2.0, tau, upsilon - ) - low97 = test_results[i].observed_statistic - plow97 - high97 = phigh97 - test_results[i].observed_statistic - ax.errorbar( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - xerr=numpy.array([[low97, high97]]).T, - capsize=4, - color="slategray", - alpha=1.0, - zorder=0, - ) - p_values.append( - test_results[i].quantile[1] * 2.0 - ) # Calculated p-values according to Meletti et al., (2021) - - if p_values[i] < 10e-5: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="red", - markersize=8, - zorder=2, - ) - if p_values[i] >= 10e-5 and p_values[i] < 10e-4: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="#FF7F50", - markersize=8, - zorder=2, - ) - if p_values[i] >= 10e-4 and p_values[i] < 10e-3: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="gold", - markersize=8, - zorder=2, - ) - if p_values[i] >= 10e-3 and p_values[i] < 0.0125: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="white", - markersize=8, - zorder=2, - ) - if p_values[i] >= 0.0125 and p_values[i] < 0.025: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="skyblue", - markersize=8, - zorder=2, - ) - if p_values[i] >= 0.025: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="blue", - markersize=8, - zorder=2, - ) - # Act on Poisson N-test - if test_results[0].name == "Poisson N-Test": - for i in range(len(test_results)): - plow97 = scipy.stats.poisson.ppf( - (1 - percentile / 100.0) / 2.0, test_results[i].test_distribution[1] - ) - phigh97 = scipy.stats.poisson.ppf( - 1 - (1 - percentile / 100.0) / 2.0, test_results[i].test_distribution[1] - ) - low97 = test_results[i].observed_statistic - plow97 - high97 = phigh97 - test_results[i].observed_statistic - ax.errorbar( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - xerr=numpy.array([[low97, high97]]).T, - capsize=4, - color="slategray", - alpha=1.0, - zorder=0, - ) - p_values.append(test_results[i].quantile[1] * 2.0) - if p_values[i] < 10e-5: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="red", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-5 and p_values[i] < 10e-4: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="#FF7F50", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-4 and p_values[i] < 10e-3: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="gold", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-3 and p_values[i] < 0.0125: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="white", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 0.0125 and p_values[i] < 0.025: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="skyblue", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 0.025: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="blue", - markersize=8, - zorder=2, - ) - # Operate on all other consistency tests - else: - for i in range(len(test_results)): - plow97 = numpy.percentile(test_results[i].test_distribution, 2.5) - phigh97 = numpy.percentile(test_results[i].test_distribution, 97.5) - low97 = test_results[i].observed_statistic - plow97 - high97 = phigh97 - test_results[i].observed_statistic - ax.errorbar( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - xerr=numpy.array([[low97, high97]]).T, - capsize=4, - color="slategray", - alpha=1.0, - zorder=0, - ) - p_values.append(test_results[i].quantile) - - if p_values[i] < 10e-5: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="red", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-5 and p_values[i] < 10e-4: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="#FF7F50", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-4 and p_values[i] < 10e-3: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="gold", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 10e-3 and p_values[i] < 0.025: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="white", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 0.025 and p_values[i] < 0.05: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="skyblue", - markersize=8, - zorder=2, - ) - elif p_values[i] >= 0.05: - ax.plot( - test_results[i].observed_statistic, - (len(test_results) - 1) - i, - marker="o", - color="blue", - markersize=8, - zorder=2, - ) - - legend_elements = [ - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="red", - lw=0, - label=r"p < 10e-5", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="#FF7F50", - lw=0, - label=r"10e-5 $\leq$ p < 10e-4", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="gold", - lw=0, - label=r"10e-4 $\leq$ p < 10e-3", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="white", - lw=0, - label=r"10e-3 $\leq$ p < 0.025", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="skyblue", - lw=0, - label=r"0.025 $\leq$ p < 0.05", - markersize=10, - markeredgecolor="k", - ), - matplotlib.lines.Line2D( - [0], - [0], - marker="o", - color="blue", - lw=0, - label=r"p $\geq$ 0.05", - markersize=10, - markeredgecolor="k", - ), - ] - - ax.legend(handles=legend_elements, loc=4, fontsize=13, edgecolor="k") - - return ax - - -def _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=None, - max_val=None): - min_val = min_val or np.min(values) max_val = max_val or np.max(values) normalized_values = ((values - min_val) / (max_val - min_val)) ** power @@ -2153,7 +1871,26 @@ def _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=No return marker_sizes -def _autoscale_histogram(ax: pyplot.Axes, bin_edges, simulated, observation, mass=99.5): +def _autoscale_histogram( + ax: matplotlib.axes.Axes, + bin_edges: numpy.ndarray, + simulated: numpy.ndarray, + observation: numpy.ndarray, + mass: float = 99.5, +) -> matplotlib.axes.Axes: + """ + Autoscale the histogram axes based on the data distribution. + + Args: + ax (matplotlib.axes.Axes): The axes to apply the scaling to. + bin_edges (numpy.ndarray): The edges of the histogram bins. + simulated (numpy.ndarray): Simulated data values. + observation (numpy.ndarray): Observed data values. + mass (float): The percentage of the data mass to consider. + + Returns: + matplotlib.axes.Axes: The scaled axes + """ upper_xlim = numpy.percentile(simulated, 100 - (100 - mass) / 2) upper_xlim = numpy.max([upper_xlim, numpy.max(observation)]) @@ -2180,10 +1917,23 @@ def _autoscale_histogram(ax: pyplot.Axes, bin_edges, simulated, observation, mas def _annotate_distribution_plot( - ax, evaluation_result, auto_annotate, plot_args + ax: matplotlib.axes.Axes, + evaluation_result: "EvaluationResult", + auto_annotate: bool, + plot_args: Dict[str, Any], ) -> matplotlib.axes.Axes: - """Returns specific plot details based on the type of evaluation_result.""" + """ + Annotates a distribution plot based on the evaluation result type. + + Args: + ax (matplotlib.axes.Axes): The axes to annotate. + evaluation_result (EvaluationResult): The evaluation result object. + auto_annotate (bool): If True, automatically annotates the plot based on result type. + plot_args (Dict[str, Any]): Additional plotting arguments. + Returns: + matplotlib.axes.Axes: The annotated axes. + """ annotation_text = None annotation_xy = None title = None @@ -2214,7 +1964,8 @@ def _annotate_distribution_plot( title = f"{evaluation_result.name}: {evaluation_result.sim_name}" annotation_xy = (0.2, 0.6) annotation_text = ( - f"$\\gamma = P(X \\leq x) = {numpy.array(evaluation_result.quantile).ravel()[-1]:.2f}$\n" + f"$\\gamma = P(X \\leq x) = " + f"{numpy.array(evaluation_result.quantile).ravel()[-1]:.2f}$\n" f"$\\omega = {evaluation_result.observed_statistic:.2f}$" ) @@ -2224,7 +1975,8 @@ def _annotate_distribution_plot( title = f"{evaluation_result.name}: {evaluation_result.sim_name}" annotation_xy = (0.55, 0.6) annotation_text = ( - f"$\\gamma = P(X \\geq x) = {numpy.array(evaluation_result.quantile).ravel()[0]:.2f}$\n" + f"$\\gamma = P(X \\geq x) = " + f"{numpy.array(evaluation_result.quantile).ravel()[0]:.2f}$\n" f"$\\omega = {evaluation_result.observed_statistic:.2f}$" ) elif evaluation_result.name == "Catalog PL-Test": @@ -2233,7 +1985,8 @@ def _annotate_distribution_plot( title = f"{evaluation_result.name}: {evaluation_result.sim_name}" annotation_xy = (0.55, 0.3) annotation_text = ( - f"$\\gamma = P(X \\leq x) = {numpy.array(evaluation_result.quantile).ravel()[-1]:.2f}$\n" + f"$\\gamma = P(X \\leq x) = " + f"{numpy.array(evaluation_result.quantile).ravel()[-1]:.2f}$\n" f"$\\omega = {evaluation_result.observed_statistic:.2f}$" ) @@ -2255,12 +2008,30 @@ def _annotate_distribution_plot( return ax -def _calculate_spatial_extent(catalog, set_global, region_border, padding_fraction=0.05): +def _calculate_spatial_extent( + element: Union["CSEPCatalog", "CartesianGrid2D"], + set_global: bool, + region_border: bool, + padding_fraction: float = 0.05, +) -> Optional[List[float]]: + """ + Calculates the spatial extent for plotting based on the catalog. + + Args: + element (CSEPCatalog), CartesianGrid2D: The catalog or region object to base the extent + on. + set_global (bool): If True, sets the extent to the global view. + region_border (bool): If True, uses the catalog's region border. + padding_fraction (float): The fraction of padding to apply to the extent. + + Returns: + Optional[List[float]]: The calculated extent or None if global view is set. + """ # todo: perhaps calculate extent also from chained ax object - bbox = catalog.get_bbox() + bbox = element.get_bbox() if region_border: try: - bbox = catalog.region.get_bbox() + bbox = element.region.get_bbox() except AttributeError: pass @@ -2272,7 +2043,24 @@ def _calculate_spatial_extent(catalog, set_global, region_border, padding_fracti return [bbox[0] - dh, bbox[1] + dh, bbox[2] - dv, bbox[3] + dv] -def _create_geo_axes(figsize, extent, projection, set_global): +def _create_geo_axes( + figsize: Optional[Tuple[float, float]], + extent: Optional[List[float]], + projection: Union[ccrs.Projection, str], + set_global: bool, +) -> pyplot.Axes: + """ + Creates and returns GeoAxes for plotting. + + Args: + figsize (Optional[Tuple[float, float]]): The size of the figure. + extent (Optional[List[float]]): The spatial extent to set. + projection (Union[ccrs.Projection, str]): The projection to use. + set_global (bool): If True, sets the global view. + + Returns: + pyplot.Axes: The created GeoAxes object. + """ if projection == "approx": fig = pyplot.figure(figsize=figsize) @@ -2291,7 +2079,15 @@ def _create_geo_axes(figsize, extent, projection, set_global): return ax -def _add_gridlines(ax, grid_labels, grid_fontsize): +def _add_gridlines(ax: matplotlib.axes.Axes, grid_labels: bool, grid_fontsize: float) -> None: + """ + Adds gridlines and optionally labels to the axes. + + Args: + ax (matplotlib.axes.Axes): The axes to add gridlines to. + grid_labels (bool): If True, labels the gridlines. + grid_fontsize (float): The font size of the grid labels. + """ gl = ax.gridlines(draw_labels=grid_labels, alpha=0.5) gl.right_labels = False gl.top_labels = False @@ -2301,18 +2097,22 @@ def _add_gridlines(ax, grid_labels, grid_fontsize): gl.yformatter = LATITUDE_FORMATTER -def _define_colormap_and_alpha(cmap, alpha_exp, alpha_0=None): +def _get_colormap( + cmap: Union[str, matplotlib.colors.Colormap], + alpha_exp: float, + alpha_0: Optional[float] = None, +) -> Tuple[matplotlib.colors.ListedColormap, Optional[float]]: """ Defines the colormap and applies alpha transparency based on the given parameters. Args: - cmap (str or matplotlib.colors.Colormap): The colormap to be used. - alpha_0 (float or None): If set, this alpha will be applied uniformly across the colormap. - alpha_exp (float): Exponent to control transparency scaling. If set to 0, no alpha scaling is applied. + cmap (Union[str, matplotlib.colors.Colormap]): The colormap to use. + alpha_exp (float): The exponent to control transparency scaling. + alpha_0 (Optional[float]): If set, applies a uniform alpha across the colormap. Returns: - cmap (matplotlib.colors.ListedColormap): The resulting colormap with applied alpha. - alpha (float or None): The alpha value used for the entire colormap, or None if alpha is scaled per color. + Tuple[matplotlib.colors.ListedColormap, Optional[float]]: The modified colormap + and the alpha value used for the entire colormap. """ # Get the colormap object if a string is provided @@ -2336,32 +2136,41 @@ def _define_colormap_and_alpha(cmap, alpha_exp, alpha_0=None): return cmap, alpha -def _add_colorbar(ax, im, clabel, clabel_fontsize, cticks_fontsize): - fig = ax.get_figure() - cax = fig.add_axes( - [ax.get_position().x1 + 0.01, ax.get_position().y0, 0.025, ax.get_position().height], - label="Colorbar", - ) - cbar = fig.colorbar(im, ax=ax, cax=cax) - cbar.set_label(clabel, fontsize=clabel_fontsize) - cbar.ax.tick_params(labelsize=cticks_fontsize) +def _process_stat_distribution( + res: "EvaluationResult", + percentile: float, + variance: Optional[float], + normalize: bool, + one_sided_lower: bool, +) -> Tuple[float, float, float, float]: + """ + Processes the statistical distribution based on its type and returns plotting values. + Args: + res (EvaluationResult): The evaluation result object containing the distribution data. + percentile (float): The percentile for calculating the confidence intervals. + variance (Optional[float]): The variance of the negative binomial distribution, if + applicable. + normalize (bool): If True, normalizes the distribution by the observed statistic. + one_sided_lower (bool): If True, performs a one-sided lower test. -def _process_stat_distribution(res, percentile, variance, normalize, one_sided_lower): - """Process the distribution based on its type and return plotting values.""" + Returns: + Tuple[float, float, float, float]: A tuple containing the lower percentile, + upper percentile, mean, and observed statistic. + """ dist_type = res.test_distribution[0] if dist_type == "poisson": mean = res.test_distribution[1] - plow = scipy.stats.poisson.ppf((1 - percentile / 100.0) / 2.0, mean) - phigh = scipy.stats.poisson.ppf(1 - (1 - percentile / 100.0) / 2.0, mean) + plow = poisson.ppf((1 - percentile / 100.0) / 2.0, mean) + phigh = poisson.ppf(1 - (1 - percentile / 100.0) / 2.0, mean) observed_statistic = res.observed_statistic elif dist_type == "negative_binomial": mean = res.test_distribution[1] upsilon = 1.0 - ((variance - mean) / variance) tau = mean**2 / (variance - mean) - plow = scipy.stats.nbinom.ppf((1 - percentile / 100.0) / 2.0, tau, upsilon) - phigh = scipy.stats.nbinom.ppf(1 - (1 - percentile / 100.0) / 2.0, tau, upsilon) + plow = nbinom.ppf((1 - percentile / 100.0) / 2.0, tau, upsilon) + phigh = nbinom.ppf(1 - (1 - percentile / 100.0) / 2.0, tau, upsilon) observed_statistic = res.observed_statistic else: diff --git a/requirements.yml b/requirements.yml index 82d6d180..be2fd220 100644 --- a/requirements.yml +++ b/requirements.yml @@ -11,7 +11,6 @@ dependencies: - pyproj - obspy - python-dateutil - - rasterio - cartopy - shapely - mercantile diff --git a/tests/test_plots.py b/tests/test_plots.py index e5de65fa..bd8139cd 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -14,9 +14,9 @@ import numpy from cartopy import crs as ccrs from matplotlib import colors -from matplotlib.text import Annotation import csep +from csep.core import catalogs from csep.core.catalog_evaluations import ( CatalogNumberTestResult, CatalogSpatialTestResult, @@ -24,7 +24,6 @@ CatalogPseudolikelihoodTestResult, CalibrationTestResult, ) -from csep.core import catalogs from csep.utils.plots import ( plot_cumulative_events_versus_time, plot_magnitude_vs_time, @@ -47,29 +46,30 @@ _get_marker_t_color, # noqa _get_marker_w_color, # noqa _get_axis_limits, # noqa - _add_labels_for_publication, # noqa _autosize_scatter, # noqa _autoscale_histogram, # noqa _annotate_distribution_plot, # noqa - _define_colormap_and_alpha, # noqa - _add_colorbar, # noqa + _get_colormap, # noqa _process_stat_distribution, # noqa ) def is_internet_available(): - """Check if internet is available by attempting to connect to a well-known host.""" try: - # Try to connect to Google's DNS server - socket.create_connection(("8.8.8.8", 53), timeout=1) + socket.create_connection(("8.8.8.8", 53), timeout=3) return True except OSError: return False -is_github_actions = os.getenv("GITHUB_ACTIONS") == "true" +def is_github_ci(): + if os.getenv("GITHUB_ACTIONS") or os.getenv("CI") or os.getenv("GITHUB_ACTION"): + return True + else: + return False + -show_plots = True +show_plots = False class TestPlots(unittest.TestCase): @@ -85,6 +85,7 @@ def savefig(self, ax, name): ax.figure.savefig(os.path.join(self.save_dir, name)) +# class TestTimeSeriesPlots(TestPlots): def setUp(self): @@ -118,13 +119,14 @@ def test_plot_magnitude_vs_time(self): self.assertEqual(ax.get_ylabel(), "$M$") # Test with custom color - ax = plot_magnitude_vs_time(catalog=self.observation_m2, color='red', show=show_plots) + ax = plot_magnitude_vs_time(catalog=self.observation_m2, color="red", show=show_plots) scatter_color = ax.collections[0].get_facecolor()[0] self.assertTrue(all(scatter_color[:3] == (1.0, 0.0, 0.0))) # Check if color is red # Test with custom marker size - ax = plot_magnitude_vs_time(catalog=self.observation_m2, size=25, max_size=600, - show=show_plots) + ax = plot_magnitude_vs_time( + catalog=self.observation_m2, size=25, max_size=600, show=show_plots + ) scatter_sizes = ax.collections[0].get_sizes() func_sizes = _autosize_scatter(self.observation_m2.data["magnitude"], 25, 600, 4) numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes) @@ -141,14 +143,15 @@ def test_plot_magnitude_vs_time(self): numpy.testing.assert_array_almost_equal(scatter_sizes, func_sizes) # # # Test with show=True (just to ensure no errors occur) - plot_magnitude_vs_time(catalog=self.observation_m2, show=True) + plot_magnitude_vs_time(catalog=self.observation_m2, show=False) + plt.close("all") def test_plot_cumulative_events_default(self): # Test with default arguments to ensure basic functionality ax = plot_cumulative_events_versus_time( catalog_forecast=self.stochastic_event_sets, observation=self.observation_m5, - show=show_plots + show=show_plots, ) self.assertIsNotNone(ax.get_title()) @@ -166,7 +169,7 @@ def test_plot_cumulative_events_hours(self): ylabel="Cumulative Event Count", title="Cumulative Event Counts by Hour", legend_loc="upper left", - show=show_plots + show=show_plots, ) self.assertEqual(ax.get_xlabel(), "Hours since Mainshock") @@ -185,7 +188,7 @@ def test_plot_cumulative_events_different_bins(self): xlabel="Days since Mainshock", ylabel="Cumulative Event Count", title="Cumulative Event Counts with More Bins", - legend_loc="best" + legend_loc="best", ) self.assertEqual(ax.get_title(), "Cumulative Event Counts with More Bins") @@ -205,7 +208,7 @@ def test_plot_cumulative_events_custom_legend(self): ylabel="Cumulative Event Count", title="Cumulative Event Counts with Custom Legend", legend_loc="lower right", - legend_fontsize=14 + legend_fontsize=14, ) self.assertEqual(ax.get_legend()._get_loc(), 4) @@ -237,7 +240,6 @@ def gr_dist(num_events, mag_min=3.0, mag_max=8.0, b_val=1.0): self.mock_cat.get_magnitudes.return_value = gr_dist(500, b_val=1.2) self.mock_cat.get_number_of_events.return_value = 500 self.mock_cat.region.magnitudes = numpy.arange(3.0, 8.0, 0.1) - self.save_dir = os.path.join(os.path.dirname(__file__), "artifacts", "plots") cat_file_m5 = os.path.join( self.artifacts, @@ -259,28 +261,19 @@ def gr_dist(num_events, mag_min=3.0, mag_max=8.0, b_val=1.0): def test_plot_magnitude_histogram_basic(self): # Test with basic arguments - ax = plot_magnitude_histogram(self.mock_forecast, - self.mock_cat, show=show_plots, - density=True) + plot_magnitude_histogram( + self.mock_forecast, self.mock_cat, show=show_plots, density=True + ) # Verify that magnitudes were retrieved for catalog in self.mock_forecast: catalog.get_magnitudes.assert_called_once() self.mock_cat.get_magnitudes.assert_called_once() self.mock_cat.get_number_of_events.assert_called_once() - ax.figure.savefig(os.path.join(self.save_dir, "magnitude_histogram.png")) def test_plot_magnitude_histogram_ucerf(self): # Test with basic arguments - ax = plot_magnitude_histogram(self.stochastic_event_sets, self.comcat, - show=show_plots) - - # # Verify that magnitudes were retrieved - # for catalog in self.stochastic_event_sets: - # catalog.get_magnitudes.assert_called_once() - # self.comcat.get_magnitudes.assert_called_once() - # self.comcat.get_number_of_events.assert_called_once() - ax.figure.savefig(os.path.join(self.save_dir, "magnitude_histogram_ucerf.png")) + plot_magnitude_histogram(self.stochastic_event_sets, self.comcat, show=show_plots) def tearDown(self): plt.close("all") @@ -404,45 +397,39 @@ def test_plot_dist_test_xlim(self): xlim=xlim, show=show_plots, ) - self.savefig(ax, "plot_dist_test_xlims.png") self.assertEqual(ax.get_xlim(), xlim) def test_plot_dist_test_autoxlim_nan(self): - ax = plot_distribution_test( + plot_distribution_test( evaluation_result=self.result_nan, percentile=95, show=show_plots, ) - self.savefig(ax, "plot_dist_test_xlims_inf.png") def test_plot_n_test(self): - ax = plot_distribution_test( + plot_distribution_test( self.n_test, show=show_plots, ) - self.savefig(ax, "plot_n_test.png") def test_plot_m_test(self): - ax = plot_distribution_test( + plot_distribution_test( self.m_test, show=show_plots, ) - self.savefig(ax, "plot_m_test.png") def test_plot_s_test(self): - ax = plot_distribution_test( + plot_distribution_test( self.s_test, show=show_plots, ) - self.savefig(ax, "plot_s_test.png") def test_plot_l_test(self): - ax = plot_distribution_test( + plot_distribution_test( self.l_test, show=show_plots, ) - self.savefig(ax, "plot_l_test.png") def tearDown(self): plt.close("all") @@ -523,7 +510,7 @@ def setUp(self): def test_plot_consistency_basic(self): ax = plot_consistency_test(eval_results=self.mock_result, show=show_plots) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") self.assertEqual(ax.get_xlabel(), "Statistic distribution") def test_plot_consistency_with_multiple_results(self): @@ -532,8 +519,9 @@ def test_plot_consistency_with_multiple_results(self): self.assertEqual(len(ax.get_yticklabels()), 5) def test_plot_consistency_with_normalization(self): - ax = plot_consistency_test(eval_results=self.mock_result, normalize=True, - show=show_plots) + ax = plot_consistency_test( + eval_results=self.mock_result, normalize=True, show=show_plots + ) # Assert that the observed statistic is plotted at 0 self.assertEqual(ax.lines[0].get_xdata(), 0) @@ -541,38 +529,47 @@ def test_plot_consistency_with_one_sided_lower(self): mock_result = copy.deepcopy(self.mock_result) # THe observed statistic is placed to the right of the model test distribution. mock_result.observed_statistic = max(self.mock_result.test_distribution) + 1 - ax = plot_consistency_test(eval_results=mock_result, one_sided_lower=True, - show=show_plots) + ax = plot_consistency_test( + eval_results=mock_result, one_sided_lower=True, show=show_plots + ) # The end of the infinite dashed line should extend way away from the plot limit self.assertGreater(ax.lines[-1].get_xdata()[-1], ax.get_xlim()[1]) def test_plot_consistency_with_custom_percentile(self): - ax = plot_consistency_test(eval_results=self.mock_result, percentile=99, - show=show_plots) + ax = plot_consistency_test( + eval_results=self.mock_result, percentile=99, show=show_plots + ) # Check that the line extent equals the lower 0.5 % percentile - self.assertAlmostEqual(ax.lines[2].get_xdata(), - numpy.percentile(self.mock_result.test_distribution, 0.5)) + self.assertAlmostEqual( + ax.lines[2].get_xdata(), numpy.percentile(self.mock_result.test_distribution, 0.5) + ) def test_plot_consistency_with_variance(self): mock_nb = copy.deepcopy(self.mock_result) mock_poisson = copy.deepcopy(self.mock_result) - mock_nb.test_distribution = ('negative_binomial', 8) - mock_poisson.test_distribution = ('poisson', 8) + mock_nb.test_distribution = ("negative_binomial", 8) + mock_poisson.test_distribution = ("poisson", 8) ax_nb = plot_consistency_test(eval_results=mock_nb, variance=16, show=show_plots) ax_p = plot_consistency_test(eval_results=mock_poisson, variance=None, show=show_plots) # Ensure the negative binomial has a larger x-axis extent than poisson self.assertTrue(ax_p.get_xlim()[1] < ax_nb.get_xlim()[1]) def test_plot_consistency_with_custom_plot_args(self): - ax = plot_consistency_test(eval_results=self.mock_result, show=show_plots, - xlabel="Custom X", ylabel="Custom Y", title="Custom Title") + ax = plot_consistency_test( + eval_results=self.mock_result, + show=show_plots, + xlabel="Custom X", + ylabel="Custom Y", + title="Custom Title", + ) self.assertEqual(ax.get_xlabel(), "Custom X") self.assertEqual(ax.get_title(), "Custom Title") def test_plot_consistency_with_mean(self): - ax = plot_consistency_test(eval_results=self.mock_result, plot_mean=True, - show=show_plots) + ax = plot_consistency_test( + eval_results=self.mock_result, plot_mean=True, show=show_plots + ) # Check for the mean line plotted as a circle self.assertTrue(any(["o" in str(line.get_marker()) for line in ax.lines])) @@ -593,7 +590,7 @@ def test_SingleNTestPlot(self): [i.get_text() for i in matplotlib.pyplot.gca().get_yticklabels()], [i.sim_name for i in [Ntest_result]], ) - self.assertEqual(matplotlib.pyplot.gca().get_title(), '') + self.assertEqual(matplotlib.pyplot.gca().get_title(), "") def test_MultiNTestPlot(self): @@ -650,7 +647,7 @@ def test_MultiTTestPlot(self): t_plots = numpy.random.randint(2, 20) t_tests = [] - def rand(limit=10, offset=0.): + def rand(limit=10, offset=0.0): return limit * (numpy.random.random() - offset) for n in range(t_plots): @@ -704,7 +701,7 @@ def test_plot_basemap_with_features(self, mock_get_basemap): mock_tiles = MagicMock() mock_get_basemap.return_value = mock_tiles - basemap = 'stock_img' + basemap = "stock_img" ax = plot_basemap( basemap=basemap, extent=self.chiloe_extent, @@ -721,7 +718,7 @@ def test_plot_basemap_with_features(self, mock_get_basemap): mock_get_basemap.assert_not_called() self.assertTrue(ax.get_legend() is None) - @unittest.skipIf(is_github_actions, "Skipping test in GitHub CI environment") + @unittest.skipIf(is_github_ci(), "Skipping test in GitHub CI environment") @unittest.skipIf(not is_internet_available(), "Skipping test due to no internet connection") def test_plot_google_satellite(self): basemap = "google-satellite" @@ -735,7 +732,7 @@ def test_plot_google_satellite(self): self.assertIsInstance(ax, plt.Axes) self.assertTrue(ax.get_legend() is None) - @unittest.skipIf(is_github_actions, "Skipping test in GitHub CI environment") + @unittest.skipIf(is_github_ci(), "Skipping test in GitHub CI environment") @unittest.skipIf(not is_internet_available(), "Skipping test due to no internet connection") def test_plot_esri(self): basemap = "ESRI_terrain" @@ -766,6 +763,7 @@ def test_plot_basemap_set_global(self, mock_get_basemap): mock_get_basemap.assert_not_called() self.assertTrue(ax.get_extent() == (-180, 180, -90, 90)) + @unittest.skipIf(is_github_ci(), "Skipping test in GitHub CI environment") def test_plot_basemap_tif_file(self): basemap = csep.datasets.basemap_california projection = ccrs.PlateCarree() @@ -786,15 +784,17 @@ def test_plot_basemap_with_custom_projection(self): def test_plot_basemap_with_custom_projection_and_features(self): projection = ccrs.Mercator() basemap = None - ax = plot_basemap(basemap=basemap, - extent=self.chiloe_extent, - projection=projection, - coastline=True, - borders=True, - grid=True, - grid_labels=True, - grid_fontsize=8, - show=show_plots) + ax = plot_basemap( + basemap=basemap, + extent=self.chiloe_extent, + projection=projection, + coastline=True, + borders=True, + grid=True, + grid_labels=True, + grid_fontsize=8, + show=show_plots, + ) self.assertIsInstance(ax, plt.Axes) self.assertEqual(ax.projection, projection) @@ -821,7 +821,7 @@ def setUp(self): numpy.min(self.mock_catalog.get_longitudes()), numpy.max(self.mock_catalog.get_longitudes()), numpy.min(self.mock_catalog.get_latitudes()), - numpy.max(self.mock_catalog.get_latitudes()) + numpy.max(self.mock_catalog.get_latitudes()), ] # Mock region if needed @@ -840,13 +840,13 @@ def test_plot_catalog_default(self): # Test plot with default settings4 ax = plot_catalog(self.mock_catalog, show=show_plots) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def test_plot_catalog_title(self): # Test plot with default settings ax = plot_catalog(self.mock_catalog, show=show_plots, title=self.mock_catalog.name) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), 'Mock Catalog') + self.assertEqual(ax.get_title(), "Mock Catalog") def test_plot_catalog_without_legend(self): # Test plot with legend @@ -856,8 +856,7 @@ def test_plot_catalog_without_legend(self): def test_plot_catalog_custom_legend(self): - ax = plot_catalog(self.mock_catalog, mag_ticks=5, - show=show_plots) + ax = plot_catalog(self.mock_catalog, mag_ticks=5, show=show_plots) legend = ax.get_legend() self.assertIsNotNone(legend) @@ -869,18 +868,19 @@ def test_plot_catalog_custom_legend(self): def test_plot_catalog_correct_sizing(self): - ax = plot_catalog(self.mock_fix, - figsize=(4,6), - mag_ticks=[4, 5, 6, 7, 8], - legend_loc='right', - show=show_plots) + ax = plot_catalog( + self.mock_fix, + figsize=(4, 6), + mag_ticks=[4, 5, 6, 7, 8], + legend_loc="right", + show=show_plots, + ) legend = ax.get_legend() self.assertIsNotNone(legend) def test_plot_catalog_custom_sizes(self): - ax = plot_catalog(self.mock_catalog, size=5, max_size=800, power=6, - show=show_plots) + ax = plot_catalog(self.mock_catalog, size=5, max_size=800, power=6, show=show_plots) legend = ax.get_legend() self.assertIsNotNone(legend) @@ -909,44 +909,48 @@ def test_plot_catalog_with_region_border(self): def test_plot_catalog_with_no_grid(self): # Test plot with grid disabled - ax = plot_catalog( - self.mock_catalog, show=show_plots, grid=False - ) + ax = plot_catalog(self.mock_catalog, show=show_plots, grid=False) gl = ax.gridlines() self.assertIsNotNone(gl) def test_plot_catalog_w_basemap(self): # Test plot with default settings - ax = plot_catalog(self.mock_catalog, basemap='stock_img', show=show_plots) + ax = plot_catalog(self.mock_catalog, basemap="stock_img", show=show_plots) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def test_plot_catalog_w_basemap_stream_kwargs(self): projection = ccrs.Mercator() - ax = plot_catalog(self.mock_catalog, basemap=None, - projection=projection, - coastline=True, - borders=True, - grid=True, - grid_labels=True, - grid_fontsize=8, - show=show_plots) + ax = plot_catalog( + self.mock_catalog, + basemap=None, + projection=projection, + coastline=True, + borders=True, + grid=True, + grid_labels=True, + grid_fontsize=8, + show=show_plots, + ) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def test_plot_catalog_w_approx_projection(self): - projection = 'approx' - ax = plot_catalog(self.mock_catalog, basemap='stock_img', - projection=projection, - coastline=True, - borders=True, - grid=True, - grid_labels=True, - grid_fontsize=8, - show=show_plots) + projection = "approx" + ax = plot_catalog( + self.mock_catalog, + basemap="stock_img", + projection=projection, + coastline=True, + borders=True, + grid=True, + grid_labels=True, + grid_fontsize=8, + show=show_plots, + ) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def tearDown(self): plt.close("all") @@ -981,7 +985,7 @@ def test_default_plot(self): def test_extent_setting_w_ax(self): extent = (-30, 30, -20, 20) ax = plot_spatial_dataset( - self.gridded_data, self.region, extent=extent, show=show_plots + self.gridded_data, self.region, extent=extent, show=show_plots ) numpy.testing.assert_array_almost_equal(ax.get_extent(crs=ccrs.PlateCarree()), extent) @@ -996,17 +1000,23 @@ def test_extent_setting(self): def test_color_mapping(self): cmap = plt.get_cmap("plasma") fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - ax = plot_spatial_dataset(self.gridded_data, self.region, ax=ax, colormap=cmap, show=show_plots) + ax = plot_spatial_dataset( + self.gridded_data, self.region, ax=ax, colormap=cmap, show=show_plots + ) self.assertIsInstance(ax.collections[0].cmap, colors.ListedColormap) def test_gridlines(self): fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - ax = plot_spatial_dataset(self.gridded_data, self.region, ax=ax, grid=True, show=show_plots) + ax = plot_spatial_dataset( + self.gridded_data, self.region, ax=ax, grid=True, show=show_plots + ) self.assertTrue(ax.gridlines()) def test_alpha_transparency(self): fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) - ax = plot_spatial_dataset(self.gridded_data, self.region, ax=ax, alpha=0.5, show=show_plots) + ax = plot_spatial_dataset( + self.gridded_data, self.region, ax=ax, alpha=0.5, show=show_plots + ) self.assertIsInstance(ax, plt.Axes) def test_plot_with_alpha_exp(self): @@ -1049,17 +1059,18 @@ def test_plot_spatial_dataset_w_basemap_stream_kwargs(self): grid_labels=True, grid_fontsize=8, show=show_plots, - plot_region=False + plot_region=False, ) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def test_plot_spatial_dataset_w_approx_projection(self): - projection = 'approx' + projection = "approx" ax = plot_spatial_dataset( self.gridded_data, - self.region, basemap='stock_img', + self.region, + basemap="stock_img", extent=[-20, 40, -5, 25], projection=projection, coastline=True, @@ -1068,11 +1079,11 @@ def test_plot_spatial_dataset_w_approx_projection(self): grid_labels=True, grid_fontsize=8, show=show_plots, - plot_region=False + plot_region=False, ) self.assertIsInstance(ax, plt.Axes) - self.assertEqual(ax.get_title(), '') + self.assertEqual(ax.get_title(), "") def tearDown(self): plt.close("all") @@ -1112,50 +1123,43 @@ def test_get_axis_limits(self): expected_limits = (0.8, 5.2) self.assertEqual(_get_axis_limits(pnts, border=0.05), expected_limits) - def test_add_labels_for_publication(self): - fig = plt.figure() - ax = fig.add_subplot(111) - _add_labels_for_publication(fig) - annotations = [child for child in ax.get_children() if isinstance(child, Annotation)] - self.assertEqual(len(annotations), 1) - self.assertEqual(annotations[0].get_text(), "(a)") - - def test_autosize_scatter(self): values = numpy.array([1, 2, 3, 4, 5]) - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) - result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + expected_sizes = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) + result = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) values = numpy.array([1, 2, 3, 4, 5]) min_val = 0 max_val = 10 - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0, - min_val=min_val, max_val=max_val) - result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0, min_val=min_val, - max_val=max_val) + expected_sizes = _autosize_scatter( + values, min_size=50.0, max_size=400.0, power=3.0, min_val=min_val, max_val=max_val + ) + result = _autosize_scatter( + values, min_size=50.0, max_size=400.0, power=3.0, min_val=min_val, max_val=max_val + ) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) values = numpy.array([1, 2, 3, 4, 5]) power = 2.0 - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power) - result = _autosize_scatter(values, min_size=50., max_size=400., power=power) + expected_sizes = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=power) + result = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=power) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) values = numpy.array([1, 2, 3, 4, 5]) power = 0.0 - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=power) - result = _autosize_scatter(values, min_size=50., max_size=400., power=power) + expected_sizes = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=power) + result = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=power) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) values = numpy.array([5, 5, 5, 5, 5]) - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) - result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + expected_sizes = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) + result = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) values = numpy.array([10, 100, 1000, 10000, 100000]) - expected_sizes = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) - result = _autosize_scatter(values, min_size=50., max_size=400., power=3.0) + expected_sizes = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) + result = _autosize_scatter(values, min_size=50.0, max_size=400.0, power=3.0) numpy.testing.assert_almost_equal(result, expected_sizes, decimal=2) def test_autoscale_histogram(self): @@ -1175,8 +1179,8 @@ def test_autoscale_histogram(self): def test_annotate_distribution_plot(self): # Mock evaluation_result for Catalog N-Test evaluation_result = Mock() - evaluation_result.name = 'Catalog N-Test' - evaluation_result.sim_name = 'Simulated Catalog' + evaluation_result.name = "Catalog N-Test" + evaluation_result.sim_name = "Simulated Catalog" evaluation_result.quantile = [0.25, 0.75] evaluation_result.observed_statistic = 5.0 @@ -1190,19 +1194,20 @@ def test_annotate_distribution_plot(self): "title": None, } - ax = _annotate_distribution_plot(ax, evaluation_result, auto_annotate=True, - plot_args=plot_args) + ax = _annotate_distribution_plot( + ax, evaluation_result, auto_annotate=True, plot_args=plot_args + ) # Assertions to check if the annotations were correctly set - self.assertEqual(ax.get_xlabel(), 'Event Count') - self.assertEqual(ax.get_ylabel(), 'Number of Catalogs') - self.assertEqual(ax.get_title(), 'Catalog N-Test: Simulated Catalog') + self.assertEqual(ax.get_xlabel(), "Event Count") + self.assertEqual(ax.get_ylabel(), "Number of Catalogs") + self.assertEqual(ax.get_title(), "Catalog N-Test: Simulated Catalog") annotation = ax.texts[0].get_text() expected_annotation = ( - f'$\\delta_1 = P(X \\geq x) = 0.25$\n' - f'$\\delta_2 = P(X \\leq x) = 0.75$\n' - f'$\\omega = 5.00$' + f"$\\delta_1 = P(X \\geq x) = 0.25$\n" + f"$\\delta_2 = P(X \\leq x) = 0.75$\n" + f"$\\omega = 5.00$" ) self.assertEqual(annotation, expected_annotation) @@ -1228,16 +1233,18 @@ def test_calculate_spatial_extent(self): def test_create_geo_axes(self): # Test GeoAxes creation with no extent (global) - ax = _create_geo_axes(figsize=(10, 8), extent=None, projection=ccrs.PlateCarree(), - set_global=True) + ax = _create_geo_axes( + figsize=(10, 8), extent=None, projection=ccrs.PlateCarree(), set_global=True + ) self.assertIsInstance(ax, plt.Axes) self.assertAlmostEqual(ax.get_xlim(), (-180, 180)) self.assertAlmostEqual(ax.get_ylim(), (-90, 90)) # Test GeoAxes creation with a specific extent extent = (-125, -110, 25, 40) - ax = _create_geo_axes(figsize=(10, 8), extent=extent, projection=ccrs.PlateCarree(), - set_global=False) + ax = _create_geo_axes( + figsize=(10, 8), extent=extent, projection=ccrs.PlateCarree(), set_global=False + ) self.assertIsInstance(ax, plt.Axes) self.assertAlmostEqual(ax.get_extent(), extent) @@ -1263,8 +1270,8 @@ def test_get_basemap_esri_terrain(self, mock_google_tiles): tiles = _get_basemap("ESRI_terrain") mock_google_tiles.assert_called_once_with( url="https://server.arcgisonline.com/ArcGIS/rest/services/World_Terrain_Base/" - "MapServer/tile/{z}/{y}/{x}.jpg", - cache=True + "MapServer/tile/{z}/{y}/{x}.jpg", + cache=True, ) self.assertIsNotNone(tiles) @@ -1292,55 +1299,30 @@ def test_plot_basemap_no_basemap(self): self.assertIsInstance(ax, plt.Axes) def test_default_colormap(self): - cmap, alpha = _define_colormap_and_alpha("viridis", 0) + cmap, alpha = _get_colormap("viridis", 0) self.assertIsInstance(cmap, matplotlib.colors.ListedColormap) expected_cmap = plt.get_cmap("viridis") self.assertTrue(numpy.allclose(cmap.colors, expected_cmap(numpy.arange(cmap.N)))) def test_custom_colormap(self): cmap = plt.get_cmap("plasma") - cmap, alpha = _define_colormap_and_alpha(cmap, 0) + cmap, alpha = _get_colormap(cmap, 0) self.assertIsInstance(cmap, matplotlib.colors.ListedColormap) expected_cmap = plt.get_cmap("plasma") self.assertTrue(numpy.allclose(cmap.colors, expected_cmap(numpy.arange(cmap.N)))) def test_alpha_exponent(self): - cmap, alpha = _define_colormap_and_alpha("viridis", 0.5) + cmap, alpha = _get_colormap("viridis", 0.5) self.assertIsInstance(cmap, matplotlib.colors.ListedColormap) self.assertIsNone(alpha) # Check that alpha values are correctly modified self.assertTrue(numpy.all(cmap.colors[:, -1] == numpy.linspace(0, 1, cmap.N) ** 0.5)) def test_no_alpha_exponent(self): - cmap, alpha = _define_colormap_and_alpha("viridis", 0) + cmap, alpha = _get_colormap("viridis", 0) self.assertEqual(alpha, 1) self.assertTrue(numpy.all(cmap.colors[:, -1] == 1)) # No alpha modification - def test_add_colorbar(self): - fig, ax = plt.subplots() - im = ax.imshow(numpy.random.rand(10, 10), cmap="viridis") - _add_colorbar( - ax, im, clabel="Colorbar Label", clabel_fontsize=12, cticks_fontsize=10 - ) - - # Check if the colorbar is added to the figure - colorbars = [ - child - for child in fig.get_children() - if isinstance(child, plt.Axes) and "Colorbar" in child.get_label() - ] - self.assertGreater(len(colorbars), 0) - - # Check colorbar label and font sizes - cbar = colorbars[0] - self.assertEqual(cbar.get_ylabel(), "Colorbar Label") - self.assertEqual(cbar.get_ylabel(), "Colorbar Label") - self.assertEqual(cbar.yaxis.label.get_size(), 12) - - # Check tick label font size - tick_labels = cbar.get_yticklabels() - self.assertTrue(all(label.get_fontsize() == 10 for label in tick_labels)) - def tearDown(self): plt.close("all") gc.collect() @@ -1383,11 +1365,11 @@ class TestProcessDistribution(unittest.TestCase): def setUp(self): self.result_poisson = mock.Mock() - self.result_poisson.test_distribution = ['poisson', 10] + self.result_poisson.test_distribution = ["poisson", 10] self.result_poisson.observed_statistic = 8 self.result_neg_binom = mock.Mock() - self.result_neg_binom.test_distribution = ['negative_binomial', 10] + self.result_neg_binom.test_distribution = ["negative_binomial", 10] self.result_neg_binom.observed_statistic = 8 self.result_empirical = mock.Mock() @@ -1396,8 +1378,11 @@ def setUp(self): def test_process_distribution_poisson(self): plow, phigh, mean, observed_statistic = _process_stat_distribution( - self.result_poisson, percentile=95, variance=None, normalize=False, - one_sided_lower=False + self.result_poisson, + percentile=95, + variance=None, + normalize=False, + one_sided_lower=False, ) self.assertAlmostEqual(mean, 10) self.assertAlmostEqual(observed_statistic, 8) @@ -1406,8 +1391,11 @@ def test_process_distribution_poisson(self): def test_process_distribution_negative_binomial(self): variance = 12 plow, phigh, mean, observed_statistic = _process_stat_distribution( - self.result_neg_binom, percentile=95, variance=variance, normalize=False, - one_sided_lower=False + self.result_neg_binom, + percentile=95, + variance=variance, + normalize=False, + one_sided_lower=False, ) self.assertAlmostEqual(mean, 10) self.assertAlmostEqual(observed_statistic, 8) @@ -1415,8 +1403,11 @@ def test_process_distribution_negative_binomial(self): def test_process_distribution_empirical(self): plow, phigh, mean, observed_statistic = _process_stat_distribution( - self.result_empirical, percentile=95, variance=None, normalize=False, - one_sided_lower=False + self.result_empirical, + percentile=95, + variance=None, + normalize=False, + one_sided_lower=False, ) self.assertAlmostEqual(mean, numpy.mean(self.result_empirical.test_distribution)) self.assertAlmostEqual(observed_statistic, 8) @@ -1424,18 +1415,29 @@ def test_process_distribution_empirical(self): def test_process_distribution_empirical_normalized(self): plow, phigh, mean, observed_statistic = _process_stat_distribution( - self.result_empirical, percentile=95, variance=None, normalize=True, - one_sided_lower=False + self.result_empirical, + percentile=95, + variance=None, + normalize=True, + one_sided_lower=False, + ) + self.assertAlmostEqual( + mean, + numpy.mean( + self.result_empirical.test_distribution + - self.result_empirical.observed_statistic + ), ) - self.assertAlmostEqual(mean, numpy.mean(self.result_empirical.test_distribution - - self.result_empirical.observed_statistic)) self.assertAlmostEqual(observed_statistic, 0) self.assertTrue(plow < mean < phigh) def test_process_distribution_empirical_one_sided(self): plow, phigh, mean, observed_statistic = _process_stat_distribution( - self.result_empirical, percentile=95, variance=None, normalize=False, - one_sided_lower=True + self.result_empirical, + percentile=95, + variance=None, + normalize=False, + one_sided_lower=True, ) self.assertAlmostEqual(mean, numpy.mean(self.result_empirical.test_distribution)) self.assertAlmostEqual(observed_statistic, 8)