From f8aeaea8265b281314c160d9c53e575b1e71f1db Mon Sep 17 00:00:00 2001 From: Darrel O'Pry Date: Thu, 19 Oct 2023 22:22:18 -0400 Subject: [PATCH] fix: RedirectURIValidator Encapsulation --- CHANGELOG.md | 1 + oauth2_provider/models.py | 11 +- oauth2_provider/oauth2_validators.py | 1 - oauth2_provider/validators.py | 50 ++++++- tests/test_models.py | 2 +- tests/test_validators.py | 216 ++++++++++++++++++++++----- 6 files changed, 227 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a61a3ebdb..f516b64c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * #1322 Instructions in documentation on how to create a code challenge and code verifier * #1284 Allow to logout with no id_token_hint even if the browser session already expired * #1296 Added reverse function in migration 0006_alter_application_client_secret +* #1336 Fix encapsulation for Redirect URI scheme validation ## [2.3.0] 2023-05-31 diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 80d8f3487..661bd7dfc 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,7 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet +from .validators import AllowedURIValidator logger = logging.getLogger(__name__) @@ -202,12 +202,11 @@ def clean(self): allowed_schemes = set(s.lower() for s in self.get_allowed_schemes()) if redirect_uris: - validator = RedirectURIValidator(WildcardSet()) + validator = AllowedURIValidator( + allowed_schemes, name="redirect uri", allow_path=True, allow_query=True + ) for uri in redirect_uris: validator(uri) - scheme = urlparse(uri).scheme - if scheme not in allowed_schemes: - raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme)) elif self.authorization_grant_type in grant_types: raise ValidationError( @@ -218,7 +217,7 @@ def clean(self): allowed_origins = self.allowed_origins.strip().split() if allowed_origins: # oauthlib allows only https scheme for CORS - validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin") + validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin") for uri in allowed_origins: validator(uri) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 00497db9a..61238aef5 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -305,7 +305,6 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs): proceed only if the client exists and is not of type "Confidential". """ if self._load_application(client_id, request) is not None: - log.debug("Application %r has type %r" % (client_id, request.client.client_type)) return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL return False diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index df3d9e753..04c323ab7 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -1,4 +1,5 @@ import re +import warnings from urllib.parse import urlsplit from django.core.exceptions import ValidationError @@ -20,6 +21,7 @@ class URIValidator(URLValidator): class RedirectURIValidator(URIValidator): def __init__(self, allowed_schemes, allow_fragments=False): + warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning) super().__init__(schemes=allowed_schemes) self.allow_fragments = allow_fragments @@ -32,6 +34,8 @@ def __call__(self, value): class AllowedURIValidator(URIValidator): + # TODO: find a way to get these associated with their form fields in place of passing name + # TODO: submit PR to get `cause` included in the parent class ValidationError params` def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): """ :param schemes: List of allowed schemes. E.g.: ["https"] @@ -47,15 +51,47 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra self.allow_fragments = allow_fragments def __call__(self, value): - super().__call__(value) value = force_str(value) - scheme, netloc, path, query, fragment = urlsplit(value) + try: + scheme, netloc, path, query, fragment = urlsplit(value) + except ValueError as e: + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={"name": self.name, "value": value, "cause": e}, + ) + + # send better validation errors + if scheme not in self.schemes: + raise ValidationError( + "%(name)s URI Validation error. %(cause)s: %(value)s", + params={"name": self.name, "value": value, "cause": "invalid_scheme"}, + ) + if query and not self.allow_query: - raise ValidationError("{} URIs must not contain query".format(self.name)) + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ "name": self.name, "value": value, "cause": 'query string not allowed'} + ) if fragment and not self.allow_fragments: - raise ValidationError("{} URIs must not contain fragments".format(self.name)) + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ "name": self.name, "value": value, "cause": 'fragment not allowed'} + ) if path and not self.allow_path: - raise ValidationError("{} URIs must not contain path".format(self.name)) + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ "name": self.name, "value": value, "cause": 'path not allowed'} + ) + + try: + super().__call__(value) + except ValidationError as e: + raise ValidationError( + "%(name)s URI validation error. %(cause)s: %(value)s", + params={ "name": self.name, "value": value, "cause": e} + ) + + ## @@ -69,5 +105,9 @@ class WildcardSet(set): A set that always returns True on `in`. """ + def __init__(self, *args, **kwargs): + warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning) + super().__init__(*args, **kwargs) + def __contains__(self, item): return True diff --git a/tests/test_models.py b/tests/test_models.py index 8c62e2c99..5bcd7d6ba 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -591,7 +591,7 @@ def test_application_clean(oauth2_settings, application): application.allowed_origins = "http://example.com" with pytest.raises(ValidationError) as exc: application.clean() - assert "Enter a valid URL" in str(exc.value) + assert "allowed origin URI Validation error. invalid_scheme: http://example.com" in str(exc.value) application.allowed_origins = "https://example.com" application.clean() diff --git a/tests/test_validators.py b/tests/test_validators.py index 6cbc23172..66d746966 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,7 +2,7 @@ from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator +from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator, WildcardSet @pytest.mark.usefixtures("oauth2_settings") @@ -36,11 +36,6 @@ def test_validate_custom_uri_scheme(self): # Check ValidationError not thrown validator(uri) - validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin") - for uri in good_uris: - # Check ValidationError not thrown - validator(uri) - def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] @@ -67,47 +62,73 @@ def test_validate_bad_uris(self): with self.assertRaises(ValidationError): validator(uri) - def test_validate_good_origin_uris(self): - """ - Test AllowedURIValidator validates origin URIs if they match requirements - """ - validator = AllowedURIValidator( - ["https"], - "Origin", - allow_path=False, - allow_query=False, - allow_fragments=False, - ) + def test_validate_wildcard_scheme__bad_uris(self): + validator = RedirectURIValidator(allowed_schemes=WildcardSet()) + bad_uris = [ + "http:/example.com#fragment", + "HTTP://localhost#fragment", + "http://example.com/#fragment", + "good://example.com/#fragment", + " ", + "", + # Bad IPv6 URL, urlparse behaves differently for these + 'https://[">', + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError, msg=uri): + validator(uri) + + def test_validate_wildcard_scheme_good_uris(self): + validator = RedirectURIValidator(allowed_schemes=WildcardSet()) good_uris = [ + "my-scheme://example.com", + "my-scheme://example", + "my-scheme://localhost", "https://example.com", - "https://example.com:8080", - "https://example", - "https://localhost", - "https://1.1.1.1", - "https://127.0.0.1", - "https://255.255.255.255", + "HTTPS://example.com", + "HTTPS://example.com.", + "git+ssh://example.com", + "ANY://localhost", + "scheme://example.com", + "at://example.com", + "all://example.com", ] for uri in good_uris: # Check ValidationError not thrown validator(uri) - def test_validate_bad_origin_uris(self): - """ - Test AllowedURIValidator rejects origin URIs if they do not match requirements - """ - validator = AllowedURIValidator( - ["https"], - "Origin", - allow_path=False, - allow_query=False, - allow_fragments=False, - ) + +@pytest.mark.usefixtures("oauth2_settings") +class TestAllowedURIValidator(TestCase): + # TODO: verify the specifics of the ValidationErrors + def test_valid_schemes(self): + validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "test") + good_uris = [ + "my-scheme://example.com", + "my-scheme://example", + "my-scheme://localhost", + "https://example.com", + "HTTPS://example.com", + "git+ssh://example.com", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_invalid_schemes(self): + validator = AllowedURIValidator(["https"], "test") bad_uris = [ "http:/example.com", "HTTP://localhost", "HTTP://example.com", + "HTTP://-example.com-", # triggers an exception in the upstream validators + "HTTP://example.com/path", + "HTTP://example.com/path?query=string", + "HTTP://example.com/path?query=string#fragmemt", "HTTP://example.com.", - "http://example.com/#fragment", + "http://example.com/path/#fragment", + "http://example.com?query=string#fragment", "123://example.com", "http://fe80::1", "git+ssh://example.com", @@ -119,12 +140,125 @@ def test_validate_bad_origin_uris(self): "", # Bad IPv6 URL, urlparse behaves differently for these 'https://[">', - # Origin uri should not contain path, query of fragment parts - # https://www.rfc-editor.org/rfc/rfc6454#section-7.1 - "https://example.com/", - "https://example.com/test", - "https://example.com/?q=test", - "https://example.com/#test", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri) + + def test_allow_paths_valid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True) + good_uris = [ + "https://example.com", + "https://example.com:8080", + "https://example", + "https://example.com/path", + "https://example.com:8080/path", + "https://example/path", + "https://localhost/path", + "myapp://host/path", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_allow_paths_invalid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True) + bad_uris = [ + "https://example.com?query=string", + "https://example.com#fragment", + "https://example.com/path?query=string", + "https://example.com/path#fragment", + "https://example.com/path?query=string#fragment", + "myapp://example.com/path?query=string", + "myapp://example.com/path#fragment", + "myapp://example.com/path?query=string#fragment", + "bad://example.com/path", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri) + + def test_allow_query_valid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True) + good_uris = [ + "https://example.com", + "https://example.com:8080", + "https://example.com?query=string", + "https://example", + "myapp://example.com?query=string", + "myapp://example?query=string", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_allow_query_invalid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True) + bad_uris = [ + "https://example.com/path", + "https://example.com#fragment", + "https://example.com/path?query=string", + "https://example.com/path#fragment", + "https://example.com/path?query=string#fragment", + "https://example.com:8080/path", + "https://example/path", + "https://localhost/path", + "myapp://example.com/path?query=string", + "myapp://example.com/path#fragment", + "myapp://example.com/path?query=string#fragment", + "bad://example.com/path", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri) + + def test_allow_fragment_valid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True) + good_uris = [ + "https://example.com", + "https://example.com#fragment", + "https://example.com:8080", + "https://example.com:8080#fragment", + "https://example", + "https://example#fragment", + "myapp://example", + "myapp://example#fragment", + "myapp://example.com", + "myapp://example.com#fragment", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_allow_fragment_invalid_urls(self): + validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True) + bad_uris = [ + "https://example.com?query=string", + "https://example.com?query=string#fragment", + "https://example.com/path", + "https://example.com/path?query=string", + "https://example.com/path#fragment", + "https://example.com/path?query=string#fragment", + "https://example.com:8080/path", + "https://example?query=string", + "https://example?query=string#fragment", + "https://example/path", + "https://example/path?query=string", + "https://example/path#fragment", + "https://example/path?query=string#fragment", + "myapp://example?query=string", + "myapp://example?query=string#fragment", + "myapp://example/path", + "myapp://example/path?query=string", + "myapp://example/path#fragment", + "myapp://example.com/path?query=string", + "myapp://example.com/path#fragment", + "myapp://example.com/path?query=string#fragment", + "myapp://example.com?query=string", + "bad://example.com", ] for uri in bad_uris: