Skip to content

Commit 7815c6a

Browse files
authored
Small refactor of AWS Signer classes for both sync and async clients (#866)
* made custom headers be available to async aws signer Signed-off-by: Bruno Murino <[email protected]> * updated changelog Signed-off-by: Bruno Murino <[email protected]> * added tests for using host header for AWS request signature on both sync and async clients Signed-off-by: Bruno Murino <[email protected]> * added documentation guide about aws auth when accessing via tunnel Signed-off-by: Bruno Murino <[email protected]> * small refactor of AWS Signer classes on sync and async clients; improved testing on them as well Signed-off-by: Bruno Murino <[email protected]> * changelog Signed-off-by: Bruno Murino <[email protected]> * fixed test Signed-off-by: Bruno Murino <[email protected]> * lint fix Signed-off-by: Bruno Murino <[email protected]> --------- Signed-off-by: Bruno Murino <[email protected]>
1 parent 87aebcd commit 7815c6a

File tree

5 files changed

+81
-124
lines changed

5 files changed

+81
-124
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
77
- Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859))
88
### Updated APIs
99
### Changed
10+
- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866))
1011
### Deprecated
1112
### Removed
1213
### Fixed

opensearchpy/helpers/asyncsigner.py

Lines changed: 9 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
# GitHub history for details.
99

1010
from typing import Any, Dict, Optional, Union
11-
from urllib.parse import parse_qs, urlencode, urlparse
11+
12+
from opensearchpy.helpers.signer import AWSV4Signer
1213

1314

1415
class AWSV4SignerAsyncAuth:
@@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth:
1718
"""
1819

1920
def __init__(self, credentials: Any, region: str, service: str = "es") -> None:
20-
if not credentials:
21-
raise ValueError("Credentials cannot be empty")
22-
self.credentials = credentials
23-
24-
if not region:
25-
raise ValueError("Region cannot be empty")
26-
self.region = region
27-
28-
if not service:
29-
raise ValueError("Service name cannot be empty")
30-
self.service = service
21+
self.signer = AWSV4Signer(credentials, region, service)
3122

3223
def __call__(
3324
self,
3425
method: str,
3526
url: str,
36-
query_string: Optional[str] = None,
3727
body: Optional[Union[str, bytes]] = None,
3828
headers: Optional[Dict[str, str]] = None,
3929
) -> Dict[str, str]:
40-
return self._sign_request(method, url, query_string, body, headers)
30+
return self._sign_request(method=method, url=url, body=body, headers=headers)
4131

4232
def _sign_request(
4333
self,
4434
method: str,
4535
url: str,
46-
query_string: Optional[str],
4736
body: Optional[Union[str, bytes]],
4837
headers: Optional[Dict[str, str]],
4938
) -> Dict[str, str]:
@@ -53,58 +42,10 @@ def _sign_request(
5342
:return: signed headers
5443
"""
5544

56-
from botocore.auth import SigV4Auth
57-
from botocore.awsrequest import AWSRequest
58-
59-
signature_host = self._fetch_url(url, headers or dict())
60-
61-
# create an AWS request object and sign it using SigV4Auth
62-
aws_request = AWSRequest(
45+
updated_headers = self.signer.sign(
6346
method=method,
64-
url=signature_host,
65-
data=body,
66-
)
67-
68-
# credentials objects expose access_key, secret_key and token attributes
69-
# via @property annotations that call _refresh() on every access,
70-
# creating a race condition if the credentials expire before secret_key
71-
# is called but after access_key- the end result is the access_key doesn't
72-
# correspond to the secret_key used to sign the request. To avoid this,
73-
# get_frozen_credentials() which returns non-refreshing credentials is
74-
# called if it exists.
75-
credentials = (
76-
self.credentials.get_frozen_credentials()
77-
if hasattr(self.credentials, "get_frozen_credentials")
78-
and callable(self.credentials.get_frozen_credentials)
79-
else self.credentials
47+
url=url,
48+
body=body,
49+
headers=headers,
8050
)
81-
82-
sig_v4_auth = SigV4Auth(credentials, self.service, self.region)
83-
sig_v4_auth.add_auth(aws_request)
84-
aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request)
85-
86-
# copy the headers from AWS request object into the prepared_request
87-
return dict(aws_request.headers.items())
88-
89-
def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
90-
"""
91-
This is a util method that helps in reconstructing the request url.
92-
:param prepared_request: unsigned request
93-
:return: reconstructed url
94-
"""
95-
parsed_url = urlparse(url)
96-
path = parsed_url.path or "/"
97-
98-
# fetch the query string if present in the request
99-
querystring = ""
100-
if parsed_url.query:
101-
querystring = "?" + urlencode(
102-
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
103-
)
104-
105-
# fetch the host information from headers
106-
headers = {key.lower(): value for key, value in (headers or dict()).items()}
107-
location = headers.get("host") or parsed_url.netloc
108-
109-
# construct the url and return
110-
return parsed_url.scheme + "://" + location + path + querystring
51+
return updated_headers

