refactor(local-identity): post-Stage4 cleanups and micro-fixes

- audit: chmod only on file creation, not every append (TOCTOU fix)
- jwt_utils: add extract_unverified_payload() helper
- cli: use extract_unverified_payload + JWTError instead of inline decode
- keys: extract _public_key_bytes() helper, import _b64url from jwt_utils
- security: FileNotFoundError try/except instead of path.exists() (TOCTOU fix)
- serve: cache JWK response at server init instead of per-request recompute

138 tests passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-02 08:25:21 +01:00
parent 3890dca25d
commit 52d44daec2
6 changed files with 37 additions and 20 deletions

View File

@@ -33,8 +33,10 @@ def log_event(command: str, username: str | None, outcome: str) -> None:
entry = f"{timestamp}\t{command}\t{username or '-'}\t{outcome}\n" entry = f"{timestamp}\t{command}\t{username or '-'}\t{outcome}\n"
path = _audit_log_path() path = _audit_log_path()
try: try:
is_new = not path.exists()
with open(path, "a", encoding="utf-8") as fh: with open(path, "a", encoding="utf-8") as fh:
fh.write(entry) fh.write(entry)
os.chmod(path, 0o600) if is_new:
os.chmod(path, 0o600)
except OSError: except OSError:
pass pass

View File

@@ -20,11 +20,11 @@ Environment:
""" """
import argparse import argparse
import base64
import json import json
import sys import sys
from .gecos import current_username, get_gecos_fullname from .gecos import current_username, get_gecos_fullname
from .jwt_utils import JWTError, extract_unverified_payload
from .user import UserRecord, make_test_user from .user import UserRecord, make_test_user
from . import audit from . import audit
from . import export as export_mod from . import export as export_mod
@@ -163,16 +163,12 @@ def cmd_revoke_token(args: argparse.Namespace) -> None:
if token_or_jti.count(".") == 2: if token_or_jti.count(".") == 2:
# Looks like a JWT — extract the JTI from the payload # Looks like a JWT — extract the JTI from the payload
try: try:
payload_b64 = token_or_jti.split(".")[1] payload = extract_unverified_payload(token_or_jti)
pad = (4 - len(payload_b64) % 4) % 4
payload = json.loads(
base64.urlsafe_b64decode(payload_b64 + "=" * pad)
)
jti = payload.get("jti") jti = payload.get("jti")
if not jti: if not jti:
print("Error: JWT has no 'jti' claim.", file=sys.stderr) print("Error: JWT has no 'jti' claim.", file=sys.stderr)
sys.exit(1) sys.exit(1)
except Exception as exc: except JWTError as exc:
print(f"Error decoding JWT: {exc}", file=sys.stderr) print(f"Error decoding JWT: {exc}", file=sys.stderr)
sys.exit(1) sys.exit(1)
else: else:

View File

@@ -114,3 +114,17 @@ def verify_token(token: str, public_key: RSAPublicKey) -> dict:
raise JWTError("token has expired") raise JWTError("token has expired")
return payload return payload
def extract_unverified_payload(token: str) -> dict:
"""
Decode the payload of a JWT without verifying the signature.
Raises JWTError if the token is malformed.
"""
parts = token.split(".")
if len(parts) != 3:
raise JWTError("malformed token: expected 3 parts")
try:
return json.loads(_b64url_decode(parts[1]))
except Exception as exc:
raise JWTError(f"cannot decode payload: {exc}") from exc

View File

