Skip to content

Commit b561e9d

Browse files
seismanweiji14
andauthored
Improve performance by avoiding loading the GMT library repeatedly (#2930)
Co-authored-by: Wei Ji <[email protected]>
1 parent 88ab1ca commit b561e9d

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

pygmt/clib/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@
8787
np.datetime64: "GMT_DATETIME",
8888
}
8989

90+
# Load the GMT library outside the Session class to avoid repeated loading.
91+
_libgmt = load_libgmt()
92+
9093

9194
class Session:
9295
"""
@@ -308,7 +311,7 @@ def get_libgmt_func(self, name, argtypes=None, restype=None):
308311
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
309312
"""
310313
if not hasattr(self, "_libgmt"):
311-
self._libgmt = load_libgmt()
314+
self._libgmt = _libgmt
312315
function = getattr(self._libgmt, name)
313316
if argtypes is not None:
314317
function.argtypes = argtypes

pygmt/tests/test_clib_loading.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import pytest
1313
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
14+
from pygmt.clib.session import Session
1415
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError
1516

1617

@@ -207,6 +208,44 @@ def test_brokenlib_brokenlib_workinglib(self):
207208
assert check_libgmt(load_libgmt(lib_fullnames=lib_fullnames)) is None
208209

209210

211+
class TestLibgmtCount:
212+
"""
213+
Test that the GMT library is not repeatedly loaded in every session.
214+
"""
215+
216+
loaded_libgmt = load_libgmt() # Load the GMT library and reuse it when necessary
217+
counter = 0 # Global counter for how many times ctypes.CDLL is called
218+
219+
def _mock_ctypes_cdll_return(self, libname): # noqa: ARG002
220+
"""
221+
Mock ctypes.CDLL to count how many times the function is called.
222+
223+
If ctypes.CDLL is called, the counter increases by one.
224+
"""
225+
self.counter += 1 # Increase the counter
226+
return self.loaded_libgmt
227+
228+
def test_libgmt_load_counter(self, monkeypatch):
229+
"""
230+
Make sure that the GMT library is not loaded in every session.
231+
"""
232+
# Monkeypatch the ctypes.CDLL function
233+
monkeypatch.setattr(ctypes, "CDLL", self._mock_ctypes_cdll_return)
234+
235+
# Create two sessions and check the global counter
236+
with Session() as lib:
237+
_ = lib
238+
with Session() as lib:
239+
_ = lib
240+
assert self.counter == 0 # ctypes.CDLL is not called after two sessions.
241+
242+
# Explicitly calling load_libgmt to make sure the mock function is correct
243+
load_libgmt()
244+
assert self.counter == 1
245+
load_libgmt()
246+
assert self.counter == 2
247+
248+
210249
###############################################################################
211250
# Test clib_full_names
212251
@pytest.fixture(scope="module", name="gmt_lib_names")

0 commit comments

Comments
 (0)