From 1b609413762aadc93ef2d366b4f0a072242af362 Mon Sep 17 00:00:00 2001
From: FallenDeity <triyanmukherjee@gmail.com>
Date: Thu, 14 Dec 2023 03:53:34 +0530
Subject: [PATCH] fix: add upload sarif for codeql

---
 .github/workflows/codeql.yml |  7 ------
 examples/evolutions.py       | 20 ++++++++++++++++
 pokelance/client.py          | 23 +++++++++----------
 pokelance/http/__init__.py   | 44 ++++++++++++++++++++++++------------
 pokelance/logger.py          |  2 +-
 5 files changed, 62 insertions(+), 34 deletions(-)
 create mode 100644 examples/evolutions.py

diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
index 29a81a2..7880c27 100644
--- a/.github/workflows/codeql.yml
+++ b/.github/workflows/codeql.yml
@@ -74,10 +74,3 @@ jobs:
       uses: github/codeql-action/analyze@v2
       with:
         category: "/language:${{matrix.language}}"
-
-    # upload results to GitHub
-    - name: Upload SARIF results
-      uses: github/codeql-action/upload-sarif@v2
-      with:
-        sarif_file: "../results/python.sarif"
-        
\ No newline at end of file
diff --git a/examples/evolutions.py b/examples/evolutions.py
new file mode 100644
index 0000000..2e07ce8
--- /dev/null
+++ b/examples/evolutions.py
@@ -0,0 +1,20 @@
+import asyncio
+
+from pokelance import PokeLance
+
+
+async def main() -> None:
+    client = PokeLance()
+    # mon = await client.pokemon.fetch_pokemon("pichu")
+    # print(mon.forms)
+    print(await client.ping())
+    # ab = (await client.pokemon.fetch_ability("static"))
+    # print(ab.name)
+    await asyncio.sleep(10)
+    await client.close()
+
+
+if __name__ == "__main__":
+    loop = asyncio.get_event_loop()
+    loop.run_until_complete(main())
+    loop.close()
diff --git a/pokelance/client.py b/pokelance/client.py
index 7ac4bb3..59442c6 100644
--- a/pokelance/client.py
+++ b/pokelance/client.py
@@ -33,7 +33,7 @@ class PokeLance:
         The HTTP client used to make requests to the PokeAPI.
     _logger : logging.Logger
         The logger used to log information about the client.
-    _loaders : t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]]
+    _ext_tasks : t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]]
         A list of coroutines to load extension data.
     EXTENSIONS : Path
         The path to the extensions directory.
@@ -112,7 +112,7 @@ def __init__(
         """
         self._logger = logger or Logger(name="pokelance", file_logging=file_logging)
         self._http = HttpClient(client=self, session=session, cache_size=cache_size)
-        self._loaders: t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]] = []
+        self._ext_tasks: t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]] = []
         self._image_cache_size = image_cache_size
         lru_cache(maxsize=self._image_cache_size)(self.get_image_async)
         self.setup_hook()
@@ -126,8 +126,8 @@ async def __aexit__(
         exc_val: t.Optional[BaseException],
         exc_tb: t.Optional["TracebackType"],
     ) -> None:
-        if self._http.session is not None:
-            await self._http.session.close()
+        self.logger.warning("Closing session!")
+        await self._http.close()
 
     def setup_hook(self) -> None:
         """
@@ -154,7 +154,7 @@ def add_extension(self, name: str, extension: "BaseExtension") -> None:
         extension : BaseExtension
             The extension to add.
         """
-        self._loaders.append((extension.setup(), name))
+        self._ext_tasks.append((extension.setup(), name))
         setattr(self, name, extension)
 
     async def ping(self) -> float:
@@ -173,9 +173,8 @@ async def close(self) -> None:
         Closes the client session. Recommended to use this when the client is no longer needed.
         Not needed if the client is used in a context manager.
         """
-        self._logger.info("Closing session")
-        if self._http.session is not None:
-            await self._http.session.close()
+        self.logger.warning("Closing session!")
+        await self._http.close()
 
     async def getch_data(
         self, ext: t.Union[ExtensionEnum, ExtensionsL, str], category: str, id_: t.Union[int, str]
@@ -267,16 +266,16 @@ async def get_image_async(self, url: str) -> bytes:
         return await self._http.load_image(url)
 
     @property
-    def loaders(self) -> t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]]:
+    def ext_tasks(self) -> t.List[t.Tuple[t.Coroutine[t.Any, t.Any, None], str]]:
         """
-        The list of loaders for the extensions.
+        A list of coroutines to load extension data.
 
         Returns
         -------
         typing.List[typing.Tuple[typing.Coroutine[typing.Any, typing.Any, None], str]]
-            The list of loaders.
+            The list of tasks.
         """
