Files
ops-bridge/tests/test_manager.py
tegwick bd169a07e2 feat(directive): implement BRIDGE-WP-0004 AccessManagementDirective alignment
- ActorType enum (adm/agt/atm) replaces actor_class string; config validates
  naming convention (adm-*/agt-*/atm-*) with hard ConfigError on mismatch;
  legacy 'human'/'automation' values accepted with DeprecationWarning
- cert_command: pluggable shell string run before each SSH launch; cert written
  to state dir; -i cert appended to SSH command alongside -i key
- TTL-aware cert refresh: parses Valid-to via ssh-keygen -L; pre-emptive restart
  5 min before expiry (no backoff, no attempt increment); CERT_EXPIRING logged
- CertAcquisitionError: cert failures trigger normal backoff/retry loop
- cert_identity: Key ID parsed from cert and recorded in BRIDGE_CONNECTED event
- bridge cert-status: new CLI command; exit 1 on expired cert; --json flag
- 233 tests passing, ruff clean

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 09:38:29 +02:00

204 lines
7.3 KiB
Python

"""Tests for TunnelManager."""
import os
import signal
from unittest.mock import MagicMock, patch
import pytest
from bridge.models import BridgeState, ReconnectPolicy, TunnelConfig
from bridge.manager import TunnelManager, build_ssh_command
@pytest.fixture
def tunnel_cfg():
return TunnelConfig(
name="test-tunnel",
host="host.local",
remote_port=18000,
local_port=8000,
ssh_user="ubuntu",
ssh_key="~/.ssh/id_ops",
actor="operator.bernd",
reconnect=ReconnectPolicy(max_attempts=3, backoff_initial=1, backoff_max=5),
)
@pytest.fixture
def state_dir(tmp_path):
return tmp_path / "bridge"
class TestBuildSshCommand:
def test_basic_command(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert cmd[0] == "ssh"
assert "-N" in cmd
assert "-R" in cmd
assert "18000:127.0.0.1:8000" in cmd
assert "-i" in cmd
assert "ubuntu@host.local" in cmd
def test_server_alive_options(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert "-o" in cmd
assert "ServerAliveInterval=10" in cmd
assert "ExitOnForwardFailure=yes" in cmd
def test_ssh_key_expanded(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
key_idx = cmd.index("-i") + 1
assert not cmd[key_idx].startswith("~")
class TestTunnelManager:
def test_get_state_initial(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
assert mgr.get_state() == BridgeState.STOPPED
def test_stop_when_not_running_is_noop(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# Should not raise
mgr.stop()
assert mgr.get_state() == BridgeState.STOPPED
def test_stop_kills_pid(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# Write a fake PID of our own process to simulate running
mgr._state.write_pid(tunnel_cfg.name, os.getpid())
mgr._state.write_state(tunnel_cfg.name, BridgeState.CONNECTED)
with patch("os.kill") as mock_kill:
mgr.stop()
# Should have sent SIGTERM
mock_kill.assert_any_call(os.getpid(), signal.SIGTERM)
assert mgr.get_state() == BridgeState.STOPPED
def test_backoff_calculation(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# First backoff = initial
assert mgr._next_backoff(0) == 1
# Doubles each time up to max
assert mgr._next_backoff(1) == 2
assert mgr._next_backoff(2) == 4
assert mgr._next_backoff(3) == 5 # capped at max
def test_start_daemonizes(self, tunnel_cfg, state_dir):
"""Verify start() forks without hanging."""
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# We can't actually fork in tests; verify state transitions via mock
with patch("subprocess.Popen") as mock_popen, \
patch("os.fork", return_value=1234), \
patch("os.setsid"), \
patch("os._exit"):
mock_proc = MagicMock()
mock_proc.pid = 9999
mock_popen.return_value = mock_proc
# When fork returns non-zero we're the parent — just check PID written
mgr.start()
# After start the state should be STARTING (set before fork)
# and PID file should exist (written in parent branch)
def test_is_running_false_initially(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
assert not mgr.is_running()
class TestBuildSshCommandWithCert:
def test_no_cert_path_omits_extra_i(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert cmd.count("-i") == 1
def test_cert_path_appends_after_key(self, tunnel_cfg, tmp_path):
cert = tmp_path / "test-cert.pub"
cert.write_text("cert")
cmd = build_ssh_command(tunnel_cfg, cert_path=cert)
i_indices = [i for i, x in enumerate(cmd) if x == "-i"]
assert len(i_indices) == 2
key_idx, cert_idx = i_indices
assert not cmd[key_idx + 1].endswith("-cert.pub") # key comes first
assert cmd[cert_idx + 1] == str(cert)
class TestRunCertCommand:
def test_returns_none_when_no_cert_command(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
assert _run_cert_command(tunnel_cfg, tmp_path) is None
def test_writes_cert_and_returns_path(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
tunnel_cfg.cert_command = "echo 'ssh-rsa-cert AAAA'"
path = _run_cert_command(tunnel_cfg, tmp_path)
assert path is not None
assert path.exists()
assert "ssh-rsa-cert" in path.read_text()
def test_raises_on_nonzero_exit(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
from bridge.models import CertAcquisitionError
tunnel_cfg.cert_command = "exit 1"
with pytest.raises(CertAcquisitionError):
_run_cert_command(tunnel_cfg, tmp_path)
class TestActorTypeFromName:
def test_adm_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("adm-bernd") == "adm"
def test_agt_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("agt-claude") == "agt"
def test_atm_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("atm-cron") == "atm"
def test_unknown_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("operator.bernd") == "unknown"
class TestTtlRefresh:
def test_parse_cert_expiry_returns_none_for_missing_file(self, tmp_path):
from bridge.manager import _parse_cert_expiry
missing = tmp_path / "no.pub"
result = _parse_cert_expiry(missing)
assert result is None
def test_parse_cert_identity_returns_none_for_missing_file(self, tmp_path):
from bridge.manager import _parse_cert_identity
missing = tmp_path / "no.pub"
result = _parse_cert_identity(missing)
assert result is None
def test_parse_cert_identity_from_keygen_output(self, tmp_path):
from unittest.mock import patch, MagicMock
from bridge.manager import _parse_cert_identity
cert = tmp_path / "test.pub"
cert.write_text("fake")
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
stdout='test.pub:\n Key ID: "agt-bridge"\n',
returncode=0,
)
result = _parse_cert_identity(cert)
assert result == "agt-bridge"
def test_parse_cert_expiry_from_keygen_output(self, tmp_path):
from unittest.mock import patch, MagicMock
from bridge.manager import _parse_cert_expiry
cert = tmp_path / "test.pub"
cert.write_text("fake")
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
stdout="test.pub:\n Valid: from 2026-05-15T10:00:00 to 2030-05-15T22:00:00\n",
returncode=0,
)
result = _parse_cert_expiry(cert)
assert result is not None
assert result.year == 2030