Files
net-kingdom/tools/iam-profile-conformance/iam_profile_conformance.py

707 lines
27 KiB
Python

#!/usr/bin/env python3
"""
Executable conformance checks for NetKingdom IAM Profile v0.2.
The suite intentionally uses provider-neutral OIDC/JWT rules. It can run
against key-cape, Keycloak, or a fixture issuer as long as the issuer
exposes standard discovery and JWKS metadata.
"""
from __future__ import annotations
import argparse
import base64
import json
import sys
import time
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from typing import Any
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa
PROFILE_VERSION = "0.2"
DEFAULT_SKEW_SECONDS = 60
REQUIRED_DISCOVERY_FIELDS = {
"issuer",
"authorization_endpoint",
"token_endpoint",
"jwks_uri",
"userinfo_endpoint",
"scopes_supported",
"response_types_supported",
"grant_types_supported",
"id_token_signing_alg_values_supported",
"code_challenge_methods_supported",
}
PRINCIPAL_TYPES = {"human", "service", "agent"}
ASSURANCE_LEVELS = {"aal0", "aal1", "aal2", "aal3", "break_glass"}
HIGH_IMPACT_ROLES = {
"admin",
"platform-admin",
"platform_operator",
"steward",
"emergency",
"break-glass",
}
@dataclass
class Config:
issuer: str
audience: str
access_token: str | None = None
client_id: str | None = None
redirect_uri: str | None = None
environment: str = "production"
timeout: float = 10.0
discovery_only: bool = False
skip_pkce_probe: bool = False
skew_seconds: int = DEFAULT_SKEW_SECONDS
@dataclass
class Result:
name: str
status: str
message: str
detail: dict[str, Any] | None = None
class NoRedirect(urllib.request.HTTPRedirectHandler):
def redirect_request(self, req, fp, code, msg, headers, newurl): # noqa: N802
return None
def pass_result(name: str, message: str, detail: dict[str, Any] | None = None) -> Result:
return Result(name, "PASS", message, detail)
def warn_result(name: str, message: str, detail: dict[str, Any] | None = None) -> Result:
return Result(name, "WARN", message, detail)
def fail_result(name: str, message: str, detail: dict[str, Any] | None = None) -> Result:
return Result(name, "FAIL", message, detail)
def b64url_decode(value: str) -> bytes:
padding_len = (4 - len(value) % 4) % 4
return base64.urlsafe_b64decode(value + ("=" * padding_len))
def fetch_json(url: str, timeout: float) -> dict[str, Any]:
req = urllib.request.Request(url, headers={"Accept": "application/json"})
with urllib.request.urlopen(req, timeout=timeout) as response:
body = response.read()
data = json.loads(body)
if not isinstance(data, dict):
raise ValueError(f"expected JSON object from {url}")
return data
def discovery_url(issuer: str) -> str:
return issuer.rstrip("/") + "/.well-known/openid-configuration"
def normalize_issuer(value: str) -> str:
return value.rstrip("/")
def as_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
return [value]
def normalize_scopes(payload: dict[str, Any]) -> list[str]:
scopes: list[str] = []
scope_value = payload.get("scope")
if isinstance(scope_value, str):
scopes.extend(part for part in scope_value.split() if part)
elif isinstance(scope_value, list):
scopes.extend(str(part) for part in scope_value)
scp_value = payload.get("scp")
if isinstance(scp_value, str):
scopes.extend(part for part in scp_value.split() if part)
elif isinstance(scp_value, list):
scopes.extend(str(part) for part in scp_value)
return sorted(set(scopes))
def normalize_roles(payload: dict[str, Any]) -> tuple[list[str], str]:
roles = payload.get("roles")
if isinstance(roles, list):
return [str(role) for role in roles], "roles"
realm_access = payload.get("realm_access")
if isinstance(realm_access, dict) and isinstance(realm_access.get("roles"), list):
return [str(role) for role in realm_access["roles"]], "realm_access.roles"
return [], "missing"
def is_local_issuer(issuer: str) -> bool:
if issuer == "local-identity":
return True
parsed = urllib.parse.urlparse(issuer)
host = (parsed.hostname or "").lower()
if parsed.scheme == "http":
return True
if host in {"localhost", "127.0.0.1", "::1"}:
return True
if host.endswith(".local"):
return True
return False
def check_discovery(config: Config, discovery: dict[str, Any]) -> list[Result]:
results: list[Result] = []
missing = sorted(REQUIRED_DISCOVERY_FIELDS - set(discovery))
if missing:
results.append(fail_result("discovery-fields", "missing required metadata fields", {"missing": missing}))
else:
results.append(pass_result("discovery-fields", "required metadata fields are present"))
advertised_issuer = str(discovery.get("issuer", ""))
if normalize_issuer(advertised_issuer) == normalize_issuer(config.issuer):
results.append(pass_result("discovery-issuer", "discovery issuer matches configured issuer"))
else:
results.append(
fail_result(
"discovery-issuer",
"discovery issuer does not match configured issuer",
{"configured": config.issuer, "advertised": advertised_issuer},
)
)
response_types = set(str(value) for value in as_list(discovery.get("response_types_supported")))
if "code" in response_types:
results.append(pass_result("authorization-code-flow", "authorization code response type is advertised"))
else:
results.append(fail_result("authorization-code-flow", "response_types_supported must include code"))
grants = set(str(value) for value in as_list(discovery.get("grant_types_supported")))
if "authorization_code" in grants:
results.append(pass_result("authorization-code-grant", "authorization_code grant is advertised"))
else:
results.append(fail_result("authorization-code-grant", "grant_types_supported must include authorization_code"))
service_grants = {"client_credentials", "urn:ietf:params:oauth:grant-type:token-exchange"}
if grants & service_grants:
results.append(pass_result("service-account-flow", "service-account or workload-token grant is advertised"))
else:
results.append(
fail_result(
"service-account-flow",
"grant_types_supported must include client_credentials or a workload token-exchange grant",
)
)
scopes = set(str(value) for value in as_list(discovery.get("scopes_supported")))
if "openid" in scopes:
results.append(pass_result("openid-scope", "openid scope is advertised"))
else:
results.append(fail_result("openid-scope", "scopes_supported must include openid"))
missing_recommended_scopes = sorted({"profile", "email"} - scopes)
if missing_recommended_scopes:
results.append(
warn_result(
"recommended-scopes",
"profile/email are recommended for human profile claims",
{"missing": missing_recommended_scopes},
)
)
else:
results.append(pass_result("recommended-scopes", "profile and email scopes are advertised"))
algs = set(str(value) for value in as_list(discovery.get("id_token_signing_alg_values_supported")))
if "RS256" in algs:
results.append(pass_result("signing-algorithm", "RS256 is advertised"))
else:
results.append(fail_result("signing-algorithm", "RS256 must be advertised for v0.2 conformance"))
pkce_methods = set(str(value) for value in as_list(discovery.get("code_challenge_methods_supported")))
if "S256" in pkce_methods:
results.append(pass_result("pkce-metadata", "PKCE S256 is advertised"))
else:
results.append(fail_result("pkce-metadata", "code_challenge_methods_supported must include S256"))
return results
def check_local_issuer_policy(config: Config, issuer: str) -> Result:
if config.environment == "production" and is_local_issuer(issuer):
return fail_result(
"local-issuer-policy",
"production mode must reject local-development issuers",
{"issuer": issuer},
)
if is_local_issuer(issuer):
return pass_result(
"local-issuer-policy",
f"local issuer accepted only because environment={config.environment}",
{"issuer": issuer},
)
return pass_result("local-issuer-policy", "issuer is not a local-development issuer")
def check_jwks(jwks: dict[str, Any]) -> list[Result]:
results: list[Result] = []
keys = jwks.get("keys")
if not isinstance(keys, list) or not keys:
return [fail_result("jwks-keys", "JWKS must contain at least one signing key")]
missing_kids = [idx for idx, key in enumerate(keys) if not isinstance(key, dict) or not key.get("kid")]
if missing_kids:
results.append(fail_result("jwks-key-ids", "all JWKS keys must carry kid", {"indexes": missing_kids}))
else:
results.append(pass_result("jwks-key-ids", "all JWKS keys carry kid"))
usable = [
key for key in keys
if isinstance(key, dict)
and key.get("kty") == "RSA"
and key.get("n")
and key.get("e")
and (key.get("use") in {None, "sig"})
]
if usable:
results.append(pass_result("jwks-rsa", "JWKS contains usable RSA signing keys", {"count": len(usable)}))
else:
results.append(fail_result("jwks-rsa", "JWKS must contain RSA signing keys for RS256"))
return results
def probe_pkce(config: Config, discovery: dict[str, Any]) -> Result:
if config.skip_pkce_probe:
return warn_result("pkce-probe", "PKCE probe skipped by operator request")
if not config.client_id or not config.redirect_uri:
return fail_result(
"pkce-probe",
"client id and redirect URI are required to probe PKCE enforcement",
)
endpoint = discovery.get("authorization_endpoint")
if not isinstance(endpoint, str) or not endpoint:
return fail_result("pkce-probe", "authorization_endpoint is missing")
query = urllib.parse.urlencode(
{
"response_type": "code",
"client_id": config.client_id,
"redirect_uri": config.redirect_uri,
"scope": "openid",
"state": "iam-profile-conformance",
"nonce": "iam-profile-conformance",
}
)
url = endpoint + ("&" if "?" in endpoint else "?") + query
opener = urllib.request.build_opener(NoRedirect())
req = urllib.request.Request(url, headers={"Accept": "application/json,text/html,*/*"})
try:
response = opener.open(req, timeout=config.timeout)
status = getattr(response, "status", 200)
location = response.headers.get("Location", "")
body = response.read(2048).decode("utf-8", errors="replace").lower()
except urllib.error.HTTPError as exc:
status = exc.code
location = exc.headers.get("Location", "")
body = exc.read(2048).decode("utf-8", errors="replace").lower()
except Exception as exc: # pragma: no cover - network diagnostics
return fail_result("pkce-probe", f"PKCE probe failed to reach authorization endpoint: {exc}")
location_lower = location.lower()
rejection_text = " ".join([location_lower, body])
rejected_for_pkce = (
status in {302, 303, 307, 308, 400, 401}
and ("invalid_request" in rejection_text or "code_challenge" in rejection_text or "pkce" in rejection_text)
)
if rejected_for_pkce:
return pass_result("pkce-probe", "authorization request without code_challenge was rejected")
return fail_result(
"pkce-probe",
"authorization request without code_challenge was not clearly rejected",
{"status": status, "location": location},
)
def decode_jwt(token: str) -> tuple[dict[str, Any], dict[str, Any]]:
parts = token.split(".")
if len(parts) != 3:
raise ValueError("JWT must have three compact-serialization parts")
header = json.loads(b64url_decode(parts[0]))
payload = json.loads(b64url_decode(parts[1]))
return header, payload
def jwk_to_rsa_public_key(jwk: dict[str, Any]):
n = int.from_bytes(b64url_decode(str(jwk["n"])), "big")
e = int.from_bytes(b64url_decode(str(jwk["e"])), "big")
return rsa.RSAPublicNumbers(e, n).public_key()
def verify_signature(header: dict[str, Any], token: str, jwks: dict[str, Any]) -> Result:
if header.get("alg") != "RS256":
return fail_result("jwt-signature", "JWT alg must be RS256", {"alg": header.get("alg")})
kid = header.get("kid")
keys = jwks.get("keys") if isinstance(jwks.get("keys"), list) else []
matching = [key for key in keys if isinstance(key, dict) and key.get("kid") == kid]
if not matching:
return fail_result("jwt-signature", "JWT kid was not found in JWKS", {"kid": kid})
parts = token.split(".")
signing_input = f"{parts[0]}.{parts[1]}".encode("ascii")
signature = b64url_decode(parts[2])
try:
public_key = jwk_to_rsa_public_key(matching[0])
public_key.verify(signature, signing_input, padding.PKCS1v15(), hashes.SHA256())
except (InvalidSignature, ValueError, KeyError) as exc:
return fail_result("jwt-signature", f"JWT signature verification failed: {exc}")
return pass_result("jwt-signature", "JWT signature verifies against JWKS", {"kid": kid})
def audience_matches(audience_claim: Any, expected: str) -> bool:
if isinstance(audience_claim, str):
return audience_claim == expected
if isinstance(audience_claim, list):
return expected in [str(value) for value in audience_claim]
return False
def check_token_lifetime(payload: dict[str, Any], config: Config) -> list[Result]:
results: list[Result] = []
now = int(time.time())
skew = config.skew_seconds
exp = payload.get("exp")
iat = payload.get("iat")
nbf = payload.get("nbf")
if not isinstance(exp, int):
results.append(fail_result("token-expiry", "exp must be an integer timestamp"))
elif exp <= now - skew:
results.append(fail_result("token-expiry", "token is expired", {"exp": exp, "now": now}))
else:
results.append(pass_result("token-expiry", "token is not expired"))
if not isinstance(iat, int):
results.append(fail_result("token-issued-at", "iat must be an integer timestamp"))
elif iat > now + skew:
results.append(fail_result("token-issued-at", "iat is in the future", {"iat": iat, "now": now}))
else:
results.append(pass_result("token-issued-at", "iat is valid"))
if nbf is None:
results.append(warn_result("token-not-before", "nbf is recommended for production tokens"))
elif not isinstance(nbf, int):
results.append(fail_result("token-not-before", "nbf must be an integer timestamp"))
elif nbf > now + skew:
results.append(fail_result("token-not-before", "nbf is in the future", {"nbf": nbf, "now": now}))
else:
results.append(pass_result("token-not-before", "nbf is valid"))
if isinstance(exp, int) and isinstance(iat, int):
ttl = exp - iat
if ttl > 3600:
results.append(warn_result("token-ttl", "access token TTL is longer than the profile default", {"ttl": ttl}))
elif ttl <= 0:
results.append(fail_result("token-ttl", "token TTL must be positive", {"ttl": ttl}))
else:
results.append(pass_result("token-ttl", "token TTL is within conformance tolerance", {"ttl": ttl}))
return results
def check_claim_shape(payload: dict[str, Any]) -> list[Result]:
results: list[Result] = []
required = {"iss", "sub", "aud", "exp", "iat", "tenant", "principal_type", "groups", "assurance"}
missing = sorted(required - set(payload))
roles, role_source = normalize_roles(payload)
scopes = normalize_scopes(payload)
if not roles:
missing.append("roles")
if not scopes:
missing.append("scope/scp")
if missing:
results.append(fail_result("claim-shape", "token is missing required IAM Profile claims", {"missing": missing}))
else:
results.append(pass_result("claim-shape", "required IAM Profile claims are present"))
tenant = payload.get("tenant")
if isinstance(tenant, str) and tenant.startswith("tenant:") and len(tenant) > len("tenant:"):
results.append(pass_result("tenant-claim", "tenant claim is well formed", {"tenant": tenant}))
else:
results.append(fail_result("tenant-claim", "tenant must be a string like tenant:platform"))
groups = payload.get("groups")
if isinstance(groups, list):
results.append(pass_result("groups-claim", "groups claim is a list", {"count": len(groups)}))
else:
results.append(fail_result("groups-claim", "groups must be a list, even when empty"))
if role_source == "roles":
results.append(pass_result("roles-claim", "canonical roles claim is present", {"count": len(roles)}))
elif role_source == "realm_access.roles":
results.append(
warn_result(
"roles-claim",
"provider-native realm_access.roles found; emit canonical roles before production consumption",
{"count": len(roles)},
)
)
else:
results.append(fail_result("roles-claim", "roles must be present as roles or normalized from provider-native roles"))
if scopes:
results.append(pass_result("scope-claim", "scope/scp claim is present", {"scopes": scopes}))
else:
results.append(fail_result("scope-claim", "scope or scp must be present"))
return results
def check_principal_shape(payload: dict[str, Any]) -> Result:
principal_type = payload.get("principal_type")
if principal_type not in PRINCIPAL_TYPES:
return fail_result("principal-shape", "principal_type must be human, service, or agent", {"principal_type": principal_type})
if principal_type == "human":
if payload.get("preferred_username"):
return pass_result("principal-shape", "human principal has preferred_username")
return fail_result("principal-shape", "human principals must include preferred_username")
if principal_type == "service":
if payload.get("azp") or payload.get("client_id"):
return pass_result("principal-shape", "service principal has azp/client_id")
return fail_result("principal-shape", "service principals must include azp or client_id")
agent = payload.get("agent")
if not isinstance(agent, dict):
return fail_result("principal-shape", "agent principals must include an agent object")
if not agent.get("id"):
return fail_result("principal-shape", "agent.id is required")
mode = agent.get("mode")
if mode not in {"autonomous", "delegated"}:
return fail_result("principal-shape", "agent.mode must be autonomous or delegated", {"mode": mode})
if mode == "delegated":
act = payload.get("act")
has_actor = bool(payload.get("actor_sub")) or (isinstance(act, dict) and bool(act.get("sub")))
if not has_actor:
return fail_result("principal-shape", "delegated agents must include actor_sub or act.sub")
return pass_result("principal-shape", "agent principal shape is valid")
def check_assurance(payload: dict[str, Any]) -> list[Result]:
results: list[Result] = []
assurance = payload.get("assurance")
if not isinstance(assurance, dict):
return [fail_result("assurance-shape", "assurance must be an object")]
level = assurance.get("level")
methods = assurance.get("methods")
mfa = assurance.get("mfa")
source = assurance.get("source")
missing = [
name for name, value in {
"level": level,
"methods": methods,
"mfa": mfa,
"source": source,
}.items()
if value is None
]
if missing:
results.append(fail_result("assurance-shape", "assurance is missing required fields", {"missing": missing}))
return results
if level not in ASSURANCE_LEVELS:
results.append(fail_result("assurance-level", "assurance.level has an unsupported value", {"level": level}))
elif level == "aal0":
results.append(warn_result("assurance-level", "aal0 is local/dev only and not production privileged"))
else:
results.append(pass_result("assurance-level", "assurance.level is recognized", {"level": level}))
if isinstance(methods, list) and all(isinstance(method, str) for method in methods):
results.append(pass_result("assurance-methods", "assurance.methods is a list"))
else:
results.append(fail_result("assurance-methods", "assurance.methods must be a list of strings"))
if isinstance(mfa, bool):
results.append(pass_result("assurance-mfa", "assurance.mfa is boolean", {"mfa": mfa}))
else:
results.append(fail_result("assurance-mfa", "assurance.mfa must be boolean"))
roles, _ = normalize_roles(payload)
has_high_impact_role = bool(HIGH_IMPACT_ROLES & set(roles))
if has_high_impact_role and level not in {"aal2", "aal3", "break_glass"}:
results.append(fail_result("privileged-assurance", "high-impact roles require aal2, aal3, or break_glass"))
elif has_high_impact_role and mfa is not True and level != "break_glass":
results.append(fail_result("privileged-assurance", "high-impact roles require MFA evidence"))
else:
results.append(pass_result("privileged-assurance", "assurance is sufficient for asserted roles"))
if "emergency" in roles or "break-glass" in roles:
if level != "break_glass":
results.append(fail_result("emergency-assurance", "emergency roles require assurance.level=break_glass"))
else:
results.append(pass_result("emergency-assurance", "emergency assurance level is explicit"))
exp = payload.get("exp")
iat = payload.get("iat")
if isinstance(exp, int) and isinstance(iat, int) and exp - iat > 900:
results.append(warn_result("emergency-ttl", "emergency token TTL should be 15 minutes or less"))
return results
def check_token(config: Config, token: str, jwks: dict[str, Any]) -> list[Result]:
results: list[Result] = []
try:
header, payload = decode_jwt(token)
except Exception as exc:
return [fail_result("jwt-structure", f"could not decode JWT: {exc}")]
results.append(pass_result("jwt-structure", "JWT compact serialization decoded"))
results.append(verify_signature(header, token, jwks))
issuer = payload.get("iss")
if isinstance(issuer, str) and normalize_issuer(issuer) == normalize_issuer(config.issuer):
results.append(pass_result("token-issuer", "token issuer matches configured issuer"))
else:
results.append(
fail_result(
"token-issuer",
"token issuer does not match configured issuer",
{"configured": config.issuer, "token": issuer},
)
)
if audience_matches(payload.get("aud"), config.audience):
results.append(pass_result("token-audience", "token audience includes configured audience"))
else:
results.append(
fail_result(
"token-audience",
"token audience does not include configured audience",
{"expected": config.audience, "token": payload.get("aud")},
)
)
results.extend(check_token_lifetime(payload, config))
results.extend(check_claim_shape(payload))
results.append(check_principal_shape(payload))
results.extend(check_assurance(payload))
return results
def run_suite(config: Config) -> list[Result]:
results: list[Result] = []
try:
discovery = fetch_json(discovery_url(config.issuer), config.timeout)
except Exception as exc:
return [fail_result("discovery-fetch", f"failed to fetch discovery document: {exc}")]
results.append(pass_result("discovery-fetch", "discovery document fetched"))
results.extend(check_discovery(config, discovery))
results.append(check_local_issuer_policy(config, str(discovery.get("issuer", config.issuer))))
try:
jwks_uri = str(discovery["jwks_uri"])
jwks = fetch_json(jwks_uri, config.timeout)
except Exception as exc:
results.append(fail_result("jwks-fetch", f"failed to fetch JWKS: {exc}"))
jwks = {"keys": []}
else:
results.append(pass_result("jwks-fetch", "JWKS fetched"))
results.extend(check_jwks(jwks))
if config.discovery_only:
return results
results.append(probe_pkce(config, discovery))
if not config.access_token:
results.append(fail_result("token-provided", "full conformance requires --access-token"))
else:
results.append(pass_result("token-provided", "access token provided"))
results.extend(check_token(config, config.access_token, jwks))
return results
def print_results(results: list[Result], as_json: bool) -> None:
if as_json:
print(json.dumps([result.__dict__ for result in results], indent=2, sort_keys=True))
return
for result in results:
print(f"{result.status:4} {result.name}: {result.message}")
if result.detail:
print(f" {json.dumps(result.detail, sort_keys=True)}")
fail_count = sum(1 for result in results if result.status == "FAIL")
warn_count = sum(1 for result in results if result.status == "WARN")
print("")
print(f"IAM Profile v{PROFILE_VERSION} conformance: {fail_count} fail, {warn_count} warn, {len(results)} checks")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run NetKingdom IAM Profile v0.2 conformance checks.")
parser.add_argument("--issuer", required=True, help="OIDC issuer URL or local issuer base")
parser.add_argument("--audience", required=True, help="Expected audience for the supplied token")
parser.add_argument("--access-token", help="JWT access token to validate")
parser.add_argument("--client-id", help="Public test client id for PKCE probe")
parser.add_argument("--redirect-uri", help="Redirect URI registered for the test client")
parser.add_argument(
"--environment",
choices=["production", "nonproduction", "local"],
default="production",
help="Validation environment; production rejects local issuers",
)
parser.add_argument("--timeout", type=float, default=10.0, help="HTTP timeout in seconds")
parser.add_argument("--discovery-only", action="store_true", help="Only run discovery/JWKS checks")
parser.add_argument("--skip-pkce-probe", action="store_true", help="Skip active PKCE enforcement probe")
parser.add_argument("--json", action="store_true", help="Print machine-readable JSON results")
return parser
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
config = Config(
issuer=args.issuer,
audience=args.audience,
access_token=args.access_token,
client_id=args.client_id,
redirect_uri=args.redirect_uri,
environment=args.environment,
timeout=args.timeout,
discovery_only=args.discovery_only,
skip_pkce_probe=args.skip_pkce_probe,
)
results = run_suite(config)
print_results(results, args.json)
return 1 if any(result.status == "FAIL" for result in results) else 0
if __name__ == "__main__":
sys.exit(main())