From eaf090a4dba2c6a99a6a3849b15a50adb40f7fb3 Mon Sep 17 00:00:00 2001 From: Alexander Mohr Date: Thu, 17 Oct 2024 11:06:42 -0700 Subject: [PATCH] initial support --- opensearchpy/connection/http_async.py | 8 ++++++- opensearchpy/helpers/asyncsigner.py | 31 +++++++++++++++++---------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/opensearchpy/connection/http_async.py b/opensearchpy/connection/http_async.py index d04908788..f970426b5 100644 --- a/opensearchpy/connection/http_async.py +++ b/opensearchpy/connection/http_async.py @@ -10,6 +10,7 @@ import asyncio +import inspect import os import ssl import warnings @@ -196,10 +197,15 @@ async def perform_request( auth = ( self._http_auth if isinstance(self._http_auth, aiohttp.BasicAuth) else None ) + if callable(self._http_auth): + http_auth_result = self._http_auth(method, url, query_string, body) + if inspect.isawaitable(http_auth_result): + http_auth_result = await http_auth_result + req_headers = { **req_headers, - **self._http_auth(method, url, query_string, body), + **http_auth_result, } start = self.loop.time() diff --git a/opensearchpy/helpers/asyncsigner.py b/opensearchpy/helpers/asyncsigner.py index c045f1384..2122deea0 100644 --- a/opensearchpy/helpers/asyncsigner.py +++ b/opensearchpy/helpers/asyncsigner.py @@ -6,8 +6,14 @@ # # Modifications Copyright OpenSearch Contributors. See # GitHub history for details. +import inspect +from typing import Dict, Optional, Union, TYPE_CHECKING -from typing import Any, Dict, Optional, Union +if TYPE_CHECKING: + from botocore.credentials import Credentials, RefreshableCredentials + from aiobotocore.credentials import AioCredentials, AioRefreshableCredentials + + CredentialTypes = Credentials | RefreshableCredentials | AioCredentials | AioRefreshableCredentials class AWSV4SignerAsyncAuth: @@ -15,7 +21,7 @@ class AWSV4SignerAsyncAuth: AWS V4 Request Signer for Async Requests. """ - def __init__(self, credentials: Any, region: str, service: str = "es") -> None: + def __init__(self, credentials: 'CredentialTypes', region: str, service: str = "es") -> None: if not credentials: raise ValueError("Credentials cannot be empty") self.credentials = credentials @@ -28,16 +34,16 @@ def __init__(self, credentials: Any, region: str, service: str = "es") -> None: raise ValueError("Service name cannot be empty") self.service = service - def __call__( + async def __call__( self, method: str, url: str, query_string: Optional[str] = None, body: Optional[Union[str, bytes]] = None, ) -> Dict[str, str]: - return self._sign_request(method, url, query_string, body) + return await self._sign_request(method, url, query_string, body) - def _sign_request( + async def _sign_request( self, method: str, url: str, @@ -67,12 +73,15 @@ def _sign_request( # correspond to the secret_key used to sign the request. To avoid this, # get_frozen_credentials() which returns non-refreshing credentials is # called if it exists. - credentials = ( - self.credentials.get_frozen_credentials() - if hasattr(self.credentials, "get_frozen_credentials") - and callable(self.credentials.get_frozen_credentials) - else self.credentials - ) + if ( + hasattr(self.credentials, "get_frozen_credentials") + and callable(self.credentials.get_frozen_credentials) + ): + credentials = self.credentials.get_frozen_credentials() + if inspect.isawaitable(credentials): + credentials = await credentials + else: + credentials = self.credentials sig_v4_auth = SigV4Auth(credentials, self.service, self.region) sig_v4_auth.add_auth(aws_request)