generated from coulomb/repo-seed
refactor(local-identity): post-Stage4 cleanups and micro-fixes
- audit: chmod only on file creation, not every append (TOCTOU fix) - jwt_utils: add extract_unverified_payload() helper - cli: use extract_unverified_payload + JWTError instead of inline decode - keys: extract _public_key_bytes() helper, import _b64url from jwt_utils - security: FileNotFoundError try/except instead of path.exists() (TOCTOU fix) - serve: cache JWK response at server init instead of per-request recompute 138 tests passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -33,8 +33,10 @@ def log_event(command: str, username: str | None, outcome: str) -> None:
|
||||
entry = f"{timestamp}\t{command}\t{username or '-'}\t{outcome}\n"
|
||||
path = _audit_log_path()
|
||||
try:
|
||||
is_new = not path.exists()
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(entry)
|
||||
os.chmod(path, 0o600)
|
||||
if is_new:
|
||||
os.chmod(path, 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -20,11 +20,11 @@ Environment:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
|
||||
from .gecos import current_username, get_gecos_fullname
|
||||
from .jwt_utils import JWTError, extract_unverified_payload
|
||||
from .user import UserRecord, make_test_user
|
||||
from . import audit
|
||||
from . import export as export_mod
|
||||
@@ -163,16 +163,12 @@ def cmd_revoke_token(args: argparse.Namespace) -> None:
|
||||
if token_or_jti.count(".") == 2:
|
||||
# Looks like a JWT — extract the JTI from the payload
|
||||
try:
|
||||
payload_b64 = token_or_jti.split(".")[1]
|
||||
pad = (4 - len(payload_b64) % 4) % 4
|
||||
payload = json.loads(
|
||||
base64.urlsafe_b64decode(payload_b64 + "=" * pad)
|
||||
)
|
||||
payload = extract_unverified_payload(token_or_jti)
|
||||
jti = payload.get("jti")
|
||||
if not jti:
|
||||
print("Error: JWT has no 'jti' claim.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as exc:
|
||||
except JWTError as exc:
|
||||
print(f"Error decoding JWT: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
|
||||
@@ -114,3 +114,17 @@ def verify_token(token: str, public_key: RSAPublicKey) -> dict:
|
||||
raise JWTError("token has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def extract_unverified_payload(token: str) -> dict:
|
||||
"""
|
||||
Decode the payload of a JWT without verifying the signature.
|
||||
Raises JWTError if the token is malformed.
|
||||
"""
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
raise JWTError("malformed token: expected 3 parts")
|
||||
try:
|
||||
return json.loads(_b64url_decode(parts[1]))
|
||||
except Exception as exc:
|
||||
raise JWTError(f"cannot decode payload: {exc}") from exc
|
||||
|
||||
@@ -8,7 +8,6 @@ The corresponding public key is never stored separately — it is always
|
||||
derived from the private key on load.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -17,6 +16,7 @@ from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
|
||||
from .jwt_utils import _b64url_encode as _b64url
|
||||
from .store import _store_dir
|
||||
|
||||
|
||||
@@ -48,27 +48,28 @@ def ensure_signing_key() -> RSAPrivateKey:
|
||||
return private_key
|
||||
|
||||
|
||||
def _b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
def _public_key_bytes(private_key: RSAPrivateKey) -> tuple[bytes, bytes]:
|
||||
"""Return (n_bytes, e_bytes) for the public key — shared by key_id and jwk_public."""
|
||||
pub = private_key.public_key().public_numbers()
|
||||
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
|
||||
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
|
||||
return n_bytes, e_bytes
|
||||
|
||||
|
||||
def key_id(private_key: RSAPrivateKey) -> str:
|
||||
"""Return a stable 16-hex-char key ID derived from the public key modulus."""
|
||||
n = private_key.public_key().public_numbers().n
|
||||
n_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big")
|
||||
n_bytes, _ = _public_key_bytes(private_key)
|
||||
return hashlib.sha256(n_bytes).hexdigest()[:16]
|
||||
|
||||
|
||||
def jwk_public(private_key: RSAPrivateKey) -> dict:
|
||||
"""Return the RSA public key as a JWK dict (RS256, sig use)."""
|
||||
pub = private_key.public_key().public_numbers()
|
||||
n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big")
|
||||
e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big")
|
||||
n_bytes, e_bytes = _public_key_bytes(private_key)
|
||||
return {
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"kid": key_id(private_key),
|
||||
"kid": hashlib.sha256(n_bytes).hexdigest()[:16],
|
||||
"n": _b64url(n_bytes),
|
||||
"e": _b64url(e_bytes),
|
||||
}
|
||||
|
||||
@@ -24,9 +24,10 @@ class CheckResult:
|
||||
|
||||
def _check_mode(path: Path, expected: int) -> CheckResult:
|
||||
"""Return a CheckResult for a single path against the expected mode."""
|
||||
if not path.exists():
|
||||
try:
|
||||
actual = stat.S_IMODE(os.stat(path).st_mode)
|
||||
except FileNotFoundError:
|
||||
return CheckResult(str(path), "warn", "does not exist (skipped)")
|
||||
actual = stat.S_IMODE(os.stat(path).st_mode)
|
||||
if actual != expected:
|
||||
return CheckResult(
|
||||
str(path), "fail",
|
||||
|
||||
@@ -78,6 +78,7 @@ def make_handler(
|
||||
so multiple test servers can run concurrently without sharing state.
|
||||
"""
|
||||
codes: dict = {}
|
||||
jwks_response = {"keys": [jwk_public(private_key)]}
|
||||
|
||||
class _Handler(OIDCHandler):
|
||||
_private_key = private_key
|
||||
@@ -85,6 +86,7 @@ def make_handler(
|
||||
_token_ttl = token_ttl
|
||||
_codes = codes
|
||||
_scheme = scheme
|
||||
_jwks = jwks_response
|
||||
|
||||
return _Handler
|
||||
|
||||
@@ -100,6 +102,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
|
||||
_token_ttl: int = 3600
|
||||
_codes: dict = {}
|
||||
_scheme: str = "https"
|
||||
_jwks: dict = {}
|
||||
|
||||
def log_message(self, fmt: str, *args) -> None:
|
||||
pass # silence default Apache-style logging
|
||||
@@ -161,7 +164,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler):
|
||||
# ---------------------------------------------------------------- #
|
||||
|
||||
def _handle_jwks(self) -> None:
|
||||
self._send_json({"keys": [jwk_public(self._private_key)]})
|
||||
self._send_json(self._jwks)
|
||||
|
||||
# ---------------------------------------------------------------- #
|
||||
# Endpoint: GET /auth — display login form #
|
||||
|
||||
Reference in New Issue
Block a user