opensearchpy/helpers/signer.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Modifications Copyright OpenSearch Contributors. See
88
# GitHub history for details.
99

10-
from typing import Any, Callable, Dict
10+
from typing import Any, Callable, Dict, Optional
1111
from urllib.parse import parse_qs, urlencode, urlparse
1212

1313
import requests
@@ -31,7 +31,9 @@ def __init__(self, credentials, region: str, service: str = "es") -> Any: # typ
3131
raise ValueError("Service name cannot be empty")
3232
self.service = service
3333

34-
def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
34+
def sign(
35+
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
36+
) -> Dict[str, str]:
3537
"""
3638
This method signs the request and returns headers.
3739
:param method: HTTP method
@@ -43,8 +45,10 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
4345
from botocore.auth import SigV4Auth
4446
from botocore.awsrequest import AWSRequest
4547

48+
signature_host = self._fetch_url(url, headers or dict())
49+
4650
# create an AWS request object and sign it using SigV4Auth
47-
aws_request = AWSRequest(method=method.upper(), url=url, data=body)
51+
aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body)
4852

4953
# credentials objects expose access_key, secret_key and token attributes
5054
# via @property annotations that call _refresh() on every access,
@@ -69,6 +73,30 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
6973

7074
return headers
7175

76+
@staticmethod
77+
def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str:
78+
"""
79+
This is a util method that helps in reconstructing the request url.
80+
:param prepared_request: unsigned request
81+
:return: reconstructed url
82+
"""
83+
parsed_url = urlparse(url)
84+
path = parsed_url.path or "/"
85+
86+
# fetch the query string if present in the request
87+
querystring = ""
88+
if parsed_url.query:
89+
querystring = "?" + urlencode(
90+
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
91+
)
92+
93+
# fetch the host information from headers
94+
headers = {key.lower(): value for key, value in (headers or dict()).items()}
95+
location = headers.get("host") or parsed_url.netloc
96+
97+
# construct the url and return
98+
return parsed_url.scheme + "://" + location + path + querystring
99+
72100

73101
class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
74102
"""
@@ -89,40 +117,16 @@ def _sign_request(self, prepared_request): # type: ignore
89117
:return: signed request
90118
"""
91119

92-
prepared_request.headers.update(
93-
self.signer.sign(
94-
prepared_request.method,
95-
self._fetch_url(prepared_request),
96-
prepared_request.body,
97-
)
120+
updated_headers = self.signer.sign(
121+
method=prepared_request.method,
122+
url=prepared_request.url,
123+
body=prepared_request.body,
124+
headers=prepared_request.headers,
98125
)
99126

100-
return prepared_request
101-
102-
def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
103-
"""
104-
This is a util method that helps in reconstructing the request url.
105-
:param prepared_request: unsigned request
106-
:return: reconstructed url
107-
"""
108-
url = urlparse(prepared_request.url)
109-
path = url.path or "/"
110-
111-
# fetch the query string if present in the request
112-
querystring = ""
113-
if url.query:
114-
querystring = "?" + urlencode(
115-
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
116-
)
127+
prepared_request.headers.update(updated_headers)
117128

118-
# fetch the host information from headers
119-
headers = {
120-
key.lower(): value for key, value in prepared_request.headers.items()
121-
}
122-
location = headers.get("host") or url.netloc
123-
124-
# construct the url and return
125-
return url.scheme + "://" + location + path + querystring # type: ignore
129+
return prepared_request
126130

127131

128132
# Deprecated: use RequestsAWSV4SignerAuth
@@ -135,5 +139,7 @@ def __init__(self, credentials, region, service: str = "es") -> None: # type: i
135139
self.signer = AWSV4Signer(credentials, region, service)
136140
self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600
137141

138-
def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]:
139-
return self.signer.sign(method, url, body)
142+
def __call__(
143+
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
144+
) -> Dict[str, str]:
145+
return self.signer.sign(method, url, body, headers)

test_opensearchpy/test_async/test_signer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import uuid
1111
from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union
12-
from unittest.mock import Mock
12+
from unittest.mock import Mock, patch
1313

1414
import pytest
1515
from _pytest.mark.structures import MarkDecorator
@@ -81,15 +81,18 @@ async def test_aws_signer_async_fetch_url_with_querystring(self) -> None:
8181
region = "us-west-2"
8282
service = "aoss"
8383

84-
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth
85-
86-
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
84+
from botocore.awsrequest import AWSRequest
8785

88-
signature_host = auth._fetch_url(
89-
"http://localhost/?foo=bar", headers={"host": "otherhost"}
90-
)
86+
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth
9187

92-
assert signature_host == "http://otherhost/?foo=bar"
88+
with patch(
89+
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
90+
) as mock_aws_request:
91+
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
92+
auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"})
93+
mock_aws_request.assert_called_with(
94+
method="GET", url="http://otherhost:443/?foo=bar", data=None
95+
)
9396

9497

9598
class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
@@ -155,7 +158,6 @@ def _sign_request(
155158
self,
156159
method: str,
157160
url: str,
158-
query_string: Optional[str] = None,
159161
body: Optional[Union[str, bytes]] = None,
160162
headers: Optional[Dict[str, str]] = None,
161163
) -> Dict[str, str]:

test_opensearchpy/test_connection/test_requests_http_connection.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -457,22 +457,27 @@ def mock_session(self) -> Any:
457457

458458
return dummy_session
459459

460-
def test_aws_signer_fetch_url_with_querystring(self) -> None:
460+
def test_aws_signer_url_with_querystring_and_custom_header(self) -> None:
461461
region = "us-west-2"
462462

463463
import requests
464+
from botocore.awsrequest import AWSRequest
464465

465466
from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth
466467

467-
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
468-
469-
prepared_request = requests.Request(
470-
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
471-
).prepare()
468+
with patch(
469+
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
470+
) as mock_aws_request:
472471

473-
signature_host = auth._fetch_url(prepared_request)
472+
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
473+
prepared_request = requests.Request(
474+
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
475+
).prepare()
476+
auth(prepared_request)
474477

475-
assert signature_host == "http://otherhost:443/?foo=bar"
478+
mock_aws_request.assert_called_with(
479+
method="GET", url="http://otherhost:443/?foo=bar", data=None
480+
)
476481

477482
def test_aws_signer_as_http_auth(self) -> None:
478483
region = "us-west-2"
@@ -525,9 +530,11 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None:
525530
).prepare()
526531
auth(prepared_request)
527532
self.assertEqual(mock_sign.call_count, 1)
528-
self.assertEqual(
529-
mock_sign.call_args[0],
530-
("GET", "http://localhost/?key1=value1&key2=value2", None),
533+
mock_sign.assert_called_with(
534+
method="GET",
535+
url="http://localhost/?key1=value1&key2=value2",
536+
body=None,
537+
headers={},
531538
)
532539

533540
def test_aws_signer_consitent_url(self) -> None:

0 commit comments

Comments
 (0)