diff --git a/local-identity/src/local_identity/audit.py b/local-identity/src/local_identity/audit.py index 5c80425..7ab2bee 100644 --- a/local-identity/src/local_identity/audit.py +++ b/local-identity/src/local_identity/audit.py @@ -33,8 +33,10 @@ def log_event(command: str, username: str | None, outcome: str) -> None: entry = f"{timestamp}\t{command}\t{username or '-'}\t{outcome}\n" path = _audit_log_path() try: + is_new = not path.exists() with open(path, "a", encoding="utf-8") as fh: fh.write(entry) - os.chmod(path, 0o600) + if is_new: + os.chmod(path, 0o600) except OSError: pass diff --git a/local-identity/src/local_identity/cli.py b/local-identity/src/local_identity/cli.py index e359b7c..c4ed25e 100644 --- a/local-identity/src/local_identity/cli.py +++ b/local-identity/src/local_identity/cli.py @@ -20,11 +20,11 @@ Environment: """ import argparse -import base64 import json import sys from .gecos import current_username, get_gecos_fullname +from .jwt_utils import JWTError, extract_unverified_payload from .user import UserRecord, make_test_user from . import audit from . import export as export_mod @@ -163,16 +163,12 @@ def cmd_revoke_token(args: argparse.Namespace) -> None: if token_or_jti.count(".") == 2: # Looks like a JWT — extract the JTI from the payload try: - payload_b64 = token_or_jti.split(".")[1] - pad = (4 - len(payload_b64) % 4) % 4 - payload = json.loads( - base64.urlsafe_b64decode(payload_b64 + "=" * pad) - ) + payload = extract_unverified_payload(token_or_jti) jti = payload.get("jti") if not jti: print("Error: JWT has no 'jti' claim.", file=sys.stderr) sys.exit(1) - except Exception as exc: + except JWTError as exc: print(f"Error decoding JWT: {exc}", file=sys.stderr) sys.exit(1) else: diff --git a/local-identity/src/local_identity/jwt_utils.py b/local-identity/src/local_identity/jwt_utils.py index 4c8cf52..9a3ca1c 100644 --- a/local-identity/src/local_identity/jwt_utils.py +++ b/local-identity/src/local_identity/jwt_utils.py @@ -114,3 +114,17 @@ def verify_token(token: str, public_key: RSAPublicKey) -> dict: raise JWTError("token has expired") return payload + + +def extract_unverified_payload(token: str) -> dict: + """ + Decode the payload of a JWT without verifying the signature. + Raises JWTError if the token is malformed. + """ + parts = token.split(".") + if len(parts) != 3: + raise JWTError("malformed token: expected 3 parts") + try: + return json.loads(_b64url_decode(parts[1])) + except Exception as exc: + raise JWTError(f"cannot decode payload: {exc}") from exc diff --git a/local-identity/src/local_identity/keys.py b/local-identity/src/local_identity/keys.py index a32896a..91f05a2 100644 --- a/local-identity/src/local_identity/keys.py +++ b/local-identity/src/local_identity/keys.py @@ -8,7 +8,6 @@ The corresponding public key is never stored separately — it is always derived from the private key on load. """ -import base64 import hashlib import os from pathlib import Path @@ -17,6 +16,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from .jwt_utils import _b64url_encode as _b64url from .store import _store_dir @@ -48,27 +48,28 @@ def ensure_signing_key() -> RSAPrivateKey: return private_key -def _b64url(data: bytes) -> str: - return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") +def _public_key_bytes(private_key: RSAPrivateKey) -> tuple[bytes, bytes]: + """Return (n_bytes, e_bytes) for the public key — shared by key_id and jwk_public.""" + pub = private_key.public_key().public_numbers() + n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big") + e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big") + return n_bytes, e_bytes def key_id(private_key: RSAPrivateKey) -> str: """Return a stable 16-hex-char key ID derived from the public key modulus.""" - n = private_key.public_key().public_numbers().n - n_bytes = n.to_bytes((n.bit_length() + 7) // 8, byteorder="big") + n_bytes, _ = _public_key_bytes(private_key) return hashlib.sha256(n_bytes).hexdigest()[:16] def jwk_public(private_key: RSAPrivateKey) -> dict: """Return the RSA public key as a JWK dict (RS256, sig use).""" - pub = private_key.public_key().public_numbers() - n_bytes = pub.n.to_bytes((pub.n.bit_length() + 7) // 8, byteorder="big") - e_bytes = pub.e.to_bytes((pub.e.bit_length() + 7) // 8, byteorder="big") + n_bytes, e_bytes = _public_key_bytes(private_key) return { "kty": "RSA", "use": "sig", "alg": "RS256", - "kid": key_id(private_key), + "kid": hashlib.sha256(n_bytes).hexdigest()[:16], "n": _b64url(n_bytes), "e": _b64url(e_bytes), } diff --git a/local-identity/src/local_identity/security.py b/local-identity/src/local_identity/security.py index b1326fa..14f616f 100644 --- a/local-identity/src/local_identity/security.py +++ b/local-identity/src/local_identity/security.py @@ -24,9 +24,10 @@ class CheckResult: def _check_mode(path: Path, expected: int) -> CheckResult: """Return a CheckResult for a single path against the expected mode.""" - if not path.exists(): + try: + actual = stat.S_IMODE(os.stat(path).st_mode) + except FileNotFoundError: return CheckResult(str(path), "warn", "does not exist (skipped)") - actual = stat.S_IMODE(os.stat(path).st_mode) if actual != expected: return CheckResult( str(path), "fail", diff --git a/local-identity/src/local_identity/serve.py b/local-identity/src/local_identity/serve.py index 1f3024c..6296181 100644 --- a/local-identity/src/local_identity/serve.py +++ b/local-identity/src/local_identity/serve.py @@ -78,6 +78,7 @@ def make_handler( so multiple test servers can run concurrently without sharing state. """ codes: dict = {} + jwks_response = {"keys": [jwk_public(private_key)]} class _Handler(OIDCHandler): _private_key = private_key @@ -85,6 +86,7 @@ def make_handler( _token_ttl = token_ttl _codes = codes _scheme = scheme + _jwks = jwks_response return _Handler @@ -100,6 +102,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler): _token_ttl: int = 3600 _codes: dict = {} _scheme: str = "https" + _jwks: dict = {} def log_message(self, fmt: str, *args) -> None: pass # silence default Apache-style logging @@ -161,7 +164,7 @@ class OIDCHandler(http.server.BaseHTTPRequestHandler): # ---------------------------------------------------------------- # def _handle_jwks(self) -> None: - self._send_json({"keys": [jwk_public(self._private_key)]}) + self._send_json(self._jwks) # ---------------------------------------------------------------- # # Endpoint: GET /auth — display login form #