Skip to content
This repository was archived by the owner on Mar 18, 2019. It is now read-only.

Commit a91c765

Browse files
committed
Add support for auth schemes
1 parent b82a8f2 commit a91c765

File tree

10 files changed

+283
-47
lines changed

10 files changed

+283
-47
lines changed

coreapi/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# coding: utf-8
2-
from coreapi import codecs, exceptions, transports, utils
2+
from coreapi import auth, codecs, exceptions, transports, utils
33
from coreapi.client import Client
44
from coreapi.document import Array, Document, Link, Object, Error, Field
55

66

7-
__version__ = '2.2.4'
7+
__version__ = '2.3.0'
88
__all__ = [
99
'Array', 'Document', 'Link', 'Object', 'Error', 'Field',
1010
'Client',
11-
'codecs', 'exceptions', 'transports', 'utils',
11+
'auth', 'codecs', 'exceptions', 'transports', 'utils',
1212
]

coreapi/auth.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from coreapi.utils import domain_matches
2+
from requests.auth import AuthBase, HTTPBasicAuth
3+
4+
5+
class BasicAuthentication(HTTPBasicAuth):
6+
allow_cookies = False
7+
8+
def __init__(self, username, password, domain=None):
9+
self.domain = domain
10+
super(BasicAuthentication, self).__init__(username, password)
11+
12+
def __call__(self, request):
13+
if not domain_matches(request, self.domain):
14+
return request
15+
16+
return super(BasicAuthentication, self).__call__(request)
17+
18+
19+
class TokenAuthentication(AuthBase):
20+
allow_cookies = False
21+
prefix = 'Bearer'
22+
23+
def __init__(self, token, prefix=None, domain=None):
24+
"""
25+
* Use an unauthenticated client, and make a request to obtain a token.
26+
* Create an authenticated client using eg. `TokenAuthentication(token="<token>")`
27+
"""
28+
self.token = token
29+
self.domain = domain
30+
if prefix is not None:
31+
self.prefix = prefix
32+
33+
def __call__(self, request):
34+
if not domain_matches(request, self.domain):
35+
return request
36+
37+
request.headers['Authorization'] = '%s %s' % (self.prefix, self.token)
38+
return request
39+
40+
41+
class SessionAuthentication(AuthBase):
42+
"""
43+
Enables session based login.
44+
45+
* Make an initial request to obtain a CSRF token.
46+
* Make a login request.
47+
"""
48+
allow_cookies = True
49+
safe_methods = ('GET', 'HEAD', 'OPTIONS', 'TRACE')
50+
51+
def __init__(self, csrf_cookie_name=None, csrf_header_name=None, domain=None):
52+
self.csrf_cookie_name = csrf_cookie_name
53+
self.csrf_header_name = csrf_header_name
54+
self.csrf_token = None
55+
self.domain = domain
56+
57+
def store_csrf_token(self, response, **kwargs):
58+
if self.csrf_cookie_name in response.cookies:
59+
self.csrf_token = response.cookies[self.csrf_cookie_name]
60+
61+
def __call__(self, request):
62+
if not domain_matches(request, self.domain):
63+
return request
64+
65+
if self.csrf_token and self.csrf_header_name is not None and (request.method not in self.safe_methods):
66+
request.headers[self.csrf_header_name] = self.csrf_token
67+
if self.csrf_cookie_name is not None:
68+
request.register_hook('response', self.store_csrf_token)
69+
return request

coreapi/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,23 @@ def get_default_decoders():
8989
]
9090

9191

92-
def get_default_transports():
92+
def get_default_transports(auth=None, session=None):
9393
return [
94-
transports.HTTPTransport()
94+
transports.HTTPTransport(auth=auth, session=session)
9595
]
9696

9797

