From e2cec3c69d7b5c378a234213589fc9d81c8954ff Mon Sep 17 00:00:00 2001 From: s2123329 Date: Fri, 7 Jun 2024 11:53:41 +0100 Subject: [PATCH] Add hist df TablePlot class --- package/ClayCode/plot/plots.py | 555 ++++++++++++++++++++------------- 1 file changed, 342 insertions(+), 213 deletions(-) diff --git a/package/ClayCode/plot/plots.py b/package/ClayCode/plot/plots.py index 9f90c18..ca17bbc 100644 --- a/package/ClayCode/plot/plots.py +++ b/package/ClayCode/plot/plots.py @@ -14,7 +14,7 @@ from collections import UserString from functools import cached_property, partialmethod, wraps from pathlib import Path -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Type, Union import matplotlib.colors as mpc import matplotlib.gridspec as gridspec @@ -24,6 +24,8 @@ import pandas as pd import scipy.optimize from ClayCode.analysis.analysisbase import AnalysisData +from ClayCode.analysis.classes import HistData +from ClayCode.analysis.consts import PE_DATA # from ClayCode.analysis.classes import AtomTypeData, HistData from ClayCode.analysis.utils import ( @@ -46,6 +48,9 @@ from tqdm import tqdm from tqdm.contrib.logging import logging_redirect_tqdm +# from ClayCode.analysis.peaks import Peaks + + tqdm.pandas(desc="pandas") __all__ = [ @@ -61,6 +66,18 @@ logger = logging.getLogger(__name__) +def plot_args_decorator(select: str): + def plot_decorator(plot_f): + def wrapper(self, **kwargs): + self.select = select + plot_f(self, **kwargs) + self.select = None + + return wrapper + + return plot_decorator + + # ch = logging.StreamHandler() # ch.setLevel(logging.DEBUG) # logger.addHandler(ch) @@ -5470,19 +5487,25 @@ def __init__( use_rel_data=False, group_all=False, max_line_width=LINE_LENGTH, + other_cutoff=None, ): """Constructor method""" logger.info(f"Initialising {self.__class__.__name__}") self.infilelist: list = [] self.bins: Bins = Bins(bins) self.cutoff: float = Cutoff(cutoff) + if other is not None: + if other_cutoff is not None: + self.other_cutoff = Cutoff(other_cutoff) + else: + self.other_cutoff = self.cutoff self.analysis: Union[str, None] = analysis self._min = Cutoff(min) if use_rel_data is True: self._arr_col = 2 else: self._arr_col = 1 - + self.use_rel_data = use_rel_data if type(indir) != Path: indir = Path(indir) @@ -5519,7 +5542,9 @@ def __init__( ) ) logger.info(f"Found {len(self.filelist)} files.") - + if len(self.filelist) == 0: + logger.error(f"No files found in {indir!r}") + sys.exit(1) if load is not False: load = Path(load.resolve()) self.df: pd.DataFrame = pkl.load(load) @@ -5570,9 +5595,12 @@ def __init__( f, delimiter="\s+", comment="#", header=None ).to_numpy() new_x = new_x[:, 0] - if not np.isclose( - np.ceil(new_x[-1]), self.cutoff.num - ) or not np.any(np.isclose(np.diff(new_x), self.bins.num)): + if ( + not np.isclose(np.ceil(new_x[-1]), self.cutoff.num) + or not np.any(np.isclose(np.diff(new_x), self.bins.num)) + ) and not np.isclose( + np.ceil(new_x[-1]), self.other_cutoff.num, atol=0.01 + ): self.filelist.pop(fi) logger.info( f"Bins ({np.unique(np.diff(new_x).round(2))}) and cutoff {np.ceil(new_x[-1])} do not match requirements!\n" @@ -6061,7 +6089,7 @@ def _get_edge_fname( atom_type: str, other: Optional[str], clay_type="all", - name: Union[Literal["pe"], Literal["edge"]] = "pe", + name: Union[Literal["pe"], Literal["edge"]] = "edges", ): if other is not None: other = f"{other}_" @@ -6073,8 +6101,10 @@ def _get_edge_fname( clay_type = f"{clay_type}_" # fname = Path.cwd() / f"edge_data/edges_{atom_type}_{self.cutoff}_{self.bins}.p" fname = ( - Path(__file__).parent.parent - / f"analysis/pe_data/{clay_type}{atom_type}_{other}{name}_data_{self._min}_{self.cutoff}_{self.bins}.p" + # Path(__file__).parent.parent + # / f"analysis/pe_data/" + PE_DATA + / f"{clay_type}{atom_type}_{other}{name}_data_{self._min}_{self.cutoff}_{self.bins}.p" ) logger.info(f"Peak/edge Filename: {fname}\n") return fname @@ -6288,6 +6318,8 @@ def _read_edge_file( ): clay_sel = np.unique([clay_type, "all"]) for clay_type in clay_sel: + if other is not None and isinstance(other, list): + other = other[0] fname = self._get_edge_fname( atom_type, name="pe", other=other, clay_type=clay_type ) @@ -6358,21 +6390,41 @@ def edges(self): else: for atom_type in self._atoms: self._edges[atom_type] = {} - self._edges[atom_type]["all"] = self._read_edge_file(atom_type) + self._edges[atom_type]["all"] = self._read_edge_file( + atom_type, other=None + ) if self.other is not None: - self._edges[atom_type][self.other] = self._read_edge_file( - atom_type, other=self.other - ) + if isinstance(self.other, list): + for other in self.other: + self._edges[atom_type][other] = {} + self._edges[atom_type][other][ + "all" + ] = self._read_edge_file(atom_type, other=other) + else: + self._edges[atom_type][self.other] = {} + self._edges[atom_type][self.other][ + "all" + ] = self._read_edge_file(atom_type, other=self.other) for clay_type in self.clays: self._edges[atom_type][clay_type] = self._read_edge_file( atom_type, clay_type=clay_type ) if self.other is not None: - self._edges[atom_type][ - self.other - ] = self._read_edge_file( - atom_type, other=self.other, clay_type=clay_type - ) + if isinstance(self.other, list): + for other in self.other: + self._edges[atom_type][other][ + clay_type + ] = self._read_edge_file( + atom_type, other=other, clay_type=clay_type + ) + else: + self._edges[atom_type][ + self.other + ] = self._read_edge_file( + atom_type, + other=self.other, + clay_type=clay_type, + ) return self._edges def get_bin_df(self): @@ -7696,7 +7748,7 @@ def save(self, savename: Union[str, Path]) -> None: pd.to_pickle(self.df, savename) @property - def z_bins(self) -> HistData: + def z_bins(self) -> Type["HistData"]: """Get z bins :return: z bins :rtype: HistData @@ -7813,8 +7865,9 @@ def _get_edge_fname( other = "" # fname = Path.cwd() / f"edge_data/edges_{atom_type}_{self.cutoff}_{self.bins}.p" fname = ( - Path.cwd() - / f"pe_data/{atom_type}_{other}{name}_data_{self.cutoff}_{self.bins}.p" + # Path.cwd() + PE_DATA + / f"{atom_type}_{other}{name}_data_{self.cutoff}_{self.bins}.p" ) logger.info(f"Peak/edge Filename: {fname}") return fname @@ -9120,28 +9173,30 @@ def get_atom_colour_codes_from_names(self, atom_name): def add_axis_labels(self, y, x, rowlabel, columnlabel): """Add labels to plot.""" for y_id in range(self.y.l): - if y_id == self.y.l - 1: - y_ax_label = f"{self.label_mod([(self.y.v[y_id], self.y)])}\n" - if self.y.l > 1: - self.fig.supylabel(f"{self.title_dict[y]}s", size=14) - else: - # if not re.match("\nctl|ctl\n", y_ax_label, flags=re.IGNORECASE): - # self.fig.supylabel(f"{self.title_dict[y]}: {y_ax_label}", size=14) - # else: - # self.fig.supylabel(f"{self.title_dict[y]}", size=14) - self.name = re.sub( - rf"(.*)_{y}_(.*)", - r"\1_" + y_ax_label.strip("\n") + r"_\2", - self.name, - ) - y_ax_label = "" - try: - self.ax[y_id, 0].set_ylabel(y_ax_label + rowlabel) - except IndexError: - self.ax[y_id].set_ylabel(y_ax_label + rowlabel) - except TypeError: - self.ax.set_ylabel(y_ax_label + rowlabel) for x_id in range(self.x.l): + if x_id == 0: + y_ax_label = ( + f"{self.label_mod([(self.y.v[y_id], self.y)])}\n" + ) + if self.y.l > 1: + self.fig.supylabel(f"{self.title_dict[y]}s", size=14) + else: + # if not re.match("\nctl|ctl\n", y_ax_label, flags=re.IGNORECASE): + # self.fig.supylabel(f"{self.title_dict[y]}: {y_ax_label}", size=14) + # else: + # self.fig.supylabel(f"{self.title_dict[y]}", size=14) + self.name = re.sub( + rf"(.*)_{y}_(.*)", + r"\1_" + y_ax_label.strip("\n") + r"_\2", + self.name, + ) + y_ax_label = "" + try: + self.ax[y_id, 0].set_ylabel(y_ax_label + rowlabel) + except IndexError: + self.ax[y_id].set_ylabel(y_ax_label + rowlabel) + except TypeError: + self.ax.set_ylabel(y_ax_label + rowlabel) if y_id == self.y.l - 1: x_ax_label = ( f"\n{self.label_mod([(self.x.v[x_id], self.x)])}" @@ -9219,21 +9274,53 @@ def add_plot_labels(self, lines): ] if ( unique_legends.ndim == 1 - and handle_colours.ndim == 1 + # and handle_colours.ndim == 1 and handle_linestyles.ndim == 1 ): if i == 0 and j == self.x.l - 1: - ncol = 1 - self.fig.legend( - labels=label, - handles=handles, - ncol=ncol, - title=self.title_dict[lines], - frameon=False, - bbox_to_anchor=(1, 1), - loc="upper left", - borderaxespad=0.0, - ) + # ncol = 1 + # self.fig.legend( + # labels=label, + # handles=handles, + # ncol=ncol, + # title=self.title_dict[lines], + # frameon=False, + # bbox_to_anchor=(1, 1), + # loc="upper left", + # borderaxespad=0.0, + # ) + if len(label) % 3 == 0: + ncol = 3 + else: + ncol = 2 + try: + self.ax[i, j].legend( + labels=label, + handles=handles, + ncol=ncol, + title=self.title_dict[lines], + frameon=False, + loc="upper right", + ) + except IndexError: + max_id = max(i, j) + self.ax[max_id].legend( + labels=label, + handles=handles, + ncol=ncol, + title=self.title_dict[lines], + frameon=False, + loc="upper right", + ) + except TypeError: + self.ax.legend( + labels=label, + handles=handles, + ncol=ncol, + title=self.title_dict[lines], + frameon=False, + loc="upper right", + ) else: if len(label) % 3 == 0: ncol = 3 @@ -9489,20 +9576,18 @@ def _get_binned_plot_colour_dfs_1d( ) # dict(zip(colour_keys, colours)) # colours = color_palette('dark').as_hex() sel = self.data.clays - # get data for plotting + bin_df = self.data.bin_df.copy() try: # clays still in columns - plot_df = self.data.bin_df[sel].copy() - + plot_df = bin_df[sel] # move clays category from columns to index idx_names = ["clays", *plot_df.index.droplevel(["x_bins"]).names] # DataFrame -> Series plot_df = plot_df.stack() except KeyError: - plot_df = self.data.bin_df.copy() + plot_df = bin_df idx_names = plot_df.index.droplevel(["x_bins"]).names - # get values for atom types (including separate ions) atoms: np.array = plot_df.index.get_level_values("_atoms").to_numpy() # atom_type_groups = np.unique(atoms) @@ -10311,16 +10396,18 @@ def plot( if plsave != False: logger.info("Saving plot") if type(plsave) == str: - outname = f"{plsave}.{format}" + outname = f"{plsave}.{format.strip('.')}" else: - outname = f"{self.name}.{format}" + outname = f"{self.name}.{format.strip('.')}" odir = Path(odir).absolute() logger.info(f"output to {odir.absolute()}") if not odir.is_dir(): os.makedirs(odir) logger.info(odir / outname) self.fig.savefig( - str(odir / outname), format=format, bbox_inches="tight" + str(odir / outname), + format=format.strip("."), + bbox_inches="tight", ) else: if type(plsave) == str: @@ -10337,7 +10424,7 @@ def plot( {"fig": self.fig, "ax": self.ax, "data": self.data.df}, pklfile, ) - self.fig.show() + # self.fig.show() self.fig.clear() def _get_bin_label(self, x_bin, bin_list): @@ -10538,46 +10625,46 @@ def __init__( # self.__setattr__(attr, new_data.__getattribute__(attr)) super().__init__(new_data) - def add_plot_labels(self, lines): - for i in range(self.y.l): - for j in range(self.x.l): - label = [ - self.label_mod([(leg, self.line)]) - for leg in self.legends[i, j] - ] - handle = np.ravel(self.handles[(i, j)]).tolist() - if len(label) % 3 == 0: - ncol = 3 - else: - ncol = 2 - try: - self.ax[i, j].legend( - labels=label, - handles=handle, - ncol=ncol, - title=self.title_dict[lines], - frameon=False, - loc="upper right", - ) - except IndexError: - max_id = max(i, j) - self.ax[max_id].legend( - labels=label, - handles=handle, - ncol=ncol, - title=self.title_dict[lines], - frameon=False, - loc="upper right", - ) - except TypeError: - self.ax.legend( - labels=label, - handles=handle, - ncol=ncol, - title=self.title_dict[lines], - frameon=False, - loc="upper right", - ) + # def add_plot_labels(self, lines): + # for i in range(self.y.l): + # for j in range(self.x.l): + # label = [ + # self.label_mod([(leg, self.line)]) + # for leg in self.legends[i, j] + # ] + # handle = np.ravel(self.handles[(i, j)]).tolist() + # if len(label) % 3 == 0: + # ncol = 3 + # else: + # ncol = 2 + # try: + # self.ax[i, j].legend( + # labels=label, + # handles=handle, + # ncol=ncol, + # title=self.title_dict[lines], + # frameon=False, + # loc="upper right", + # ) + # except IndexError: + # max_id = max(i, j) + # self.ax[max_id].legend( + # labels=label, + # handles=handle, + # ncol=ncol, + # title=self.title_dict[lines], + # frameon=False, + # loc="upper right", + # ) + # except TypeError: + # self.ax.legend( + # labels=label, + # handles=handle, + # ncol=ncol, + # title=self.title_dict[lines], + # frameon=False, + # loc="upper right", + # ) def _get_binned_plot_colour_dfs_1d( self, @@ -10819,7 +10906,7 @@ def plot_other( smooth_line=0.01, antialiased=True, colours=None, - format="png", + format=".png", ): self.select = "other" self.plot( @@ -11519,15 +11606,15 @@ def plot( if plsave != False: logger.info("Saving plot") if type(plsave) == str: - outname = f"{plsave}{format}" + outname = f"{plsave}.{format.strip('.')}" else: - outname = f"{self.name}{format}" + outname = f"{self.name}.{format.strip('.')}" odir = Path(odir).absolute() logger.info(f"output to {odir.absolute()}") if not odir.is_dir(): os.makedirs(odir) logger.info(odir / outname) - self.fig.savefig(str(odir / outname), format=format) + self.fig.savefig(str(odir / outname), format=format.strip(".")) else: if type(plsave) == str: outname = f"{plsave}.p" @@ -11543,7 +11630,7 @@ def plot( {"fig": self.fig, "ax": self.ax, "data": self.data.df}, pklfile, ) - self.fig.show() + # self.fig.show() self.fig.clear() def _get_bin_label(self, x_bin, bin_list): @@ -12135,15 +12222,15 @@ def plot( if plsave != False: logger.info("Saving plot") if type(plsave) == str: - outname = f"{plsave}{format}" + outname = f"{plsave}.{format.strip('.')}" else: - outname = f"{self.name}{format}" + outname = f"{self.name}.{format.strip('.')}" odir = Path(odir).absolute() logger.info(f"output to {odir.absolute()}") if not odir.is_dir(): os.makedirs(odir) logger.info(odir / outname) - self.fig.savefig(str(odir / outname), format=format) + self.fig.savefig(str(odir / outname), format=format.strip(".")) else: plt.show() self.fig.clear() @@ -12532,7 +12619,7 @@ def plot_ions( xpad=0.25, cmap="winter", plot_table=False, - only_table=False, + # only_table=False, ): self.plot( bars, @@ -12551,7 +12638,7 @@ def plot_ions( xpad, cmap, plot_table=plot_table, - only_table=only_table, + # only_table=only_table, ) @plot_args_decorator(select="other") @@ -12573,7 +12660,7 @@ def plot_other( xpad=0.25, cmap="winter", plot_table=None, - only_table=False, + # only_table=False, ): self.plot( bars, @@ -12592,7 +12679,7 @@ def plot_other( xpad, cmap, plot_table=plot_table, - only_table=only_table, + # only_table=only_table, ) def get_suptitle(self, pl, separate): @@ -12624,7 +12711,7 @@ def plot( cmap="winter", tab_fontsize=12, plot_table=False, - only_table=False, + # only_table=False, format="png", ): """Create stacked Histogram adsorption shell populations.""" @@ -12676,7 +12763,7 @@ def plot( yid = np.ravel(np.where(np.array(idx) == self.y))[0] if figsize is None: - if plot_table is True or only_table is True: + if plot_table is True: figsize = self.get_figsize(xmax=xmax, ymax=ymax * 2) else: figsize = self.get_figsize(xmax=xmax, ymax=ymax) @@ -12711,8 +12798,8 @@ def plot( self.name = f"{self.data.name}_{self.data.analysis}_{self.select}_{x}_{y}_{bars}{pl_str}_{self.data.cutoff}_{self.data.bins}" if plot_table is True: self.name += "_table" - elif only_table is True: - self.name += "_only_table" + # elif only_table is True: + # self.name += "_only_table" logger.info(f"plot {self.name}") # index map for y values # y_dict: dict = dict(zip(vy, np.arange(ly))) @@ -12727,7 +12814,7 @@ def plot( # print('rows ', plt_nrows) # generate figure and axes array - if plot_table is True or only_table is True: + if plot_table is True: # or only_table is True: self.fig = plt.figure(figsize=figsize, dpi=dpi) else: self.fig = plt.figure( @@ -12825,7 +12912,7 @@ def plot( tab_colours = [] tab_rows = [] if ( - plot_table is True or only_table is True + plot_table is True # or only_table is True ) and bar_num == 0: # tab_colours = [] # tab_rows = [] @@ -12913,7 +13000,9 @@ def plot( # label = f'$ > {x_bin.left}$ \AA' # if label not in table_rows and cmap == table_cmap: # table_rows.append(label) - if y_val >= 0.001 and only_table is False: + if ( + y_val >= 0.001 + ): # and only_table is False: # barwidth = bulk_edge - x_bin.left # try: # x_tick = x_ticks[-1] + barwidth @@ -12946,7 +13035,9 @@ def plot( # self.fig.subplots_adjust(left=0.2, bottom=0.2) # except IndexError: # self.ax[y_id].subplots_adjust(left=0.2, bottom=0.2) - if plot_table is True or only_table is True: + if ( + plot_table is True + ): # or only_table is True: # y_id += 1 # print(tab_colours) logger.info(f"Has table, {y_id}") @@ -13057,90 +13148,85 @@ def plot( # n_bar * bulk_edge + bulk_edge, int(bulk_edge)) for n_bar in range(lbars)] # x_ticks = np.ravel(x_ticks) # x_labels = np.tile(np.arange(0, bulk_edge, 1), lbars) - if only_table is False: - for i in range(self.y.l): - # if plot_table is True: - # ax_multi = 2 - # else: - # ax_multi = 1 - # ax_i = i * ax_multi - # print(f"Axis index: {ax_i}, multi: {ax_multi}") + # if only_table is False: + for i in range(self.y.l): + # if plot_table is True: + # ax_multi = 2 + # else: + # ax_multi = 1 + # ax_i = i * ax_multi + # print(f"Axis index: {ax_i}, multi: {ax_multi}") - try: - # self.ax[] - # self.ax[i, 0].set_ylabel( + try: + # self.ax[] + # self.ax[i, 0].set_ylabel( + # f"{self.label_mod([(self.y.v[i], self.y)])}\n" + rowlabel + # ) + for j in range(self.x.l): + # if plot_table is True: + # self.ax[i, j].subplots_adjust(bottom=0.2) # f"{self.label_mod([(self.y.v[i], self.y)])}\n" + rowlabel # ) - for j in range(self.x.l): - # if plot_table is True: - # self.ax[i, j].subplots_adjust(bottom=0.2) - # f"{self.label_mod([(self.y.v[i], self.y)])}\n" + rowlabel - # ) - if j == 0: - self.ax[i, 0].set_ylabel( - f"{self.label_mod([(self.y.v[i], self.y)])}\n" - + rowlabel - ) - self.ax[i, j].set_yticks( - np.arange(0.0, 1.1, 0.2) - ) - else: - self.ax[i, j].set_yticks( - np.arange(0.0, 1.1, 0.2) - ) - self.ax[i, j].set_yticklabels( - [] - ) # np.arange(0.0, 1.1, 0.2)) - self.ax[i, j].spines[["top", "right"]].set_visible( - False - ) - self.ax[i, j].hlines( - 1.0, - -xpad, - self.bars.l * (barwidth + xpad) + xpad, - linestyle="--", - ) - # self.ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) - # if xlim != None: - self.ax[i, j].set_xlim( - (-xpad, self.bars.l * (barwidth + xpad)) - ) - self.ax[i, j].set_xticks([], []) - # if ylim != None: - self.ax[i, j].set_ylim((0.0, 1.2)) - self.ax[self.y.l - 1, j].set_xticks( - np.array(x_ticks) + 0.5 * barwidth, x_labels + if j == 0: + self.ax[i, 0].set_ylabel( + f"{self.label_mod([(self.y.v[i], self.y)])}\n" + + rowlabel ) - self.ax[self.y.l - 1, j].set_xlabel( - bars - + f"\n{self.label_mod([(self.x.v[j], self.x)])}" - ) # self.ax[i, j].set_yticklabels(np.arange(0.0, 1.1, 0.2)) - except IndexError: - self.ax[i].set_ylabel( - f"{self.label_mod([(self.y.v[i], y)])}\n" - + rowlabel + self.ax[i, j].set_yticks(np.arange(0.0, 1.1, 0.2)) + else: + self.ax[i, j].set_yticks(np.arange(0.0, 1.1, 0.2)) + self.ax[i, j].set_yticklabels( + [] + ) # np.arange(0.0, 1.1, 0.2)) + self.ax[i, j].spines[["top", "right"]].set_visible( + False ) - # self.ax[i].legend([self.label_mod([(leg, self.label_key)]) for leg in self.legends[i, 0]], ncol=3) - self.ax[self.y.l - 1].set_xlabel( - columnlabel - + f"\n{self.label_mod([(self.x.v[0], self.x)])}" + self.ax[i, j].hlines( + 1.0, + -xpad, + self.bars.l * (barwidth + xpad) + xpad, + linestyle="--", ) - # # # - self.fig.supxlabel(f"{self.title_dict[x]}s", size=14) - self.fig.supylabel(f"{self.title_dict[y]}s", size=14) + # self.ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) + # if xlim != None: + self.ax[i, j].set_xlim( + (-xpad, self.bars.l * (barwidth + xpad)) + ) + self.ax[i, j].set_xticks([], []) + # if ylim != None: + self.ax[i, j].set_ylim((0.0, 1.2)) + self.ax[self.y.l - 1, j].set_xticks( + np.array(x_ticks) + 0.5 * barwidth, x_labels + ) + self.ax[self.y.l - 1, j].set_xlabel( + bars + + f"\n{self.label_mod([(self.x.v[j], self.x)])}" + ) # self.ax[i, j].set_yticklabels(np.arange(0.0, 1.1, 0.2)) + except IndexError: + self.ax[i].set_ylabel( + f"{self.label_mod([(self.y.v[i], y)])}\n" + rowlabel + ) + # self.ax[i].legend([self.label_mod([(leg, self.label_key)]) for leg in self.legends[i, 0]], ncol=3) + self.ax[self.y.l - 1].set_xlabel( + columnlabel + + f"\n{self.label_mod([(self.x.v[0], self.x)])}" + ) + # # # + self.fig.supxlabel(f"{self.title_dict[x]}s", size=14) + self.fig.supylabel(f"{self.title_dict[y]}s", size=14) # if plsave is True: logger.info("Saving plot") if type(plsave) == str and plsave != "": - outname = f"{plsave}{format}" + outname = f"{plsave}.{format.strip('.')}" else: - outname = f"{self.name}{format}" + outname = f"{self.name}.{format.strip('.')}" odir = Path(odir).absolute() logger.info(f"output to {odir.absolute()}") if not odir.is_dir(): os.makedirs(odir) logger.info(odir) logger.info(outname) - self.fig.savefig(str(odir / outname), format=format) + self.fig.savefig(str(odir / outname), format=format.strip(".")) # else: # plt.show() self.fig.clear() @@ -13158,8 +13244,44 @@ class TablePlot(HistPlot): def __init__(self, data, add_missing_bulk=True, **kwargs): super().__init__(data, add_missing_bulk, **kwargs) - def df(self): - print(self.plot_df) + @plot_args_decorator(select="ions") + def plot_ions(self, odir, plsave=True): + self.plot(plsave=plsave, odir=odir) + + @plot_args_decorator(select="other") + def plot_other(self, odir, plsave=True): + self.plot(plsave=plsave, odir=odir) + + def plot(self, odir, plsave=True): + abs_str = {False: "absoute", True: "relative"}[self.data.use_rel_data] + self.name = ( + f"{self.data.name}_{self.data.analysis}_{self.select}_{abs_str}" + f"_{self.data.cutoff}_{self.data.bins}" + ) + if plsave is not False: + logger.info("Saving data") + if type(plsave) == str: + outname = f"{plsave}.csv" + else: + outname = f"{self.name}.csv" + odir = Path(odir).absolute() + logger.info(f"output to {odir.absolute()}") + if not odir.is_dir(): + os.makedirs(odir) + logger.info(odir / outname) + self.plot_df.to_csv(odir / outname) + else: + outname = f"{self.name}.p" + odir = Path(odir).absolute() + logger.info(f"output to {odir.absolute()}") + if not odir.is_dir(): + os.makedirs(odir) + logger.info(odir / outname) + with open(odir / outname, "wb") as pklfile: + pkl.dump( + {"data": self.data.df, "plot_df": self.plot_df}, + pklfile, + ) class GaussHistPlot(Plot): @@ -14321,16 +14443,16 @@ def plot( if plsave is not False: logger.info("Saving plot") if type(plsave) == str and plsave != "": - outname = f"{plsave}{format}" + outname = f"{plsave}.{format}.strip('.')" else: - outname = f"{self.name}{format}" + outname = f"{self.name}.{format.strip('.')}" odir = Path(odir).absolute() logger.info(f"output to {odir.absolute()}") if not odir.is_dir(): os.makedirs(odir) logger.info(odir) logger.info(outname) - self.fig.savefig(str(odir / outname), format=format) + self.fig.savefig(str(odir / outname), format=format.strip(".")) else: plt.show() self.fig.clear() @@ -16438,7 +16560,7 @@ def make_hist2d(self, other=None, bins=None, z_bins=None): ylim=args.ymax, plot_table=args.table, format=args.format, - only_table=args.only_table, + # only_table=args.only_table, ) elif args.lines: plot = LinePlot(data) @@ -16461,22 +16583,29 @@ def make_hist2d(self, other=None, bins=None, z_bins=None): format=args.format, ) elif args.bars: - plot = HistPlot(data) # , add_missing_bulk=args.add_missing_bulk) - for s in args.sel: - plot_method[s](plot)( - x=args.x, - y=args.y, - bars=args.plsel, - dpi=200, - columnlabel=rf"{args.xlabel}", - # r"distance from surface (\AA)", - rowlabel=rf"{args.ylabel}", # 'closest atom type',# r"$\rho_z$ ()", - plsave=plsave, - odir=args.odir, # "/storage/plots/aadist_u/", - ylim=args.ymax, - plot_table=args.table, - only_table=args.only_table, - ) + if args.only_table is True: + plot = TablePlot(data) + for s in args.sel: + plot_method[s](plot)(plsave=plsave, odir=args.odir) + else: + plot = HistPlot( + data + ) # , add_missing_bulk=args.add_missing_bulk) + for s in args.sel: + plot_method[s](plot)( + x=args.x, + y=args.y, + bars=args.plsel, + dpi=200, + columnlabel=rf"{args.xlabel}", + # r"distance from surface (\AA)", + rowlabel=rf"{args.ylabel}", # 'closest atom type',# r"$\rho_z$ ()", + plsave=plsave, + odir=args.odir, # "/storage/plots/aadist_u/", + ylim=args.ymax, + plot_table=args.table, + # only_table=args.only_table, + ) else: plot = HistPlot2D(data, col_sel=["rdens"], select="ions") plot.plot_ions(