-        return self._loaders
+        return self._ext_tasks
 
     @property
     def logger(self) -> "logging.Logger":
diff --git a/pokelance/http/__init__.py b/pokelance/http/__init__.py
index 95f9597..b747468 100644
--- a/pokelance/http/__init__.py
+++ b/pokelance/http/__init__.py
@@ -43,6 +43,8 @@ class HttpClient:
         The cache to use for the HTTP client.
     _client: pokelance.PokeLance
         The client that this HTTP client is for.
+    _tasks_queue: typing.List[asyncio.Task]
+        The queue for the tasks.
     """
 
     __slots__: t.Tuple[str, ...] = (
@@ -50,6 +52,7 @@ class HttpClient:
         "session",
         "_cache",
         "_is_ready",
+        "_tasks_queue",
     )
 
     def __init__(
@@ -75,21 +78,35 @@ def __init__(
         self.session = session
         self._is_ready = False
         self._cache = Cache(max_size=cache_size)
-
-    def _load_endpoints(self) -> None:
-        """Loads the endpoints for the HTTP client."""
-        for num, (coro, name) in enumerate(self._client.loaders):
-            task = asyncio.create_task(coro)
-            task.add_done_callback(
-                lambda _: self._client.logger.info(f"Loaded {self._client.loaders.pop(0)[1]} endpoints.")
-            )
-        self._is_ready = True
+        self._tasks_queue: t.List[asyncio.Task[None]] = []
+
+    async def _load_ext(self, coro: t.Coroutine[t.Any, t.Any, None], message: str) -> None:
+        await coro
+        self._client.logger.info(message)
+
+    async def _schedule_tasks(self) -> None:
+        total = len(self._client.ext_tasks)
+        for num, (coro, name) in enumerate(self._client.ext_tasks):
+            message = f"Extension {name} endpoints ({num + 1}/{total})"
+            self._client.logger.debug(f"Loading {message}")
+            task = asyncio.create_task(self._load_ext(coro, f"Loaded {message}"), name=name)
+            self._tasks_queue.append(task)
+        self._client.ext_tasks.clear()
+
+    async def close(self) -> None:
+        for task in self._tasks_queue:
+            if not task.done():
+                task.cancel()
+                self._client.logger.warning(f"Cancelled task {task.get_name()}")
+        if self.session:
+            await self.session.close()
 
     async def connect(self) -> None:
         """Connects the HTTP client."""
-        if not self._is_ready and self.session is None:
-            self.session = aiohttp.ClientSession()
-            self._load_endpoints()
+        self.session = self.session or aiohttp.ClientSession()
+        if not self._is_ready:
+            await self._schedule_tasks()
+            self._is_ready = True
 
     async def request(self, route: Route) -> t.Any:
         """Makes a request to the PokeAPI.
@@ -111,14 +128,13 @@ async def request(self, route: Route) -> t.Any:
         """
         if self.session is None:
             await self.connect()
-        if not self._is_ready and self._client.loaders:
-            self._load_endpoints()
         if self.session is not None:
             async with self.session.request(route.method, route.url, params=route.payload) as response:
                 if 300 > response.status >= 200:
                     self._client.logger.debug(f"Request to {route.url} was successful.")
                     return await response.json()
                 else:
+                    self._client.logger.error(f"Request to {route.url} was unsuccessful.")
                     raise HTTPException(str(response.reason), route, response.status).create()
         else:
             raise HTTPException("No session was provided.", route, 0).create()
diff --git a/pokelance/logger.py b/pokelance/logger.py
index 344d6e0..93b8a68 100644
--- a/pokelance/logger.py
+++ b/pokelance/logger.py
@@ -103,7 +103,7 @@ class Logger(logging.Logger):
 
     file_handler: t.Optional[FileHandler] = None
 
-    def __init__(self, *, name: str, level: int = logging.INFO, file_logging: bool = False) -> None:
+    def __init__(self, *, name: str, level: int = logging.DEBUG, file_logging: bool = False) -> None:
         super().__init__(name, level)
         self._handler = logging.StreamHandler()
         self._handler.setFormatter(Formatter())