"""Tests for warden.ca — LocalCA and parse_cert_metadata.""" from datetime import datetime, timezone from pathlib import Path from unittest.mock import MagicMock, patch import pytest import json from warden.ca import CAError, LocalCA, _enforce_ttl, _evict_cert, _append_signature_log, parse_cert_metadata from warden.models import ActorType, CertSpec, CertRecord SAMPLE_SSHKEYGEN_L = """\ /tmp/key-cert.pub: Type: ssh-ed25519-cert-v01@openssh.com user certificate Public key: ED25519-CERT SHA256:abc123 Signing CA: ED25519 SHA256:xyz (using ssh-ed25519) Key ID: "agt-state-hub-bridge" Serial: 0 Valid: from 2026-03-28T10:00:00 to 2026-03-29T10:00:00 Principals: agt-task-bridge Critical Options: (none) Extensions: permit-pty """ CERT_CONTENT = "ssh-ed25519-cert-v01@openssh.com AAAA_fake_cert_data" def _mock_run_factory(cert_content: str): """Returns a mock subprocess.run that writes the cert file on sign and returns SAMPLE_SSHKEYGEN_L on -L.""" def mock_run(cmd, **kwargs): result = MagicMock() result.returncode = 0 result.stdout = "" result.stderr = "" if not isinstance(cmd, list) or not cmd: return result if cmd[0] == "ssh-keygen" and "-s" in cmd: # Signing: write cert next to the pubkey copy (last arg) pubkey_path = Path(cmd[-1]) cert_path = pubkey_path.parent / (pubkey_path.stem + "-cert.pub") cert_path.write_text(cert_content) elif cmd[0] == "ssh-keygen" and "-L" in cmd: result.stdout = SAMPLE_SSHKEYGEN_L return result return mock_run # --------------------------------------------------------------------------- # parse_cert_metadata # --------------------------------------------------------------------------- def test_parse_cert_metadata(tmp_path): cert_path = tmp_path / "key-cert.pub" cert_path.write_text(CERT_CONTENT) mock_result = MagicMock(returncode=0, stdout=SAMPLE_SSHKEYGEN_L, stderr="") with patch("warden.ca.subprocess.run", return_value=mock_result): meta = parse_cert_metadata(cert_path) assert meta["identity"] == "agt-state-hub-bridge" assert meta["principals"] == ["agt-task-bridge"] assert meta["valid_before"] == datetime(2026, 3, 29, 10, 0, 0, tzinfo=timezone.utc) def test_parse_cert_metadata_failure(tmp_path): cert_path = tmp_path / "key-cert.pub" cert_path.write_text("not a cert") mock_result = MagicMock(returncode=1, stdout="", stderr="not a certificate") with patch("warden.ca.subprocess.run", return_value=mock_result): with pytest.raises(CAError, match="ssh-keygen -L failed"): parse_cert_metadata(cert_path) def test_parse_cert_metadata_missing_valid_before(tmp_path): cert_path = tmp_path / "key-cert.pub" cert_path.write_text(CERT_CONTENT) output_no_valid = SAMPLE_SSHKEYGEN_L.replace( " Valid: from 2026-03-28T10:00:00 to 2026-03-29T10:00:00\n", "" ) mock_result = MagicMock(returncode=0, stdout=output_no_valid, stderr="") with patch("warden.ca.subprocess.run", return_value=mock_result): with pytest.raises(CAError, match="valid_before"): parse_cert_metadata(cert_path) # --------------------------------------------------------------------------- # LocalCA.sign # --------------------------------------------------------------------------- def test_local_ca_sign(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca-private-key") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA actor-key") spec = CertSpec( actor_name="agt-state-hub-bridge", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-task-bridge"], identity="agt-state-hub-bridge", ) with patch("warden.ca.subprocess.run", side_effect=_mock_run_factory(CERT_CONTENT)): ca = LocalCA(ca_key, tmp_path / "state") record = ca.sign(spec) assert record.identity == "agt-state-hub-bridge" assert record.actor_name == "agt-state-hub-bridge" assert record.principals == ["agt-task-bridge"] cert_dest = tmp_path / "state" / "agt-state-hub-bridge-cert.pub" assert cert_dest.exists() assert cert_dest.read_text().strip() == CERT_CONTENT def test_local_ca_sign_missing_pubkey(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=tmp_path / "nonexistent.pub", ttl_hours=24, principals=["agt-test"], ) ca = LocalCA(ca_key, tmp_path / "state") with pytest.raises(CAError, match="Public key not found"): ca.sign(spec) def test_local_ca_sign_missing_ca_key(tmp_path): pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA") spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-test"], ) ca = LocalCA(tmp_path / "nonexistent_ca", tmp_path / "state") with pytest.raises(CAError, match="CA key not found"): ca.sign(spec) def test_local_ca_sign_ssh_keygen_failure(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA") spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-test"], ) def fail_run(cmd, **kwargs): result = MagicMock() result.returncode = 1 result.stderr = "load key: invalid format" result.stdout = "" return result ca = LocalCA(ca_key, tmp_path / "state") with patch("warden.ca.subprocess.run", side_effect=fail_run): with pytest.raises(CAError, match="Signing failed"): ca.sign(spec) # --------------------------------------------------------------------------- # _enforce_ttl # --------------------------------------------------------------------------- @pytest.mark.parametrize("actor_type,max_h", [ (ActorType.ADM, 48), (ActorType.AGT, 24), (ActorType.ATM, 8), ]) def test_enforce_ttl_rejects_over_max(actor_type, max_h, tmp_path): spec = CertSpec( actor_name=f"{actor_type.value}-test", actor_type=actor_type, pubkey_path=tmp_path / "k.pub", ttl_hours=max_h + 1, principals=["x"], ) with pytest.raises(CAError, match="exceeds maximum"): _enforce_ttl(spec) @pytest.mark.parametrize("actor_type,max_h", [ (ActorType.ADM, 48), (ActorType.AGT, 24), (ActorType.ATM, 8), ]) def test_enforce_ttl_accepts_at_max(actor_type, max_h, tmp_path): spec = CertSpec( actor_name=f"{actor_type.value}-test", actor_type=actor_type, pubkey_path=tmp_path / "k.pub", ttl_hours=max_h, principals=["x"], ) _enforce_ttl(spec) # must not raise # --------------------------------------------------------------------------- # _evict_cert # --------------------------------------------------------------------------- def test_evict_cert_removes_existing(tmp_path): cert = tmp_path / "agt-test-cert.pub" cert.write_text("old cert") _evict_cert("agt-test", tmp_path) assert not cert.exists() def test_evict_cert_noop_when_absent(tmp_path): _evict_cert("agt-test", tmp_path) # must not raise # --------------------------------------------------------------------------- # _append_signature_log # --------------------------------------------------------------------------- def test_append_signature_log_creates_file(tmp_path): record = CertRecord( identity="agt-test", valid_before=datetime(2026, 3, 29, 10, 0, 0, tzinfo=timezone.utc), cert_path=tmp_path / "agt-test-cert.pub", signed_at=datetime(2026, 3, 28, 10, 0, 0, tzinfo=timezone.utc), principals=["agt-task"], actor_name="agt-test", ) spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=tmp_path / "k.pub", ttl_hours=24, principals=["agt-task"], ) _append_signature_log(record, spec, tmp_path, "local") log_path = tmp_path / "signatures.log" assert log_path.exists() entry = json.loads(log_path.read_text().strip()) assert entry["actor"] == "agt-test" assert entry["actor_type"] == "agt" assert entry["ttl_hours"] == 24 assert entry["backend"] == "local" assert entry["principals"] == ["agt-task"] def test_append_signature_log_appends(tmp_path): record = CertRecord( identity="agt-test", valid_before=datetime(2026, 3, 29, 10, 0, 0, tzinfo=timezone.utc), cert_path=tmp_path / "agt-test-cert.pub", signed_at=datetime(2026, 3, 28, 10, 0, 0, tzinfo=timezone.utc), principals=["agt-task"], actor_name="agt-test", ) spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=tmp_path / "k.pub", ttl_hours=24, principals=["agt-task"], ) _append_signature_log(record, spec, tmp_path, "local") _append_signature_log(record, spec, tmp_path, "local") lines = (tmp_path / "signatures.log").read_text().strip().splitlines() assert len(lines) == 2 # --------------------------------------------------------------------------- # LocalCA.sign with TTL enforcement, eviction, and log # --------------------------------------------------------------------------- def test_local_ca_sign_enforces_ttl(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA") spec = CertSpec( actor_name="agt-test", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=100, # exceeds AGT max of 24h principals=["agt-test"], ) ca = LocalCA(ca_key, tmp_path / "state") with pytest.raises(CAError, match="exceeds maximum"): ca.sign(spec) def test_local_ca_sign_evicts_existing_cert(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA actor-key") state = tmp_path / "state" state.mkdir() old_cert = state / "agt-state-hub-bridge-cert.pub" old_cert.write_text("old cert content") spec = CertSpec( actor_name="agt-state-hub-bridge", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-task-bridge"], identity="agt-state-hub-bridge", ) with patch("warden.ca.subprocess.run", side_effect=_mock_run_factory(CERT_CONTENT)): ca = LocalCA(ca_key, state) record = ca.sign(spec) assert record.cert_path.read_text().strip() == CERT_CONTENT # Only one cert file for this actor (old was replaced) assert len(list(state.glob("agt-state-hub-bridge-cert.pub"))) == 1 def test_local_ca_sign_cert_mode_600(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA actor-key") state = tmp_path / "state" spec = CertSpec( actor_name="agt-state-hub-bridge", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-task-bridge"], identity="agt-state-hub-bridge", ) with patch("warden.ca.subprocess.run", side_effect=_mock_run_factory(CERT_CONTENT)): ca = LocalCA(ca_key, state) record = ca.sign(spec) assert oct(record.cert_path.stat().st_mode & 0o777) == oct(0o600) def test_local_ca_sign_writes_signature_log(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA actor-key") state = tmp_path / "state" spec = CertSpec( actor_name="agt-state-hub-bridge", actor_type=ActorType.AGT, pubkey_path=pubkey, ttl_hours=24, principals=["agt-task-bridge"], identity="agt-state-hub-bridge", ) with patch("warden.ca.subprocess.run", side_effect=_mock_run_factory(CERT_CONTENT)): ca = LocalCA(ca_key, state) ca.sign(spec) log_path = state / "signatures.log" assert log_path.exists() entry = json.loads(log_path.read_text().strip()) assert entry["actor"] == "agt-state-hub-bridge" assert entry["backend"] == "local" assert entry["ttl_hours"] == 24 # --------------------------------------------------------------------------- # LocalCA.generate_keypair # --------------------------------------------------------------------------- def _mock_keygen_gen(cmd, **kwargs): """Mock for generate_keypair: writes privkey and pubkey based on -f arg.""" result = MagicMock() result.returncode = 0 result.stdout = "" result.stderr = "" if "-f" in cmd: idx = cmd.index("-f") privkey = Path(cmd[idx + 1]) privkey.parent.mkdir(parents=True, exist_ok=True) privkey.write_text("fake private key") (privkey.parent / (privkey.name + ".pub")).write_text("ssh-ed25519 AAAA pubkey") return result def test_generate_keypair_returns_paths(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") ca = LocalCA(ca_key, tmp_path / "state") with patch("warden.ca.subprocess.run", side_effect=_mock_keygen_gen): privkey, pubkey = ca.generate_keypair("agt-test") assert privkey.name == "agt-test_ed25519" assert pubkey.name == "agt-test_ed25519.pub" assert str(privkey).endswith("state/keys/agt-test_ed25519") def test_generate_keypair_ed25519_no_passphrase(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") ca = LocalCA(ca_key, tmp_path / "state") calls = [] def capturing_mock(cmd, **kwargs): calls.append(cmd) return _mock_keygen_gen(cmd, **kwargs) with patch("warden.ca.subprocess.run", side_effect=capturing_mock): ca.generate_keypair("agt-test") assert len(calls) == 1 cmd = calls[0] assert "-t" in cmd and cmd[cmd.index("-t") + 1] == "ed25519" assert "-N" in cmd and cmd[cmd.index("-N") + 1] == "" assert "-C" in cmd and cmd[cmd.index("-C") + 1] == "agt-test" def test_generate_keypair_overwrites_existing(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") state = tmp_path / "state" keys_dir = state / "keys" keys_dir.mkdir(parents=True) old_priv = keys_dir / "agt-test_ed25519" old_pub = keys_dir / "agt-test_ed25519.pub" old_priv.write_text("old key") old_pub.write_text("old pubkey") ca = LocalCA(ca_key, state) with patch("warden.ca.subprocess.run", side_effect=_mock_keygen_gen): privkey, pubkey = ca.generate_keypair("agt-test") assert privkey.read_text() == "fake private key" def test_generate_keypair_ca_error_on_failure(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") ca = LocalCA(ca_key, tmp_path / "state") def fail_run(cmd, **kwargs): result = MagicMock() result.returncode = 1 result.stderr = "failed to generate key" return result with patch("warden.ca.subprocess.run", side_effect=fail_run): with pytest.raises(CAError, match="Key generation failed"): ca.generate_keypair("agt-test") def test_generate_keypair_sets_permissions(tmp_path): ca_key = tmp_path / "ca_key" ca_key.write_text("fake-ca") ca = LocalCA(ca_key, tmp_path / "state") with patch("warden.ca.subprocess.run", side_effect=_mock_keygen_gen): privkey, pubkey = ca.generate_keypair("agt-test") assert oct(privkey.stat().st_mode & 0o777) == oct(0o600) assert oct(pubkey.stat().st_mode & 0o777) == oct(0o644)