Files
net-kingdom/local-identity/src/local_identity/jwt_utils.py
tegwick 52d44daec2 refactor(local-identity): post-Stage4 cleanups and micro-fixes
- audit: chmod only on file creation, not every append (TOCTOU fix)
- jwt_utils: add extract_unverified_payload() helper
- cli: use extract_unverified_payload + JWTError instead of inline decode
- keys: extract _public_key_bytes() helper, import _b64url from jwt_utils
- security: FileNotFoundError try/except instead of path.exists() (TOCTOU fix)
- serve: cache JWK response at server init instead of per-request recompute

138 tests passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 08:25:21 +01:00

131 lines
4.2 KiB
Python

"""
JWT creation and verification for local-identity OIDC serve.
Uses RS256 (RSA-PKCS1v15 + SHA-256) signing via the cryptography library.
No third-party JWT library is used — only stdlib base64/json and cryptography
primitives.
"""
import base64
import json
import time
import uuid
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
# ------------------------------------------------------------------ #
# Base64url helpers #
# ------------------------------------------------------------------ #
def _b64url_encode(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def _b64url_decode(s: str) -> bytes:
pad = (4 - len(s) % 4) % 4
return base64.urlsafe_b64decode(s + "=" * pad)
# ------------------------------------------------------------------ #
# Token creation #
# ------------------------------------------------------------------ #
def create_token(
private_key: RSAPrivateKey,
kid: str,
sub: str,
iss: str,
aud: str,
email: str,
name: str,
preferred_username: str,
ttl: int = 3600,
nonce: str | None = None,
) -> str:
"""Create and sign a JWT with RS256. Returns the compact serialisation."""
now = int(time.time())
header = {"alg": "RS256", "typ": "JWT", "kid": kid}
payload: dict = {
"sub": sub,
"iss": iss,
"aud": aud,
"exp": now + ttl,
"iat": now,
"jti": str(uuid.uuid4()),
"email": email,
"name": name,
"preferred_username": preferred_username,
}
if nonce is not None:
payload["nonce"] = nonce
header_b64 = _b64url_encode(json.dumps(header, separators=(",", ":")).encode())
payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
signature = private_key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256())
return f"{header_b64}.{payload_b64}.{_b64url_encode(signature)}"
# ------------------------------------------------------------------ #
# Token verification #
# ------------------------------------------------------------------ #
class JWTError(Exception):
pass
def verify_token(token: str, public_key: RSAPublicKey) -> dict:
"""
Verify signature and expiry of a JWT. Returns the decoded payload.
Raises JWTError for any failure (malformed, bad signature, expired).
"""
parts = token.split(".")
if len(parts) != 3:
raise JWTError("malformed token: expected 3 parts")
header_b64, payload_b64, sig_b64 = parts
try:
header = json.loads(_b64url_decode(header_b64))
except Exception as exc:
raise JWTError(f"cannot decode header: {exc}") from exc
if header.get("alg") != "RS256":
raise JWTError(f"unsupported algorithm: {header.get('alg')!r}")
try:
payload = json.loads(_b64url_decode(payload_b64))
except Exception as exc:
raise JWTError(f"cannot decode payload: {exc}") from exc
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
try:
sig_bytes = _b64url_decode(sig_b64)
public_key.verify(sig_bytes, signing_input, padding.PKCS1v15(), hashes.SHA256())
except InvalidSignature as exc:
raise JWTError("invalid signature") from exc
if "exp" in payload and int(time.time()) > payload["exp"]:
raise JWTError("token has expired")
return payload
def extract_unverified_payload(token: str) -> dict:
"""
Decode the payload of a JWT without verifying the signature.
Raises JWTError if the token is malformed.
"""
parts = token.split(".")
if len(parts) != 3:
raise JWTError("malformed token: expected 3 parts")
try:
return json.loads(_b64url_decode(parts[1]))
except Exception as exc:
raise JWTError(f"cannot decode payload: {exc}") from exc