diff --git a/rubin_sim/maf/plots/special_plotters.py b/rubin_sim/maf/plots/special_plotters.py index c9e579cf..767964bc 100644 --- a/rubin_sim/maf/plots/special_plotters.py +++ b/rubin_sim/maf/plots/special_plotters.py @@ -64,11 +64,11 @@ def __call__(self, metric_value, slicer, user_plot_dict, fig=None): plot_dict["scale"] = hp.nside2pixarea(slicer.nside, degrees=True) / 1000.0 if fig is None: - fig = plt.Figure(figsize=plot_dict["figsize"]) + fig, ax = plt.subplots(figsize=plot_dict["figsize"]) # Expect metric_value to be something like number of visits cumulative_area = np.arange(1, metric_value.compressed().size + 1)[::-1] * plot_dict["scale"] - plt.plot( + ax.plot( np.sort(metric_value.compressed()), cumulative_area, "k-", @@ -83,13 +83,13 @@ def __call__(self, metric_value, slicer, user_plot_dict, fig=None): f_o_area = metrics.FOArea(col="fO", n_visit=n_visits, norm=False, nside=slicer.nside).run(rarr) f_o_nv = metrics.FONv(col="fO", asky=asky, norm=False, nside=slicer.nside).run(rarr) - plt.axvline(x=n_visits, linewidth=plot_dict["reflinewidth"], color="b", linestyle=":") - plt.axhline(y=asky / 1000.0, linewidth=plot_dict["reflinewidth"], color="r", linestyle=":") + ax.axvline(x=n_visits, linewidth=plot_dict["reflinewidth"], color="b", linestyle=":") + ax.axhline(y=asky / 1000.0, linewidth=plot_dict["reflinewidth"], color="r", linestyle=":") # Add lines for nvis_median and f_o_area: # note if these are -666 (badval), they will 'disappear' nvis_median = f_o_nv["value"][np.where(f_o_nv["name"] == "MedianNvis")][0] - plt.axvline( + ax.axvline( x=nvis_median, linewidth=plot_dict["reflinewidth"], color="b", @@ -97,7 +97,7 @@ def __call__(self, metric_value, slicer, user_plot_dict, fig=None): linestyle="-", label=f"f$_0$ Med. Nvis. (@ {asky/1000 :.0f}K sq deg) = {nvis_median :.0f} visits", ) - plt.axhline( + ax.axhline( y=f_o_area / 1000.0, linewidth=plot_dict["reflinewidth"], color="r", @@ -105,20 +105,20 @@ def __call__(self, metric_value, slicer, user_plot_dict, fig=None): linestyle="-", label=f"f$_0$ Area (@ {n_visits :.0f} visits) = {f_o_area/1000 :.01f}K sq deg", ) - plt.legend(loc="upper right", fontsize="small", numpoints=1, framealpha=1.0) + ax.legend(loc="upper right", fontsize="small", numpoints=1, framealpha=1.0) - plt.xlabel(plot_dict["xlabel"]) - plt.ylabel(plot_dict["ylabel"]) - plt.title(plot_dict["title"]) + ax.set_xlabel(plot_dict["xlabel"]) + ax.set_ylabel(plot_dict["ylabel"]) + ax.set_title(plot_dict["title"]) x_min = plot_dict["x_min"] x_max = plot_dict["x_max"] y_min = plot_dict["y_min"] y_max = plot_dict["y_max"] if (x_min is not None) or (x_max is not None): - plt.xlim([x_min, x_max]) + ax.set_xlim([x_min, x_max]) if (y_min is not None) or (y_max is not None): - plt.ylim([y_min, y_max]) + ax.set_ylim([y_min, y_max]) return fig