Skip to content

Commit 6bbcc79

Browse files
committed
GMT_DATASET.to_dataframe: Add the 'header' parameter for parsing column names from data header
1 parent 85d4ed2 commit 6bbcc79

File tree

3 files changed

+83
-8
lines changed

3 files changed

+83
-8
lines changed

pygmt/clib/session.py

+5
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,7 @@ def virtualfile_to_dataset(
17461746
self,
17471747
vfname: str,
17481748
output_type: Literal["pandas", "numpy", "file"] = "pandas",
1749+
header: int | None = None,
17491750
column_names: list[str] | None = None,
17501751
dtype: type | dict[str, type] | None = None,
17511752
index_col: str | int | None = None,
@@ -1766,6 +1767,9 @@ def virtualfile_to_dataset(
17661767
- ``"pandas"`` will return a :class:`pandas.DataFrame` object.
17671768
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
17681769
- ``"file"`` means the result was saved to a file and will return ``None``.
1770+
header
1771+
Row number containing column names. ``header=None`` means not to parse the
1772+
column names from data header.
17691773
column_names
17701774
The column names for the :class:`pandas.DataFrame` output.
17711775
dtype
@@ -1862,6 +1866,7 @@ def virtualfile_to_dataset(
18621866

18631867
# Read the virtual file as a GMT dataset and convert to pandas.DataFrame
18641868
result = self.read_virtualfile(vfname, kind="dataset").contents.to_dataframe(
1869+
header=header,
18651870
column_names=column_names,
18661871
dtype=dtype,
18671872
index_col=index_col,

pygmt/datatypes/dataset.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
2626
>>> with GMTTempFile(suffix=".txt") as tmpfile:
2727
... # Prepare the sample data file
2828
... with Path(tmpfile.name).open(mode="w") as fp:
29+
... print("# x y z name", file=fp)
2930
... print(">", file=fp)
3031
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
3132
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
@@ -42,7 +43,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
4243
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
4344
... # The table
4445
... tbl = ds.table[0].contents
45-
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
46+
... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers)
47+
... print(tbl.header[: tbl.n_headers])
4648
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
4749
... for i in range(tbl.n_segments):
4850
... seg = tbl.segment[i].contents
@@ -51,7 +53,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
5153
... print(seg.text[: seg.n_rows])
5254
1 3 2
5355
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
54-
3 2 4
56+
3 2 4 1
57+
[b'x y z name']
5558
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
5659
[1.0, 4.0]
5760
[2.0, 5.0]
@@ -144,8 +147,9 @@ class _GMT_DATASEGMENT(ctp.Structure): # noqa: N801
144147
("hidden", ctp.c_void_p),
145148
]
146149

147-
def to_dataframe(
150+
def to_dataframe( # noqa: PLR0912
148151
self,
152+
header: int | None = None,
149153
column_names: pd.Index | None = None,
150154
dtype: type | Mapping[Any, type] | None = None,
151155
index_col: str | int | None = None,
@@ -164,6 +168,9 @@ def to_dataframe(
164168
----------
165169
column_names
166170
A list of column names.
171+
header
172+
Row number containing column names. ``header=None`` means not to parse the
173+
column names from data header.
167174
dtype
168175
Data type. Can be a single type for all columns or a dictionary mapping
169176
column names to types.
@@ -184,6 +191,7 @@ def to_dataframe(
184191
>>> with GMTTempFile(suffix=".txt") as tmpfile:
185192
... # prepare the sample data file
186193
... with Path(tmpfile.name).open(mode="w") as fp:
194+
... print("# col1 col2 col3 colstr", file=fp)
187195
... print(">", file=fp)
188196
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
189197
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
@@ -194,9 +202,9 @@ def to_dataframe(
194202
... with lib.virtualfile_out(kind="dataset") as vouttbl:
195203
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
196204
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
197-
... df = ds.contents.to_dataframe()
205+
... df = ds.contents.to_dataframe(header=0)
198206
>>> df
199-
0 1 2 3
207+
col1 col2 col3 colstr
200208
0 1.0 2.0 3.0 TEXT1 TEXT23
201209
1 4.0 5.0 6.0 TEXT4 TEXT567
202210
2 7.0 8.0 9.0 TEXT8 TEXT90
@@ -230,14 +238,19 @@ def to_dataframe(
230238
pd.Series(data=np.char.decode(textvector), dtype=pd.StringDtype())
231239
)
232240

241+
if header is not None:
242+
tbl = self.table[0].contents # Use the first table!
243+
if header < tbl.n_headers:
244+
column_names = tbl.header[header].decode().split()
245+
233246
if len(vectors) == 0:
234247
# Return an empty DataFrame if no columns are found.
235248
df = pd.DataFrame(columns=column_names)
236249
else:
237250
# Create a DataFrame object by concatenating multiple columns
238251
df = pd.concat(objs=vectors, axis="columns")
239252
if column_names is not None: # Assign column names
240-
df.columns = column_names
253+
df.columns = column_names[: df.shape[1]]
241254
if dtype is not None: # Set dtype for the whole dataset or individual columns
242255
df = df.astype(dtype)
243256
if index_col is not None: # Use a specific column as index

pygmt/tests/test_datatypes_dataset.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No
3939
return df
4040

4141

42-
def dataframe_from_gmt(fname):
42+
def dataframe_from_gmt(fname, **kwargs):
4343
"""
4444
Read tabular data as pandas.DataFrame using GMT virtual file.
4545
"""
4646
with Session() as lib:
4747
with lib.virtualfile_out(kind="dataset") as vouttbl:
4848
lib.call_module("read", f"{fname} {vouttbl} -Td")
49-
df = lib.virtualfile_to_dataset(vfname=vouttbl)
49+
df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs)
5050
return df
5151

5252

@@ -81,3 +81,60 @@ def test_dataset_empty():
8181
assert df.empty # Empty DataFrame
8282
expected_df = dataframe_from_pandas(tmpfile.name)
8383
pd.testing.assert_frame_equal(df, expected_df)
84+
85+
86+
def test_dataset_header():
87+
"""
88+
Test parsing column names from dataset header.
89+
"""
90+
with GMTTempFile(suffix=".txt") as tmpfile:
91+
with Path(tmpfile.name).open(mode="w") as fp:
92+
print("# lon lat z text", file=fp)
93+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
94+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
95+
96+
# Parse columne names from the first header line
97+
df = dataframe_from_gmt(tmpfile.name, header=0)
98+
assert df.columns.tolist() == ["lon", "lat", "z", "text"]
99+
# pd.read_csv() can't parse the header line with a leading '#'.
100+
# So, we need to skip the header line and manually set the column names.
101+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
102+
expected_df.columns = df.columns.tolist()
103+
pd.testing.assert_frame_equal(df, expected_df)
104+
105+
106+
def test_dataset_header_greater_than_nheaders():
107+
"""
108+
Test passing a header line number that is greater than the number of header lines.
109+
"""
110+
with GMTTempFile(suffix=".txt") as tmpfile:
111+
with Path(tmpfile.name).open(mode="w") as fp:
112+
print("# lon lat z text", file=fp)
113+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
114+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
115+
116+
# Parse column names from the second header line.
117+
df = dataframe_from_gmt(tmpfile.name, header=1)
118+
# There is only one header line, so the column names should be default.
119+
assert df.columns.tolist() == [0, 1, 2, 3]
120+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
121+
pd.testing.assert_frame_equal(df, expected_df)
122+
123+
124+
def test_dataset_header_too_many_names():
125+
"""
126+
Test passing a header line with more column names than the number of columns.
127+
"""
128+
with GMTTempFile(suffix=".txt") as tmpfile:
129+
with Path(tmpfile.name).open(mode="w") as fp:
130+
print("# lon lat z text1 text2", file=fp)
131+
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
132+
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
133+
134+
df = dataframe_from_gmt(tmpfile.name, header=0)
135+
assert df.columns.tolist() == ["lon", "lat", "z", "text1"]
136+
# pd.read_csv() can't parse the header line with a leading '#'.
137+
# So, we need to skip the header line and manually set the column names.
138+
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
139+
expected_df.columns = df.columns.tolist()
140+
pd.testing.assert_frame_equal(df, expected_df)

0 commit comments

Comments
 (0)