From 3c9f43000fe31912163c6097ff62bfbc2dc47a52 Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Fri, 17 May 2024 04:10:18 +0800 Subject: [PATCH] fix: :bug: Fix single watchlist exception --- asynctradier/clients/watchlist_client.py | 6 +++++- tests/test_tradier.py | 25 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/asynctradier/clients/watchlist_client.py b/asynctradier/clients/watchlist_client.py index d2569ea..e80d06a 100644 --- a/asynctradier/clients/watchlist_client.py +++ b/asynctradier/clients/watchlist_client.py @@ -32,7 +32,11 @@ async def get_watchlists(self) -> List[Watchlist]: """ url = "/v1/watchlists" response = await self.session.get(url) - watchlists = response.get("watchlists", {}).get("watchlist", []) + watchlists = response.get("watchlists", {}).get("watchlist") + if watchlists is None: + return [] + elif isinstance(watchlists, dict): + return [Watchlist(**watchlists)] return [Watchlist(**watchlist) for watchlist in watchlists] async def get_watchlist(self, watchlist_id: str) -> Watchlist: diff --git a/tests/test_tradier.py b/tests/test_tradier.py index 82a9d79..e74deda 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -3316,6 +3316,31 @@ def mock_get(path: str, params: dict = None): tradier_client.session.get.assert_called_once_with("/v1/watchlists") +@pytest.mark.asyncio() +async def test_get_watchlists_single_element(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "watchlists": { + "watchlist": { + "name": "default", + "id": "default", + "public_id": "public-atea42pd", + } + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + watchlists = await tradier_client.get_watchlists() + + assert len(watchlists) == 1 + assert watchlists[0].name == "default" + assert watchlists[0].id == "default" + assert watchlists[0].public_id == "public-atea42pd" + + tradier_client.session.get.assert_called_once_with("/v1/watchlists") + + @pytest.mark.asyncio() async def test_get_watchlist(mocker, tradier_client): def mock_get(path: str, params: dict = None):