Skip to content

Commit 9d6fd85

Browse files
committed
Allow for io.StringIO inputs into legend's spec argument
There is a new 'buffer' data_kind which is intended for io.StringIO stream data. This 'buffer' data is passed into `legend` via an intermediate file created using the helper tempfile_from_buffer function.
1 parent b084960 commit 9d6fd85

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

pygmt/base_plotting.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44
"""
55
import contextlib
66
import csv
7+
78
import numpy as np
89
import pandas as pd
910

1011
from .clib import Session
1112
from .exceptions import GMTError, GMTInvalidInput
1213
from .helpers import (
14+
GMTTempFile,
1315
build_arg_string,
14-
dummy_context,
1516
data_kind,
17+
dummy_context,
1618
fmt_docstring,
17-
GMTTempFile,
18-
use_alias,
1919
kwargs_to_strings,
20+
tempfile_from_buffer,
21+
use_alias,
2022
)
2123

2224

@@ -801,10 +803,11 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg
801803
802804
Parameters
803805
----------
804-
spec : None or str
805-
Either None (default) for using the automatically generated legend
806-
specification file, or a filename pointing to the legend
807-
specification file.
806+
spec : None or str or io.StringIO
807+
Set to None (default) for using the automatically generated legend
808+
specification file. Alternatively, pass in a filename or an
809+
io.StringIO in-memory stream buffer pointing to the legend
810+
specification text.
808811
{J}
809812
{R}
810813
position : str
@@ -829,13 +832,17 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg
829832

830833
with Session() as lib:
831834
if spec is None:
832-
specfile = ""
835+
file_context = dummy_context("")
833836
elif data_kind(spec) == "file":
834-
specfile = spec
837+
file_context = dummy_context(spec)
838+
elif data_kind(spec) == "buffer":
839+
file_context = tempfile_from_buffer(spec)
835840
else:
836-
raise GMTInvalidInput("Unrecognized data type: {}".format(type(spec)))
837-
arg_str = " ".join([specfile, build_arg_string(kwargs)])
838-
lib.call_module("legend", arg_str)
841+
raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}")
842+
843+
with file_context as fname:
844+
arg_str = " ".join([fname, build_arg_string(kwargs)])
845+
lib.call_module("legend", arg_str)
839846

840847
@fmt_docstring
841848
@use_alias(

pygmt/helpers/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Utilities and common tasks for wrapping the GMT modules.
33
"""
4-
import sys
4+
import io
55
import shutil
66
import subprocess
7+
import sys
78
import webbrowser
89
from collections.abc import Iterable
910
from contextlib import contextmanager
@@ -62,6 +63,8 @@ def data_kind(data, x=None, y=None, z=None):
6263

6364
if isinstance(data, str):
6465
kind = "file"
66+
elif isinstance(data, io.StringIO):
67+
kind = "buffer"
6568
elif isinstance(data, xr.DataArray):
6669
kind = "grid"
6770
elif data is not None:

pygmt/tests/test_legend.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Tests for legend
33
"""
4+
import io
5+
46
import pytest
57

68
from .. import Figure
@@ -44,9 +46,6 @@ def test_legend_default_position():
4446
return fig
4547

4648

47-
@pytest.mark.xfail(
48-
reason="Baseline image not updated to use earth relief grid in GMT 6.1.0",
49-
)
5049
@pytest.mark.mpl_image_compare
5150
def test_legend_entries():
5251
"""
@@ -73,7 +72,8 @@ def test_legend_entries():
7372

7473

7574
@pytest.mark.mpl_image_compare
76-
def test_legend_specfile():
75+
@pytest.mark.parametrize("usebuffer", [True, False])
76+
def test_legend_specfile(usebuffer):
7777
"""
7878
Test specfile functionality.
7979
"""
@@ -113,7 +113,10 @@ def test_legend_specfile():
113113
fig = Figure()
114114

115115
fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True)
116-
fig.legend(specfile.name, position="JTM+jCM+w5i")
116+
117+
spec = io.StringIO(specfile_contents) if usebuffer else specfile.name
118+
119+
fig.legend(spec=spec, position="JTM+jCM+w5i")
117120

118121
return fig
119122

0 commit comments

Comments
 (0)