Skip to content

Commit

Permalink
fix: RedirectURIValidator Encapsulation
Browse files Browse the repository at this point in the history
  • Loading branch information
dopry committed Oct 20, 2023
1 parent 2c83e6c commit f8aeaea
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* #1322 Instructions in documentation on how to create a code challenge and code verifier
* #1284 Allow to logout with no id_token_hint even if the browser session already expired
* #1296 Added reverse function in migration 0006_alter_application_client_secret
* #1336 Fix encapsulation for Redirect URI scheme validation

## [2.3.0] 2023-05-31

Expand Down
11 changes: 5 additions & 6 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .scopes import get_scopes_backend
from .settings import oauth2_settings
from .utils import jwk_from_pem
from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet
from .validators import AllowedURIValidator


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -202,12 +202,11 @@ def clean(self):
allowed_schemes = set(s.lower() for s in self.get_allowed_schemes())

if redirect_uris:
validator = RedirectURIValidator(WildcardSet())
validator = AllowedURIValidator(
allowed_schemes, name="redirect uri", allow_path=True, allow_query=True
)
for uri in redirect_uris:
validator(uri)
scheme = urlparse(uri).scheme
if scheme not in allowed_schemes:
raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme))

elif self.authorization_grant_type in grant_types:
raise ValidationError(
Expand All @@ -218,7 +217,7 @@ def clean(self):
allowed_origins = self.allowed_origins.strip().split()
if allowed_origins:
# oauthlib allows only https scheme for CORS
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin")
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin")
for uri in allowed_origins:
validator(uri)

Expand Down
1 change: 0 additions & 1 deletion oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs):
proceed only if the client exists and is not of type "Confidential".
"""
if self._load_application(client_id, request) is not None:
log.debug("Application %r has type %r" % (client_id, request.client.client_type))
return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL
return False

Expand Down
50 changes: 45 additions & 5 deletions oauth2_provider/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from urllib.parse import urlsplit

from django.core.exceptions import ValidationError
Expand All @@ -20,6 +21,7 @@ class URIValidator(URLValidator):

class RedirectURIValidator(URIValidator):
def __init__(self, allowed_schemes, allow_fragments=False):
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
super().__init__(schemes=allowed_schemes)
self.allow_fragments = allow_fragments

Expand All @@ -32,6 +34,8 @@ def __call__(self, value):


class AllowedURIValidator(URIValidator):
# TODO: find a way to get these associated with their form fields in place of passing name
# TODO: submit PR to get `cause` included in the parent class ValidationError params`
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
"""
:param schemes: List of allowed schemes. E.g.: ["https"]
Expand All @@ -47,15 +51,47 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra
self.allow_fragments = allow_fragments

def __call__(self, value):
super().__call__(value)
value = force_str(value)
scheme, netloc, path, query, fragment = urlsplit(value)
try:
scheme, netloc, path, query, fragment = urlsplit(value)
except ValueError as e:
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": e},
)

# send better validation errors
if scheme not in self.schemes:
raise ValidationError(
"%(name)s URI Validation error. %(cause)s: %(value)s",
params={"name": self.name, "value": value, "cause": "invalid_scheme"},
)

if query and not self.allow_query:
raise ValidationError("{} URIs must not contain query".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={ "name": self.name, "value": value, "cause": 'query string not allowed'}
)
if fragment and not self.allow_fragments:
raise ValidationError("{} URIs must not contain fragments".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={ "name": self.name, "value": value, "cause": 'fragment not allowed'}
)
if path and not self.allow_path:
raise ValidationError("{} URIs must not contain path".format(self.name))
raise ValidationError(
"%(name)s URI validation error. %(cause)s: %(value)s",
params={ "name": self.name, "value": value, "cause": 'path not allowed'}
)

try:
super().__call__(value)
except ValidationError as e:
raise ValidationError(

Check warning on line 89 in oauth2_provider/validators.py

View check run for this annotation

Codecov / codecov/patch

oauth2_provider/validators.py#L88-L89

Added lines #L88 - L89 were not covered by tests
"%(name)s URI validation error. %(cause)s: %(value)s",
params={ "name": self.name, "value": value, "cause": e}
)




##
Expand All @@ -69,5 +105,9 @@ class WildcardSet(set):
A set that always returns True on `in`.
"""

def __init__(self, *args, **kwargs):
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
super().__init__(*args, **kwargs)

def __contains__(self, item):
return True
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def test_application_clean(oauth2_settings, application):
application.allowed_origins = "http://example.com"
with pytest.raises(ValidationError) as exc:
application.clean()
assert "Enter a valid URL" in str(exc.value)
assert "allowed origin URI Validation error. invalid_scheme: http://example.com" in str(exc.value)
application.allowed_origins = "https://example.com"
application.clean()

Expand Down
Loading

0 comments on commit f8aeaea

Please sign in to comment.