9898
class Client(itypes.Object):
99-
def __init__(self, decoders=None, transports=None):
99+
def __init__(self, decoders=None, transports=None, auth=None, session=None):
100+
assert transports is None or auth is None, (
101+
"Cannot specify both 'auth' and 'transports'. "
102+
"When specifying transport instances explicitly you should set "
103+
"the authentication directly on the transport."
104+
)
100105
if decoders is None:
101106
decoders = get_default_decoders()
102107
if transports is None:
103-
transports = get_default_transports()
108+
transports = get_default_transports(auth=auth)
104109
self._decoders = itypes.List(decoders)
105110
self._transports = itypes.List(transports)
106111

coreapi/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
try:
1313
# Python 2
1414
import urlparse
15+
import cookielib as cookiejar
1516

1617
string_types = (basestring,)
1718
text_type = unicode
@@ -26,6 +27,7 @@ def b64encode(input_string):
2627
# Python 3
2728
import urllib.parse as urlparse
2829
from io import IOBase
30+
from http import cookiejar
2931

3032
string_types = (str,)
3133
text_type = str

coreapi/transports/http.py

Lines changed: 79 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import unicode_literals
33
from collections import OrderedDict
44
from coreapi import exceptions, utils
5-
from coreapi.compat import urlparse
5+
from coreapi.compat import cookiejar, urlparse
66
from coreapi.document import Document, Object, Link, Array, Error
77
from coreapi.transports.base import BaseTransport
88
from coreapi.utils import guess_filename, is_file, File
@@ -11,23 +11,75 @@
1111
import itypes
1212
import mimetypes
1313
import uritemplate
14+
import warnings
1415

1516

1617
Params = collections.namedtuple('Params', ['path', 'query', 'data', 'files'])
1718
empty_params = Params({}, {}, {}, {})
1819

1920

2021
class ForceMultiPartDict(dict):
21-
# A dictionary that always evaluates as True.
22-
# Allows us to force requests to use multipart encoding, even when no
23-
# file parameters are passed.
22+
"""
23+
A dictionary that always evaluates as True.
24+
Allows us to force requests to use multipart encoding, even when no
25+
file parameters are passed.
26+
"""
2427
def __bool__(self):
2528
return True
2629

2730
def __nonzero__(self):
2831
return True
2932

3033

34+
class BlockAll(cookiejar.CookiePolicy):
35+
"""
36+
A cookie policy that rejects all cookies.
37+
Used to override the default `requests` behavior.
38+
"""
39+
return_ok = set_ok = domain_return_ok = path_return_ok = lambda self, *args, **kwargs: False
40+
netscape = True
41+
rfc2965 = hide_cookie2 = False
42+
43+
44+
class DomainCredentials(requests.auth.AuthBase):
45+
"""
46+
Custom auth class to support deprecated 'credentials' argument.
47+
"""
48+
allow_cookies = False
49+
credentials = None
50+
51+
def __init__(self, credentials=None):
52+
self.credentials = credentials
53+
54+
def __call__(self, request):
55+
if not self.credentials:
56+
return request
57+
58+
# Include any authorization credentials relevant to this domain.
59+
url_components = urlparse.urlparse(request.url)
60+
host = url_components.hostname
61+
if host in self.credentials:
62+
request.headers['Authorization'] = self.credentials[host]
63+
return request
64+
65+
66+
class CallbackAdapter(requests.adapters.HTTPAdapter):
67+
"""
68+
Custom requests HTTP adapter, to support deprecated callback arguments.
69+
"""
70+
def __init__(self, request_callback=None, response_callback=None):
71+
self.request_callback = request_callback
72+
self.response_callback = response_callback
73+
74+
def send(self, request, **kwargs):
75+
if self.request_callback is not None:
76+
self.request_callback(request)
77+
response = super(CallbackAdapter, self).send(request, **kwargs)
78+
if self.response_callback is not None:
79+
self.response_callback(response)
80+
return response
81+
82+
3183
def _get_method(action):
3284
if not action:
3385
return 'GET'
@@ -107,7 +159,7 @@ def _get_url(url, path_params):
107159
return url
108160

109161

110-
def _get_headers(url, decoders, credentials=None):
162+
def _get_headers(url, decoders):
111163
"""
112164
Return a dictionary of HTTP headers to use in the outgoing request.
113165
"""
@@ -120,13 +172,6 @@ def _get_headers(url, decoders, credentials=None):
120172
'user-agent': 'coreapi'
121173
}
122174

123-
if credentials:
124-
# Include any authorization credentials relevant to this domain.
125-
url_components = urlparse.urlparse(url)
126-
host = url_components.hostname
127-
if host in credentials:
128-
headers['authorization'] = credentials[host]
129-
130175
return headers
131176

132177

@@ -288,24 +333,34 @@ def _handle_inplace_replacements(document, link, link_ancestors):
288333
class HTTPTransport(BaseTransport):
289334
schemes = ['http', 'https']
290335

291-
def __init__(self, credentials=None, headers=None, session=None, request_callback=None, response_callback=None):
336+
def __init__(self, credentials=None, headers=None, auth=None, session=None, request_callback=None, response_callback=None):
292337
if headers:
293338
headers = {key.lower(): value for key, value in headers.items()}
294339
if session is None:
295340
session = requests.Session()
296-
self._credentials = itypes.Dict(credentials or {})
341+
if auth is not None:
342+
session.auth = auth
343+
if not getattr(session.auth, 'allow_cookies', False):
344+
session.cookies.set_policy(BlockAll())
345+
346+
if credentials is not None:
347+
warnings.warn(
348+
"The 'credentials' argument is now deprecated in favor of 'auth'.",
349+
DeprecationWarning
350+
)
351+
auth = DomainCredentials(credentials)
352+
if request_callback is not None or response_callback is not None:
353+
warnings.warn(
354+
"The 'request_callback' and 'response_callback' arguments are now deprecated. "
355+
"Use a custom 'session' instance instead.",
356+
DeprecationWarning
357+
)
358+
session.mount('https://', CallbackAdapter(request_callback, response_callback))
359+
session.mount('http://', CallbackAdapter(request_callback, response_callback))
360+
297361
self._headers = itypes.Dict(headers or {})
298362
self._session = session
299363

300-
# Fallback for v1.x overrides.
301-
# Will be removed at some point, most likely in a 2.1 release.
302-
self._request_callback = request_callback
303-
self._response_callback = response_callback
304-
305-
@property
306-
def credentials(self):
307-
return self._credentials
308-
309364
@property
310365
def headers(self):
311366
return self._headers
@@ -316,19 +371,11 @@ def transition(self, link, decoders, params=None, link_ancestors=None, force_cod
316371
encoding = _get_encoding(link.encoding)
317372
params = _get_params(method, encoding, link.fields, params)
318373
url = _get_url(link.url, params.path)
319-
headers = _get_headers(url, decoders, self.credentials)
374+
headers = _get_headers(url, decoders)
320375
headers.update(self.headers)
321376

322377
request = _build_http_request(session, url, method, headers, encoding, params)
323-
324-
if self._request_callback is not None:
325-
self._request_callback(request)
326-
327378
response = session.send(request)
328-
329-
if self._response_callback is not None:
330-
self._response_callback(response)
331-
332379
result = _decode_result(response, decoders, force_codec)
333380

334381
if isinstance(result, Document) and link_ancestors:

coreapi/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66
import tempfile
77

88

9+
def domain_matches(request, domain):
10+
"""
11+
Domain string matching against an outgoing request.
12+
Patterns starting with '*' indicate a wildcard domain.
13+
"""
14+
if (domain is None) or (domain == '*'):
15+
return True
16+
17+
host = urlparse.urlparse(request.url).hostname
18+
if domain.startswith('*'):
19+
return host.endswith(domain[1:])
20+
return host == domain
21+
22+
923
def get_installed_codecs():
1024
packages = [
1125
(package, package.load()) for package in

0 commit comments

Comments
 (0)