@@ -8,7 +8,6 @@ The corresponding public key is never stored separately — it is always
derived from the private key on load. derived from the private key on load.
""" """
import base64
import hashlib import hashlib
import os import os
from pathlib import Path from pathlib import Path
@@ -17,6 +16,7 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from .jwt_utils import _b64url_encode as _b64url
from .store import _store_dir from .store import _store_dir
@@ -48,27 +48,28 @@ def ensure_signing_key() -> RSAPrivateKey:
return private_key return private_key
def _b64url(data: bytes) -> str: def _public_key_bytes(private_key: RSAPrivateKey) -> tuple[bytes, bytes]:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") """Return (n_bytes, e_bytes) for the public key — shared by key_id and jwk_public."""
pub = private_key.public_key().public_numbers()
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
return n_bytes, e_bytes
def key_id(private_key: RSAPrivateKey) -> str: def key_id(private_key: RSAPrivateKey) -> str:
"""Return a stable 16-hex-char key ID derived from the public key modulus.""" """Return a stable 16-hex-char key ID derived from the public key modulus."""
n = private_key.public_key().public_numbers().n n_bytes, _ = _public_key_bytes(private_key)
n_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big")
return hashlib.sha256(n_bytes).hexdigest()[:16] return hashlib.sha256(n_bytes).hexdigest()[:16]
def jwk_public(private_key: RSAPrivateKey) -> dict: def jwk_public(private_key: RSAPrivateKey) -> dict:
"""Return the RSA public key as a JWK dict (RS256, sig use).""" """Return the RSA public key as a JWK dict (RS256, sig use)."""
pub = private_key.public_key().public_numbers() n_bytes, e_bytes = _public_key_bytes(private_key)
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
return { return {
"kty": "RSA", "kty": "RSA",
"use": "sig", "use": "sig",
"alg": "RS256", "alg": "RS256",
"kid": key_id(private_key), "kid": hashlib.sha256(n_bytes).hexdigest()[:16],
"n": _b64url(n_bytes), "n": _b64url(n_bytes),
"e": _b64url(e_bytes), "e": _b64url(e_bytes),
} }

View File

@@ -24,9 +24,10 @@ class CheckResult:
def _check_mode(path: Path, expected: int) -> CheckResult: def _check_mode(path: Path, expected: int) -> CheckResult:
"""Return a CheckResult for a single path against the expected mode.""" """Return a CheckResult for a single path against the expected mode."""
if not path.exists(): try:
actual = stat.S_IMODE(os.stat(path).st_mode)
except FileNotFoundError:
return CheckResult(str(path), "warn", "does not exist (skipped)") return CheckResult(str(path), "warn", "does not exist (skipped)")
actual = stat.S_IMODE(os.stat(path).st_mode)
if actual != expected: if actual != expected:
return CheckResult( return CheckResult(
str(path), "fail", str(path), "fail",

View File

@@ -78,6 +78,7 @@ def make_handler(
so multiple test servers can run concurrently without sharing state. so multiple test servers can run concurrently without sharing state.
""" """
codes: dict = {} codes: dict = {}
jwks_response = {"keys": [jwk_public(private_key)]}
class _Handler(OIDCHandler): class _Handler(OIDCHandler):
_private_key = private_key _private_key = private_key
@@ -85,6 +86,7 @@ def make_handler(
_token_ttl = token_ttl _token_ttl = token_ttl
_codes = codes _codes = codes
_scheme = scheme _scheme = scheme
_jwks = jwks_response
return _Handler return _Handler
@@ -100,6 +102,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
_token_ttl: int = 3600 _token_ttl: int = 3600
_codes: dict = {} _codes: dict = {}
_scheme: str = "https" _scheme: str = "https"
_jwks: dict = {}
def log_message(self, fmt: str, *args) -> None: def log_message(self, fmt: str, *args) -> None:
pass # silence default Apache-style logging pass # silence default Apache-style logging
@@ -161,7 +164,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
# ---------------------------------------------------------------- # # ---------------------------------------------------------------- #
def _handle_jwks(self) -> None: def _handle_jwks(self) -> None:
self._send_json({"keys": [jwk_public(self._private_key)]}) self._send_json(self._jwks)
# ---------------------------------------------------------------- # # ---------------------------------------------------------------- #
# Endpoint: GET /auth — display login form # # Endpoint: GET /auth — display login form #