""" JWT creation and verification for local-identity OIDC serve. Uses RS256 (RSA-PKCS1v15 + SHA-256) signing via the cryptography library. No third-party JWT library is used — only stdlib base64/json and cryptography primitives. """ import base64 import json import time import uuid from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey # ------------------------------------------------------------------ # # Base64url helpers # # ------------------------------------------------------------------ # def _b64url_encode(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") def _b64url_decode(s: str) -> bytes: pad = (4 - len(s) % 4) % 4 return base64.urlsafe_b64decode(s + "=" * pad) # ------------------------------------------------------------------ # # Token creation # # ------------------------------------------------------------------ # def create_token( private_key: RSAPrivateKey, kid: str, sub: str, iss: str, aud: str, email: str, name: str, preferred_username: str, ttl: int = 3600, nonce: str | None = None, ) -> str: """Create and sign a JWT with RS256. Returns the compact serialisation.""" now = int(time.time()) header = {"alg": "RS256", "typ": "JWT", "kid": kid} payload: dict = { "sub": sub, "iss": iss, "aud": aud, "exp": now + ttl, "iat": now, "jti": str(uuid.uuid4()), "email": email, "name": name, "preferred_username": preferred_username, } if nonce is not None: payload["nonce"] = nonce header_b64 = _b64url_encode(json.dumps(header, separators=(",", ":")).encode()) payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode()) signing_input = f"{header_b64}.{payload_b64}".encode("ascii") signature = private_key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256()) return f"{header_b64}.{payload_b64}.{_b64url_encode(signature)}" # ------------------------------------------------------------------ # # Token verification # # ------------------------------------------------------------------ # class JWTError(Exception): pass def verify_token(token: str, public_key: RSAPublicKey) -> dict: """ Verify signature and expiry of a JWT. Returns the decoded payload. Raises JWTError for any failure (malformed, bad signature, expired). """ parts = token.split(".") if len(parts) != 3: raise JWTError("malformed token: expected 3 parts") header_b64, payload_b64, sig_b64 = parts try: header = json.loads(_b64url_decode(header_b64)) except Exception as exc: raise JWTError(f"cannot decode header: {exc}") from exc if header.get("alg") != "RS256": raise JWTError(f"unsupported algorithm: {header.get('alg')!r}") try: payload = json.loads(_b64url_decode(payload_b64)) except Exception as exc: raise JWTError(f"cannot decode payload: {exc}") from exc signing_input = f"{header_b64}.{payload_b64}".encode("ascii") try: sig_bytes = _b64url_decode(sig_b64) public_key.verify(sig_bytes, signing_input, padding.PKCS1v15(), hashes.SHA256()) except InvalidSignature as exc: raise JWTError("invalid signature") from exc if "exp" in payload and int(time.time()) > payload["exp"]: raise JWTError("token has expired") return payload def extract_unverified_payload(token: str) -> dict: """ Decode the payload of a JWT without verifying the signature. Raises JWTError if the token is malformed. """ parts = token.split(".") if len(parts) != 3: raise JWTError("malformed token: expected 3 parts") try: return json.loads(_b64url_decode(parts[1])) except Exception as exc: raise JWTError(f"cannot decode payload: {exc}") from exc