diff --git a/src/axiom_py/client.py b/src/axiom_py/client.py index 90e128c..d618314 100644 --- a/src/axiom_py/client.py +++ b/src/axiom_py/client.py @@ -1,6 +1,7 @@ """Client provides an easy-to use client library to connect to Axiom.""" import ndjson +import atexit import gzip import ujson import os @@ -128,6 +129,7 @@ class Client: # pylint: disable=R0903 datasets: DatasetsClient users: UsersClient annotations: AnnotationsClient + is_closed: bool # track if the client has been closed ( for tests ) def __init__( self, @@ -175,6 +177,14 @@ def __init__( self.users = UsersClient(self.session) self.annotations = AnnotationsClient(self.session, self.logger) + # wrap shutdown hook in a lambda passing in self as a ref + atexit.register(lambda: self.shutdown_hook()) + self.is_closed = False + + def shutdown_hook(self): + self.session.close() + self.is_closed = True + def ingest( self, dataset: str, diff --git a/tests/test_client.py b/tests/test_client.py index 5ea4138..55c40b5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,9 @@ """This module contains the tests for the axiom client.""" +import sys import os import unittest +from unittest.mock import patch import gzip import ujson import rfc3339 @@ -224,6 +226,17 @@ def test_step005_complex_query(self): agg = res.buckets.totals[0].aggregations[0] self.assertEqual("event_count", agg.op) + @patch("sys.exit") + def test_client_shutdown_atexit(self, mock_exit): + """Test client shutdown atexit""" + # Use the mock to test the firing mechanism + self.assertEqual(self.client.is_closed, False) + sys.exit() + mock_exit.assert_called_once() + # Use the hook implementation to assert the client is closed closed + self.client.shutdown_hook() + self.assertEqual(self.client.is_closed, True) + @classmethod def tearDownClass(cls): """A teardown that checks if the dataset still exists and deletes it,