Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate sub and jti claims for the token #1005

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ target/
.mypy_cache
pip-wheel-metadata/
.venv/


.idea
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changed
jwt.encode({"payload":"abc"}, key=None, algorithm='none')
```

- Added validation for 'sub' (subject) and 'jti' (JWT ID) claims in tokens

Fixed
~~~~~

Expand Down
56 changes: 55 additions & 1 deletion jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from .warnings import RemovedInPyjwt3Warning
Expand All @@ -39,6 +41,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
"verify_iat": True,
"verify_aud": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"require": [],
}

Expand Down Expand Up @@ -112,6 +116,7 @@ def decode_complete(
# consider putting in options
audience: str | Iterable[str] | None = None,
issuer: str | Sequence[str] | None = None,
subject: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs: Any,
Expand Down Expand Up @@ -145,6 +150,8 @@ def decode_complete(
options.setdefault("verify_iat", False)
options.setdefault("verify_aud", False)
options.setdefault("verify_iss", False)
options.setdefault("verify_sub", False)
options.setdefault("verify_jti", False)

decoded = api_jws.decode_complete(
jwt,
Expand All @@ -158,7 +165,12 @@ def decode_complete(

merged_options = {**self.options, **options}
self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
payload,
merged_options,
audience=audience,
issuer=issuer,
leeway=leeway,
subject=subject,
)

decoded["payload"] = payload
Expand Down Expand Up @@ -193,6 +205,7 @@ def decode(
# passthrough arguments to _validate_claims
# consider putting in options
audience: str | Iterable[str] | None = None,
subject: str | None = None,
issuer: str | Sequence[str] | None = None,
leeway: float | timedelta = 0,
# kwargs
Expand All @@ -214,6 +227,7 @@ def decode(
verify=verify,
detached_payload=detached_payload,
audience=audience,
subject=subject,
issuer=issuer,
leeway=leeway,
)
Expand All @@ -225,6 +239,7 @@ def _validate_claims(
options: dict[str, Any],
audience=None,
issuer=None,
subject: str | None = None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
Expand Down Expand Up @@ -254,6 +269,12 @@ def _validate_claims(
payload, audience, strict=options.get("strict_aud", False)
)

if options["verify_sub"]:
self._validate_sub(payload, subject)

if options["verify_jti"]:
self._validate_jti(payload)

def _validate_required_claims(
self,
payload: dict[str, Any],
Expand All @@ -263,6 +284,39 @@ def _validate_required_claims(
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)

def _validate_sub(self, payload: dict[str, Any], subject=None) -> None:
"""
Checks whether "sub" if in the payload is valid ot not.
This is an Optional claim

:param payload(dict): The payload which needs to be validated
:param subject(str): The subject of the token
"""

if "sub" not in payload:
return

if not isinstance(payload["sub"], str):
raise InvalidSubjectError("Subject must be a string")

if subject is not None:
if payload.get("sub") != subject:
raise InvalidSubjectError("Invalid subject")

def _validate_jti(self, payload: dict[str, Any]) -> None:
"""
Checks whether "jti" if in the payload is valid ot not
This is an Optional claim

:param payload(dict): The payload which needs to be validated
"""

if "jti" not in payload:
return

if not isinstance(payload.get("jti"), str):
raise InvalidJTIError("JWT ID must be a string")

def _validate_iat(
self,
payload: dict[str, Any],
Expand Down
8 changes: 8 additions & 0 deletions jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ class PyJWKClientError(PyJWTError):

class PyJWKClientConnectionError(PyJWKClientError):
pass


class InvalidSubjectError(InvalidTokenError):
pass


class InvalidJTIError(InvalidTokenError):
pass
120 changes: 120 additions & 0 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
InvalidAudienceError,
InvalidIssuedAtError,
InvalidIssuerError,
InvalidJTIError,
InvalidSubjectError,
MissingRequiredClaimError,
)
from jwt.utils import base64url_decode
Expand Down Expand Up @@ -816,3 +818,121 @@ def test_decode_strict_ok(self, jwt, payload):
options={"strict_aud": True},
algorithms=["HS256"],
)

# -------------------- Sub Claim Tests --------------------

def test_encode_decode_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["sub"] == "user123"

def test_decode_without_and_not_required_sub_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert "sub" not in decoded

def test_decode_missing_sub_but_required_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError):
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["sub"]}
)

def test_decode_invalid_int_sub_claim(self, jwt):
payload = {
"sub": 1224344,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError):
jwt.decode(token, secret, algorithms=["HS256"])

def test_decode_with_valid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject="user123")

assert decoded["sub"] == "user123"

def test_decode_with_invalid_sub_claim(self, jwt):
payload = {
"sub": "user123",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidSubjectError) as exc_info:
jwt.decode(token, secret, algorithms=["HS256"], subject="user456")

assert "Invalid subject" in str(exc_info.value)

def test_decode_with_sub_claim_and_none_subject(self, jwt):
payload = {
"sub": "user789",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"], subject=None)
assert decoded["sub"] == "user789"

# -------------------- JTI Claim Tests --------------------

def test_encode_decode_with_valid_jti_claim(self, jwt):
payload = {
"jti": "unique-id-456",
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")
decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded["jti"] == "unique-id-456"

def test_decode_missing_jti_when_required_claim(self, jwt):
payload = {"name": "Bob", "admin": False}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(MissingRequiredClaimError) as exc_info:
jwt.decode(
token, secret, algorithms=["HS256"], options={"require": ["jti"]}
)

assert "jti" in str(exc_info.value)

def test_decode_missing_jti_claim(self, jwt):
payload = {}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

decoded = jwt.decode(token, secret, algorithms=["HS256"])

assert decoded.get("jti") is None

def test_jti_claim_with_invalid_int_value(self, jwt):
special_jti = 12223
payload = {
"jti": special_jti,
}
secret = "your-256-bit-secret"
token = jwt.encode(payload, secret, algorithm="HS256")

with pytest.raises(InvalidJTIError):
jwt.decode(token, secret, algorithms=["HS256"])
Loading