|
1 | 1 | """
|
2 | 2 | grdtrack - Sample grids at specified (x,y) locations.
|
3 | 3 | """
|
4 |
| -import numpy as np |
5 | 4 | import pandas as pd
|
6 | 5 | from pygmt.clib import Session
|
7 | 6 | from pygmt.exceptions import GMTInvalidInput
|
8 |
| -from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias |
| 7 | +from pygmt.helpers import ( |
| 8 | + GMTTempFile, |
| 9 | + build_arg_string, |
| 10 | + fmt_docstring, |
| 11 | + kwargs_to_strings, |
| 12 | + use_alias, |
| 13 | +) |
9 | 14 |
|
10 | 15 | __doctest_skip__ = ["grdtrack"]
|
11 | 16 |
|
|
38 | 43 | w="wrap",
|
39 | 44 | )
|
40 | 45 | @kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
|
41 |
| -def grdtrack( |
42 |
| - grid, points=None, newcolname=None, output_type="pandas", outfile=None, **kwargs |
43 |
| -): |
| 46 | +def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): |
44 | 47 | r"""
|
45 | 48 | Sample grids at specified (x,y) locations.
|
46 | 49 |
|
@@ -289,30 +292,29 @@ def grdtrack(
|
289 | 292 | if hasattr(points, "columns") and newcolname is None:
|
290 | 293 | raise GMTInvalidInput("Please pass in a str to 'newcolname'")
|
291 | 294 |
|
292 |
| - with Session() as lib: |
293 |
| - with lib.virtualfile_from_data( |
294 |
| - check_kind="raster", data=grid |
295 |
| - ) as grdfile, lib.virtualfile_from_data( |
296 |
| - check_kind="vector", data=points, required_data=False |
297 |
| - ) as csvfile, lib.virtualfile_to_gmtdataset() as outvfile: |
298 |
| - kwargs["G"] = grdfile |
299 |
| - lib.call_module( |
300 |
| - module="grdtrack", |
301 |
| - args=build_arg_string(kwargs, infile=csvfile, outfile=outvfile), |
302 |
| - ) |
303 |
| - if outfile is not None: |
304 |
| - # if output_type == "file": |
305 |
| - lib.call_module("write", f"{outvfile} {outfile} -Td") |
306 |
| - return None |
307 |
| - |
308 |
| - vectors = lib.gmtdataset_to_vectors(outvfile) |
309 |
| - |
310 |
| - if output_type == "numpy": |
311 |
| - return np.array(vectors).T |
| 295 | + with GMTTempFile(suffix=".csv") as tmpfile: |
| 296 | + with Session() as lib: |
| 297 | + with lib.virtualfile_from_data( |
| 298 | + check_kind="raster", data=grid |
| 299 | + ) as grdfile, lib.virtualfile_from_data( |
| 300 | + check_kind="vector", data=points, required_data=False |
| 301 | + ) as csvfile: |
| 302 | + kwargs["G"] = grdfile |
| 303 | + if outfile is None: # Output to tmpfile if outfile is not set |
| 304 | + outfile = tmpfile.name |
| 305 | + lib.call_module( |
| 306 | + module="grdtrack", |
| 307 | + args=build_arg_string(kwargs, infile=csvfile, outfile=outfile), |
| 308 | + ) |
312 | 309 |
|
313 |
| - if isinstance(points, pd.DataFrame): |
314 |
| - column_names = points.columns.to_list() + [newcolname] |
315 |
| - else: |
316 |
| - column_names = None |
| 310 | + # Read temporary csv output to a pandas table |
| 311 | + if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame |
| 312 | + try: |
| 313 | + column_names = points.columns.to_list() + [newcolname] |
| 314 | + result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) |
| 315 | + except AttributeError: # 'str' object has no attribute 'columns' |
| 316 | + result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") |
| 317 | + elif outfile != tmpfile.name: # return None if outfile set, output in outfile |
| 318 | + result = None |
317 | 319 |
|
318 |
| - return pd.DataFrame(np.array(vectors).T, columns=column_names) |
| 320 | + return result |
0 commit comments