|
1 | 1 | """
|
2 | 2 | grdtrack - Sample grids at specified (x,y) locations.
|
3 | 3 | """
|
| 4 | +import warnings |
| 5 | + |
| 6 | +import numpy as np |
4 | 7 | import pandas as pd
|
5 | 8 | from pygmt.clib import Session
|
6 | 9 | from pygmt.exceptions import GMTInvalidInput
|
7 |
| -from pygmt.helpers import ( |
8 |
| - GMTTempFile, |
9 |
| - build_arg_string, |
10 |
| - fmt_docstring, |
11 |
| - kwargs_to_strings, |
12 |
| - use_alias, |
13 |
| -) |
| 10 | +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias |
14 | 11 |
|
15 | 12 | __doctest_skip__ = ["grdtrack"]
|
16 | 13 |
|
|
43 | 40 | w="wrap",
|
44 | 41 | )
|
45 | 42 | @kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
|
46 |
| -def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): |
| 43 | +def grdtrack( |
| 44 | + grid, points=None, output_type="pandas", outfile=None, newcolname=None, **kwargs |
| 45 | +): |
47 | 46 | r"""
|
48 | 47 | Sample grids at specified (x,y) locations.
|
49 | 48 |
|
@@ -292,29 +291,44 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
|
292 | 291 | if hasattr(points, "columns") and newcolname is None:
|
293 | 292 | raise GMTInvalidInput("Please pass in a str to 'newcolname'")
|
294 | 293 |
|
295 |
| - with GMTTempFile(suffix=".csv") as tmpfile: |
296 |
| - with Session() as lib: |
297 |
| - with lib.virtualfile_from_data( |
298 |
| - check_kind="raster", data=grid |
299 |
| - ) as grdfile, lib.virtualfile_from_data( |
300 |
| - check_kind="vector", data=points, required_data=False |
301 |
| - ) as csvfile: |
302 |
| - kwargs["G"] = grdfile |
303 |
| - if outfile is None: # Output to tmpfile if outfile is not set |
304 |
| - outfile = tmpfile.name |
305 |
| - lib.call_module( |
306 |
| - module="grdtrack", |
307 |
| - args=build_arg_string(kwargs, infile=csvfile, outfile=outfile), |
308 |
| - ) |
| 294 | + if output_type not in ["numpy", "pandas", "file"]: |
| 295 | + raise GMTInvalidInput( |
| 296 | + "Must specify 'output_type' either as 'numpy', 'pandas' or 'file'." |
| 297 | + ) |
| 298 | + |
| 299 | + if outfile is not None and output_type != "file": |
| 300 | + msg = ( |
| 301 | + f"Changing 'output_type' from '{output_type}' to 'file' " |
| 302 | + "since 'outfile' parameter is set. Please use output_type='file' " |
| 303 | + "to silence this warning." |
| 304 | + ) |
| 305 | + warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2) |
| 306 | + output_type = "file" |
| 307 | + elif outfile is None and output_type == "file": |
| 308 | + raise GMTInvalidInput("Must specify 'outfile' for ASCII output.") |
| 309 | + |
| 310 | + if isinstance(points, pd.DataFrame): |
| 311 | + column_names = points.columns.to_list() + [newcolname] |
| 312 | + else: |
| 313 | + column_names = None |
309 | 314 |
|
310 |
| - # Read temporary csv output to a pandas table |
311 |
| - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame |
312 |
| - try: |
313 |
| - column_names = points.columns.to_list() + [newcolname] |
314 |
| - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) |
315 |
| - except AttributeError: # 'str' object has no attribute 'columns' |
316 |
| - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") |
317 |
| - elif outfile != tmpfile.name: # return None if outfile set, output in outfile |
318 |
| - result = None |
| 315 | + with Session() as lib: |
| 316 | + with lib.virtualfile_from_data( |
| 317 | + check_kind="raster", data=grid |
| 318 | + ) as ingrid, lib.virtualfile_from_data( |
| 319 | + check_kind="vector", data=points, required_data=False |
| 320 | + ) as infile, lib.virtualfile_to_data( |
| 321 | + kind="dataset", fname=outfile |
| 322 | + ) as outvfile: |
| 323 | + kwargs["G"] = ingrid |
| 324 | + lib.call_module( |
| 325 | + module="grdtrack", |
| 326 | + args=build_arg_string(kwargs, infile=infile, outfile=outvfile), |
| 327 | + ) |
319 | 328 |
|
320 |
| - return result |
| 329 | + if output_type == "file": |
| 330 | + return None |
| 331 | + vectors = lib.read_virtualfile(outvfile, kind="dataset").contents.to_vectors() |
| 332 | + if output_type == "numpy": |
| 333 | + return np.array(vectors).T |
| 334 | + return pd.DataFrame(np.array(vectors).T, columns=column_names) |
0 commit comments