diff --git a/pygmt/src/meca.py b/pygmt/src/meca.py index a2cd6156b53..3b3e4cbbfec 100644 --- a/pygmt/src/meca.py +++ b/pygmt/src/meca.py @@ -6,7 +6,13 @@ import pandas as pd from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list, fmt_docstring, kwargs_to_strings, use_alias +from pygmt.helpers import ( + build_arg_list, + data_kind, + fmt_docstring, + kwargs_to_strings, + use_alias, +) from pygmt.src._common import _FocalMechanismConvention @@ -25,6 +31,82 @@ def _get_focal_convention(spec, convention, component) -> _FocalMechanismConvent return _FocalMechanismConvention(convention=convention, component=component) +def _preprocess_spec(spec, colnames, override_cols): + """ + Preprocess the input data. + + Parameters + ---------- + spec + The input data to be preprocessed. + colnames + The minimum required column names of the input data. + override_cols + Dictionary of column names and values to override in the input data. Only makes + sense if ``spec`` is a dict or :class:`pandas.DataFrame`. + """ + kind = data_kind(spec) # Determine the kind of the input data. + + # Convert pandas.DataFrame and numpy.ndarray to dict. + if isinstance(spec, pd.DataFrame): + spec = {k: v.to_numpy() for k, v in spec.items()} + elif isinstance(spec, np.ndarray): + spec = np.atleast_2d(spec) + # Optional columns that are not required by the convention. The key is the + # number of extra columns, and the value is a list of optional column names. + extra_cols = { + 0: [], + 1: ["event_name"], + 2: ["plot_longitude", "plot_latitude"], + 3: ["plot_longitude", "plot_latitude", "event_name"], + } + ndiff = spec.shape[1] - len(colnames) + if ndiff not in extra_cols: + msg = f"Input array must have {len(colnames)} or two/three more columns." + raise GMTInvalidInput(msg) + spec = dict(zip([*colnames, *extra_cols[ndiff]], spec.T, strict=False)) + + # Now, the input data is a dict or an ASCII file. + if isinstance(spec, dict): + # The columns can be overridden by the parameters given in the function + # arguments. Only makes sense for dict/pandas.DataFrame input. + if kind != "matrix" and override_cols is not None: + spec.update({k: v for k, v in override_cols.items() if v is not None}) + # Due to the internal implementation of the meca module, we need to convert the + # ``plot_longitude``, ``plot_latitude``, and ``event_name`` columns into strings + # if they exist. + for key in ["plot_longitude", "plot_latitude", "event_name"]: + if key in spec: + spec[key] = np.array(spec[key], dtype=str) + + # Reorder columns to match convention if necessary. The expected columns are: + # longitude, latitude, depth, focal_parameters, [plot_longitude, plot_latitude], + # [event_name]. + extra_cols = [] + if "plot_longitude" in spec and "plot_latitude" in spec: + extra_cols.extend(["plot_longitude", "plot_latitude"]) + if "event_name" in spec: + extra_cols.append("event_name") + cols = [*colnames, *extra_cols] + if list(spec.keys()) != cols: + spec = {k: spec[k] for k in cols} + return spec + + +def _auto_offset(spec) -> bool: + """ + Determine if offset should be set based on the input data. + + If the input data contains ``plot_longitude`` and ``plot_latitude``, then we set the + ``offset`` parameter to ``True`` automatically. + """ + return ( + isinstance(spec, dict | pd.DataFrame) + and "plot_longitude" in spec + and "plot_latitude" in spec + ) + + @fmt_docstring @use_alias( A="offset", @@ -45,7 +127,7 @@ def _get_focal_convention(spec, convention, component) -> _FocalMechanismConvent t="transparency", ) @kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence") -def meca( # noqa: PLR0912, PLR0913 +def meca( # noqa: PLR0913 self, spec, scale, @@ -248,78 +330,25 @@ def meca( # noqa: PLR0912, PLR0913 {transparency} """ kwargs = self._preprocess(**kwargs) - # Determine the focal mechanism convention from the input data or parameters. _convention = _get_focal_convention(spec, convention, component) - - # Convert spec to pandas.DataFrame unless it's a file - if isinstance(spec, dict | pd.DataFrame): # spec is a dict or pd.DataFrame - # convert dict to pd.DataFrame so columns can be reordered - if isinstance(spec, dict): - # convert values to ndarray so pandas doesn't complain about "all - # scalar values". See - # https://github.com/GenericMappingTools/pygmt/pull/2174 - spec = pd.DataFrame( - {key: np.atleast_1d(value) for key, value in spec.items()} - ) - elif isinstance(spec, np.ndarray): # spec is a numpy array - # Convert array to pd.DataFrame and assign column names - spec = pd.DataFrame(np.atleast_2d(spec)) - colnames = ["longitude", "latitude", "depth", *_convention.params] - # check if spec has the expected number of columns - ncolsdiff = len(spec.columns) - len(colnames) - if ncolsdiff == 0: - pass - elif ncolsdiff == 1: - colnames += ["event_name"] - elif ncolsdiff == 2: - colnames += ["plot_longitude", "plot_latitude"] - elif ncolsdiff == 3: - colnames += ["plot_longitude", "plot_latitude", "event_name"] - else: - msg = ( - f"Input array must have {len(colnames)} to {len(colnames) + 3} columns." - ) - raise GMTInvalidInput(msg) - spec.columns = colnames - - # Now spec is a pd.DataFrame or a file - if isinstance(spec, pd.DataFrame): - # override the values in pd.DataFrame if parameters are given - for arg, name in [ - (longitude, "longitude"), - (latitude, "latitude"), - (depth, "depth"), - (plot_longitude, "plot_longitude"), - (plot_latitude, "plot_latitude"), - (event_name, "event_name"), - ]: - if arg is not None: - spec[name] = np.atleast_1d(arg) - - # Due to the internal implementation of the meca module, we need to - # convert the following columns to strings if they exist - if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns: - spec["plot_longitude"] = spec["plot_longitude"].astype(str) - spec["plot_latitude"] = spec["plot_latitude"].astype(str) - if "event_name" in spec.columns: - spec["event_name"] = spec["event_name"].astype(str) - - # Reorder columns in DataFrame to match convention if necessary - # expected columns are: - # longitude, latitude, depth, focal_parameters, - # [plot_longitude, plot_latitude] [event_name] - newcols = ["longitude", "latitude", "depth", *_convention.params] - if "plot_longitude" in spec.columns and "plot_latitude" in spec.columns: - newcols += ["plot_longitude", "plot_latitude"] - if kwargs.get("A") is None: - kwargs["A"] = True - if "event_name" in spec.columns: - newcols += ["event_name"] - # reorder columns in DataFrame - if spec.columns.tolist() != newcols: - spec = spec.reindex(newcols, axis=1) - + # Preprocess the input data. + spec = _preprocess_spec( + spec, + # The minimum expected columns for the input data. + colnames=["longitude", "latitude", "depth", *_convention.params], + override_cols={ + "longitude": longitude, + "latitude": latitude, + "depth": depth, + "plot_longitude": plot_longitude, + "plot_latitude": plot_latitude, + "event_name": event_name, + }, + ) + # Determine the offset parameter if not provided. + if kwargs.get("A") is None: + kwargs["A"] = _auto_offset(spec) kwargs["S"] = f"{_convention.code}{scale}" with Session() as lib: with lib.virtualfile_in(check_kind="vector", data=spec) as vintbl: