generated from coulomb/repo-seed
- audit: chmod only on file creation, not every append (TOCTOU fix) - jwt_utils: add extract_unverified_payload() helper - cli: use extract_unverified_payload + JWTError instead of inline decode - keys: extract _public_key_bytes() helper, import _b64url from jwt_utils - security: FileNotFoundError try/except instead of path.exists() (TOCTOU fix) - serve: cache JWK response at server init instead of per-request recompute 138 tests passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
"""
|
|
JWT creation and verification for local-identity OIDC serve.
|
|
|
|
Uses RS256 (RSA-PKCS1v15 + SHA-256) signing via the cryptography library.
|
|
No third-party JWT library is used — only stdlib base64/json and cryptography
|
|
primitives.
|
|
"""
|
|
|
|
import base64
|
|
import json
|
|
import time
|
|
import uuid
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Base64url helpers #
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def _b64url_encode(data: bytes) -> str:
|
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
|
|
|
|
|
def _b64url_decode(s: str) -> bytes:
|
|
pad = (4 - len(s) % 4) % 4
|
|
return base64.urlsafe_b64decode(s + "=" * pad)
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Token creation #
|
|
# ------------------------------------------------------------------ #
|
|
|
|
def create_token(
|
|
private_key: RSAPrivateKey,
|
|
kid: str,
|
|
sub: str,
|
|
iss: str,
|
|
aud: str,
|
|
email: str,
|
|
name: str,
|
|
preferred_username: str,
|
|
ttl: int = 3600,
|
|
nonce: str | None = None,
|
|
) -> str:
|
|
"""Create and sign a JWT with RS256. Returns the compact serialisation."""
|
|
now = int(time.time())
|
|
header = {"alg": "RS256", "typ": "JWT", "kid": kid}
|
|
payload: dict = {
|
|
"sub": sub,
|
|
"iss": iss,
|
|
"aud": aud,
|
|
"exp": now + ttl,
|
|
"iat": now,
|
|
"jti": str(uuid.uuid4()),
|
|
"email": email,
|
|
"name": name,
|
|
"preferred_username": preferred_username,
|
|
}
|
|
if nonce is not None:
|
|
payload["nonce"] = nonce
|
|
|
|
header_b64 = _b64url_encode(json.dumps(header, separators=(",", ":")).encode())
|
|
payload_b64 = _b64url_encode(json.dumps(payload, separators=(",", ":")).encode())
|
|
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
|
|
|
signature = private_key.sign(signing_input, padding.PKCS1v15(), hashes.SHA256())
|
|
return f"{header_b64}.{payload_b64}.{_b64url_encode(signature)}"
|
|
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Token verification #
|
|
# ------------------------------------------------------------------ #
|
|
|
|
class JWTError(Exception):
|
|
pass
|
|
|
|
|
|
def verify_token(token: str, public_key: RSAPublicKey) -> dict:
|
|
"""
|
|
Verify signature and expiry of a JWT. Returns the decoded payload.
|
|
Raises JWTError for any failure (malformed, bad signature, expired).
|
|
"""
|
|
parts = token.split(".")
|
|
if len(parts) != 3:
|
|
raise JWTError("malformed token: expected 3 parts")
|
|
|
|
header_b64, payload_b64, sig_b64 = parts
|
|
|
|
try:
|
|
header = json.loads(_b64url_decode(header_b64))
|
|
except Exception as exc:
|
|
raise JWTError(f"cannot decode header: {exc}") from exc
|
|
|
|
if header.get("alg") != "RS256":
|
|
raise JWTError(f"unsupported algorithm: {header.get('alg')!r}")
|
|
|
|
try:
|
|
payload = json.loads(_b64url_decode(payload_b64))
|
|
except Exception as exc:
|
|
raise JWTError(f"cannot decode payload: {exc}") from exc
|
|
|
|
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
|
try:
|
|
sig_bytes = _b64url_decode(sig_b64)
|
|
public_key.verify(sig_bytes, signing_input, padding.PKCS1v15(), hashes.SHA256())
|
|
except InvalidSignature as exc:
|
|
raise JWTError("invalid signature") from exc
|
|
|
|
if "exp" in payload and int(time.time()) > payload["exp"]:
|
|
raise JWTError("token has expired")
|
|
|
|
return payload
|
|
|
|
|
|
def extract_unverified_payload(token: str) -> dict:
|
|
"""
|
|
Decode the payload of a JWT without verifying the signature.
|
|
Raises JWTError if the token is malformed.
|
|
"""
|
|
parts = token.split(".")
|
|
if len(parts) != 3:
|
|
raise JWTError("malformed token: expected 3 parts")
|
|
try:
|
|
return json.loads(_b64url_decode(parts[1]))
|
|
except Exception as exc:
|
|
raise JWTError(f"cannot decode payload: {exc}") from exc
|