Skip to content

Commit

Permalink
Encoding EC keys with a fixed bit length
Browse files Browse the repository at this point in the history
  • Loading branch information
way-dave committed Oct 10, 2024
1 parent 6c7cc61 commit b834e89
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
13 changes: 10 additions & 3 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
"y": to_base64url_uint(public_numbers.y).decode(),
"x": to_base64url_uint(
public_numbers.x,
bit_length=key_obj.curve.key_size,
).decode(),
"y": to_base64url_uint(
public_numbers.y,
bit_length=key_obj.curve.key_size,
).decode(),
}

if isinstance(key_obj, EllipticCurvePrivateKey):
obj["d"] = to_base64url_uint(
key_obj.private_numbers().private_value
key_obj.private_numbers().private_value,
bit_length=key_obj.curve.key_size,
).decode()

if as_dict:
Expand Down
15 changes: 6 additions & 9 deletions jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")


def to_base64url_uint(val: int) -> bytes:
def to_base64url_uint(val: int, *, bit_length: int | None = None) -> bytes:
if val < 0:
raise ValueError("Must be a positive integer")

int_bytes = bytes_from_int(val)
int_bytes = bytes_from_int(val, bit_length=bit_length)

if len(int_bytes) == 0:
int_bytes = b"\x00"
Expand All @@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int:
return int(binascii.b2a_hex(string), 16)


def bytes_from_int(val: int) -> bytes:
remaining = val
byte_length = 0

while remaining != 0:
remaining >>= 8
byte_length += 1
def bytes_from_int(val: int, *, bit_length: int | None = None) -> bytes:
if bit_length is None:
bit_length = val.bit_length()
byte_length = (bit_length + 7) // 8

return val.to_bytes(byte_length, "big", signed=False)

Expand Down

0 comments on commit b834e89

Please sign in to comment.