refactor(local-identity): post-Stage4 cleanups and micro-fixes

- audit: chmod only on file creation, not every append (TOCTOU fix)
- jwt_utils: add extract_unverified_payload() helper
- cli: use extract_unverified_payload + JWTError instead of inline decode
- keys: extract _public_key_bytes() helper, import _b64url from jwt_utils
- security: FileNotFoundError try/except instead of path.exists() (TOCTOU fix)
- serve: cache JWK response at server init instead of per-request recompute

138 tests passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-02 08:25:21 +01:00
parent 3890dca25d
commit 52d44daec2
6 changed files with 37 additions and 20 deletions

View File

@@ -33,8 +33,10 @@ def log_event(command: str, username: str | None, outcome: str) -> None:
entry = f"{timestamp}\t{command}\t{username or '-'}\t{outcome}\n"
path = _audit_log_path()
try:
is_new = not path.exists()
with open(path, "a", encoding="utf-8") as fh:
fh.write(entry)
os.chmod(path, 0o600)
if is_new:
os.chmod(path, 0o600)
except OSError:
pass

View File

@@ -20,11 +20,11 @@ Environment:
"""
import argparse
import base64
import json
import sys
from .gecos import current_username, get_gecos_fullname
from .jwt_utils import JWTError, extract_unverified_payload
from .user import UserRecord, make_test_user
from . import audit
from . import export as export_mod
@@ -163,16 +163,12 @@ def cmd_revoke_token(args: argparse.Namespace) -> None:
if token_or_jti.count(".") == 2:
# Looks like a JWT — extract the JTI from the payload
try:
payload_b64 = token_or_jti.split(".")[1]
pad = (4 - len(payload_b64) % 4) % 4
payload = json.loads(
base64.urlsafe_b64decode(payload_b64 + "=" * pad)
)
payload = extract_unverified_payload(token_or_jti)
jti = payload.get("jti")
if not jti:
print("Error: JWT has no 'jti' claim.", file=sys.stderr)
sys.exit(1)
except Exception as exc:
except JWTError as exc:
print(f"Error decoding JWT: {exc}", file=sys.stderr)
sys.exit(1)
else:

View File

@@ -114,3 +114,17 @@ def verify_token(token: str, public_key: RSAPublicKey) -> dict:
raise JWTError("token has expired")
return payload
def extract_unverified_payload(token: str) -> dict:
"""
Decode the payload of a JWT without verifying the signature.
Raises JWTError if the token is malformed.
"""
parts = token.split(".")
if len(parts) != 3:
raise JWTError("malformed token: expected 3 parts")
try:
return json.loads(_b64url_decode(parts[1]))
except Exception as exc:
raise JWTError(f"cannot decode payload: {exc}") from exc

View File

@@ -8,7 +8,6 @@ The corresponding public key is never stored separately — it is always
derived from the private key on load.
"""
import base64
import hashlib
import os
from pathlib import Path
@@ -17,6 +16,7 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from .jwt_utils import _b64url_encode as _b64url
from .store import _store_dir
@@ -48,27 +48,28 @@ def ensure_signing_key() -> RSAPrivateKey:
return private_key
def _b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def _public_key_bytes(private_key: RSAPrivateKey) -> tuple[bytes, bytes]:
"""Return (n_bytes, e_bytes) for the public key — shared by key_id and jwk_public."""
pub = private_key.public_key().public_numbers()
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
return n_bytes, e_bytes
def key_id(private_key: RSAPrivateKey) -> str:
"""Return a stable 16-hex-char key ID derived from the public key modulus."""
n = private_key.public_key().public_numbers().n
n_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big")
n_bytes, _ = _public_key_bytes(private_key)
return hashlib.sha256(n_bytes).hexdigest()[:16]
def jwk_public(private_key: RSAPrivateKey) -> dict:
"""Return the RSA public key as a JWK dict (RS256, sig use)."""
pub = private_key.public_key().public_numbers()
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
n_bytes, e_bytes = _public_key_bytes(private_key)
return {
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": key_id(private_key),
"kid": hashlib.sha256(n_bytes).hexdigest()[:16],
"n": _b64url(n_bytes),
"e": _b64url(e_bytes),
}

View File

@@ -24,9 +24,10 @@ class CheckResult:
def _check_mode(path: Path, expected: int) -> CheckResult:
"""Return a CheckResult for a single path against the expected mode."""
if not path.exists():
try:
actual = stat.S_IMODE(os.stat(path).st_mode)
except FileNotFoundError:
return CheckResult(str(path), "warn", "does not exist (skipped)")
actual = stat.S_IMODE(os.stat(path).st_mode)
if actual != expected:
return CheckResult(
str(path), "fail",

View File

@@ -78,6 +78,7 @@ def make_handler(
so multiple test servers can run concurrently without sharing state.
"""
codes: dict = {}
jwks_response = {"keys": [jwk_public(private_key)]}
class _Handler(OIDCHandler):
_private_key = private_key
@@ -85,6 +86,7 @@ def make_handler(
_token_ttl = token_ttl
_codes = codes
_scheme = scheme
_jwks = jwks_response
return _Handler
@@ -100,6 +102,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
_token_ttl: int = 3600
_codes: dict = {}
_scheme: str = "https"
_jwks: dict = {}
def log_message(self, fmt: str, *args) -> None:
pass # silence default Apache-style logging
@@ -161,7 +164,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
# ---------------------------------------------------------------- #
def _handle_jwks(self) -> None:
self._send_json({"keys": [jwk_public(self._private_key)]})
self._send_json(self._jwks)
# ---------------------------------------------------------------- #
# Endpoint: GET /auth — display login form #