"""Tests for warden.vault — VaultCA backend.""" from unittest.mock import MagicMock, patch import httpx import pytest from warden.ca import CAError from warden.config import VaultConfig from warden.models import ActorType, CertSpec from warden.vault import VaultCA SAMPLE_CERT = "ssh-ed25519-cert-v01@openssh.com AAAA_fake_cert_data" SAMPLE_SSHKEYGEN_L = """\ /tmp/key-cert.pub: Type: ssh-ed25519-cert-v01@openssh.com user certificate Public key: ED25519-CERT SHA256:abc123 Signing CA: ED25519 SHA256:xyz (using ssh-ed25519) Key ID: "agt-test" Serial: 0 Valid: from 2026-03-28T10:00:00 to 2026-03-29T10:00:00 Principals: agt-task Critical Options: (none) Extensions: permit-pty """ def _make_cfg(**overrides): defaults = { "addr": "http://127.0.0.1:8200", "mount": "ssh", "token_env": "VAULT_TOKEN", "role_map": {"agt": "agt-role", "adm": "adm-role", "atm": "atm-role"}, } defaults.update(overrides) return VaultConfig(**defaults) def _make_spec(tmp_path, **overrides): pubkey = tmp_path / "key.pub" pubkey.write_text("ssh-ed25519 AAAA actor-key") defaults = { "actor_name": "agt-test", "actor_type": ActorType.AGT, "pubkey_path": pubkey, "ttl_hours": 24, "principals": ["agt-task"], "identity": "agt-test", } defaults.update(overrides) return CertSpec(**defaults) def _mock_httpx_post(signed_key: str): resp = MagicMock() resp.json.return_value = {"data": {"signed_key": signed_key}} resp.raise_for_status.return_value = None return resp def _mock_ssh_keygen_L(cmd, **kwargs): result = MagicMock() result.returncode = 0 result.stdout = SAMPLE_SSHKEYGEN_L result.stderr = "" return result # --------------------------------------------------------------------------- # VaultCA.sign — success # --------------------------------------------------------------------------- def test_vault_ca_sign_success(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path) cfg = _make_cfg() ca = VaultCA(cfg, tmp_path / "state") with ( patch("warden.vault.httpx.post", return_value=_mock_httpx_post(SAMPLE_CERT)), patch("warden.ca.subprocess.run", side_effect=_mock_ssh_keygen_L), ): record = ca.sign(spec) assert record.actor_name == "agt-test" assert record.identity == "agt-test" assert record.principals == ["agt-task"] dest = tmp_path / "state" / "agt-test-cert.pub" assert dest.exists() assert SAMPLE_CERT in dest.read_text() def test_vault_ca_sign_cert_mode_600(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), tmp_path / "state") with ( patch("warden.vault.httpx.post", return_value=_mock_httpx_post(SAMPLE_CERT)), patch("warden.ca.subprocess.run", side_effect=_mock_ssh_keygen_L), ): record = ca.sign(spec) assert oct(record.cert_path.stat().st_mode & 0o777) == oct(0o600) def test_vault_ca_sign_writes_signature_log(tmp_path, monkeypatch): import json monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), tmp_path / "state") with ( patch("warden.vault.httpx.post", return_value=_mock_httpx_post(SAMPLE_CERT)), patch("warden.ca.subprocess.run", side_effect=_mock_ssh_keygen_L), ): ca.sign(spec) log_path = tmp_path / "state" / "signatures.log" assert log_path.exists() entry = json.loads(log_path.read_text().strip()) assert entry["backend"] == "vault" assert entry["actor"] == "agt-test" # --------------------------------------------------------------------------- # VaultCA.sign — failure paths # --------------------------------------------------------------------------- def test_vault_ca_sign_http_403(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "bad-token") spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), tmp_path / "state") request = httpx.Request("POST", "http://127.0.0.1:8200/v1/ssh/sign/agt-role") response = httpx.Response(403, request=request, text="permission denied") exc = httpx.HTTPStatusError("403", request=request, response=response) with patch("warden.vault.httpx.post", side_effect=exc): with pytest.raises(CAError, match="403"): ca.sign(spec) def test_vault_ca_sign_request_error(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), tmp_path / "state") request = httpx.Request("POST", "http://127.0.0.1:8200/v1/ssh/sign/agt-role") exc = httpx.ConnectError("connection refused", request=request) with patch("warden.vault.httpx.post", side_effect=exc): with pytest.raises(CAError, match="unreachable"): ca.sign(spec) def test_vault_ca_sign_missing_token(tmp_path, monkeypatch): monkeypatch.delenv("VAULT_TOKEN", raising=False) spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), tmp_path / "state") with pytest.raises(CAError, match="VAULT_TOKEN"): ca.sign(spec) def test_vault_ca_sign_missing_role(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") cfg = _make_cfg(role_map={}) # no roles mapped spec = _make_spec(tmp_path) ca = VaultCA(cfg, tmp_path / "state") with pytest.raises(CAError, match="role_map"): ca.sign(spec) def test_vault_ca_sign_missing_pubkey(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path, pubkey_path=tmp_path / "nonexistent.pub") ca = VaultCA(_make_cfg(), tmp_path / "state") with pytest.raises(CAError, match="Public key not found"): ca.sign(spec) def test_vault_ca_sign_ttl_enforcement(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") spec = _make_spec(tmp_path, ttl_hours=100) # AGT max is 24h ca = VaultCA(_make_cfg(), tmp_path / "state") with pytest.raises(CAError, match="exceeds maximum"): ca.sign(spec) def test_vault_ca_sign_evicts_existing_cert(tmp_path, monkeypatch): monkeypatch.setenv("VAULT_TOKEN", "fake-token") state = tmp_path / "state" state.mkdir() old_cert = state / "agt-test-cert.pub" old_cert.write_text("old cert") spec = _make_spec(tmp_path) ca = VaultCA(_make_cfg(), state) with ( patch("warden.vault.httpx.post", return_value=_mock_httpx_post(SAMPLE_CERT)), patch("warden.ca.subprocess.run", side_effect=_mock_ssh_keygen_L), ): record = ca.sign(spec) assert record.cert_path.read_text().strip() == SAMPLE_CERT assert len(list(state.glob("agt-test-cert.pub"))) == 1