"""Tests for TunnelManager.""" import os import signal from unittest.mock import MagicMock, patch import pytest from bridge.models import BridgeState, ReconnectPolicy, TunnelConfig from bridge.manager import TunnelManager, build_ssh_command @pytest.fixture def tunnel_cfg(): return TunnelConfig( name="test-tunnel", host="host.local", remote_port=18000, local_port=8000, ssh_user="ubuntu", ssh_key="~/.ssh/id_ops", actor="operator.bernd", reconnect=ReconnectPolicy(max_attempts=3, backoff_initial=1, backoff_max=5), ) @pytest.fixture def state_dir(tmp_path): return tmp_path / "bridge" class TestBuildSshCommand: def test_basic_command(self, tunnel_cfg): cmd = build_ssh_command(tunnel_cfg) assert cmd[0] == "ssh" assert "-N" in cmd assert "-R" in cmd assert "18000:127.0.0.1:8000" in cmd assert "-i" in cmd assert "ubuntu@host.local" in cmd def test_server_alive_options(self, tunnel_cfg): cmd = build_ssh_command(tunnel_cfg) assert "-o" in cmd assert "ServerAliveInterval=10" in cmd assert "ExitOnForwardFailure=yes" in cmd def test_ssh_key_expanded(self, tunnel_cfg): cmd = build_ssh_command(tunnel_cfg) key_idx = cmd.index("-i") + 1 assert not cmd[key_idx].startswith("~") class TestTunnelManager: def test_get_state_initial(self, tunnel_cfg, state_dir): mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) assert mgr.get_state() == BridgeState.STOPPED def test_stop_when_not_running_is_noop(self, tunnel_cfg, state_dir): mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) # Should not raise mgr.stop() assert mgr.get_state() == BridgeState.STOPPED def test_stop_kills_pid(self, tunnel_cfg, state_dir): mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) # Write a fake PID of our own process to simulate running mgr._state.write_pid(tunnel_cfg.name, os.getpid()) mgr._state.write_state(tunnel_cfg.name, BridgeState.CONNECTED) with patch("os.kill") as mock_kill: mgr.stop() # Should have sent SIGTERM mock_kill.assert_any_call(os.getpid(), signal.SIGTERM) assert mgr.get_state() == BridgeState.STOPPED def test_backoff_calculation(self, tunnel_cfg, state_dir): mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) # First backoff = initial assert mgr._next_backoff(0) == 1 # Doubles each time up to max assert mgr._next_backoff(1) == 2 assert mgr._next_backoff(2) == 4 assert mgr._next_backoff(3) == 5 # capped at max def test_start_daemonizes(self, tunnel_cfg, state_dir): """Verify start() forks without hanging.""" mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) # We can't actually fork in tests; verify state transitions via mock with patch("subprocess.Popen") as mock_popen, \ patch("os.fork", return_value=1234), \ patch("os.setsid"), \ patch("os._exit"): mock_proc = MagicMock() mock_proc.pid = 9999 mock_popen.return_value = mock_proc # When fork returns non-zero we're the parent — just check PID written mgr.start() # After start the state should be STARTING (set before fork) # and PID file should exist (written in parent branch) def test_is_running_false_initially(self, tunnel_cfg, state_dir): mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) assert not mgr.is_running() class TestBuildSshCommandWithCert: def test_no_cert_path_omits_extra_i(self, tunnel_cfg): cmd = build_ssh_command(tunnel_cfg) assert cmd.count("-i") == 1 def test_cert_path_appends_after_key(self, tunnel_cfg, tmp_path): cert = tmp_path / "test-cert.pub" cert.write_text("cert") cmd = build_ssh_command(tunnel_cfg, cert_path=cert) i_indices = [i for i, x in enumerate(cmd) if x == "-i"] assert len(i_indices) == 2 key_idx, cert_idx = i_indices assert not cmd[key_idx + 1].endswith("-cert.pub") # key comes first assert cmd[cert_idx + 1] == str(cert) class TestRunCertCommand: def test_returns_none_when_no_cert_command(self, tunnel_cfg, tmp_path): from bridge.manager import _run_cert_command assert _run_cert_command(tunnel_cfg, tmp_path) is None def test_writes_cert_and_returns_path(self, tunnel_cfg, tmp_path): from bridge.manager import _run_cert_command tunnel_cfg.cert_command = "echo 'ssh-rsa-cert AAAA'" path = _run_cert_command(tunnel_cfg, tmp_path) assert path is not None assert path.exists() assert "ssh-rsa-cert" in path.read_text() def test_raises_on_nonzero_exit(self, tunnel_cfg, tmp_path): from bridge.manager import _run_cert_command from bridge.models import CertAcquisitionError tunnel_cfg.cert_command = "exit 1" with pytest.raises(CertAcquisitionError): _run_cert_command(tunnel_cfg, tmp_path) class TestActorTypeFromName: def test_adm_prefix(self): from bridge.manager import _actor_type_from_name assert _actor_type_from_name("adm-bernd") == "adm" def test_agt_prefix(self): from bridge.manager import _actor_type_from_name assert _actor_type_from_name("agt-claude") == "agt" def test_atm_prefix(self): from bridge.manager import _actor_type_from_name assert _actor_type_from_name("atm-cron") == "atm" def test_unknown_prefix(self): from bridge.manager import _actor_type_from_name assert _actor_type_from_name("operator.bernd") == "unknown" class TestTtlRefresh: def test_parse_cert_expiry_returns_none_for_missing_file(self, tmp_path): from bridge.manager import _parse_cert_expiry missing = tmp_path / "no.pub" result = _parse_cert_expiry(missing) assert result is None def test_parse_cert_identity_returns_none_for_missing_file(self, tmp_path): from bridge.manager import _parse_cert_identity missing = tmp_path / "no.pub" result = _parse_cert_identity(missing) assert result is None def test_parse_cert_identity_from_keygen_output(self, tmp_path): from unittest.mock import patch, MagicMock from bridge.manager import _parse_cert_identity cert = tmp_path / "test.pub" cert.write_text("fake") with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( stdout='test.pub:\n Key ID: "agt-bridge"\n', returncode=0, ) result = _parse_cert_identity(cert) assert result == "agt-bridge" def test_parse_cert_expiry_from_keygen_output(self, tmp_path): from unittest.mock import patch, MagicMock from bridge.manager import _parse_cert_expiry cert = tmp_path / "test.pub" cert.write_text("fake") with patch("subprocess.run") as mock_run: mock_run.return_value = MagicMock( stdout="test.pub:\n Valid: from 2026-05-15T10:00:00 to 2030-05-15T22:00:00\n", returncode=0, ) result = _parse_cert_expiry(cert) assert result is not None assert result.year == 2030