Files
ops-bridge/tests/test_manager.py

216 lines
7.8 KiB
Python

"""Tests for TunnelManager."""
import os
import signal
from unittest.mock import MagicMock, patch
from dataclasses import replace
import pytest
from bridge.models import BridgeState, ReconnectPolicy, TunnelConfig
from bridge.manager import TunnelManager, build_ssh_command
@pytest.fixture
def tunnel_cfg():
return TunnelConfig(
name="test-tunnel",
host="host.local",
remote_port=18000,
local_port=8000,
ssh_user="ubuntu",
ssh_key="~/.ssh/id_ops",
actor="operator.bernd",
reconnect=ReconnectPolicy(max_attempts=3, backoff_initial=1, backoff_max=5),
)
@pytest.fixture
def state_dir(tmp_path):
return tmp_path / "bridge"
class TestBuildSshCommand:
def test_basic_command(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert cmd[0] == "ssh"
assert "-N" in cmd
assert "-R" in cmd
assert "18000:127.0.0.1:8000" in cmd
assert "-i" in cmd
assert "ubuntu@host.local" in cmd
def test_remote_host_override_local(self, tunnel_cfg):
cfg = replace(tunnel_cfg, direction="local", remote_host="10.43.103.154")
cmd = build_ssh_command(cfg)
assert "-L" in cmd
assert f"{cfg.local_port}:10.43.103.154:{cfg.remote_port}" in cmd
def test_remote_host_default_loopback(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert "18000:127.0.0.1:8000" in cmd
def test_server_alive_options(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert "-o" in cmd
assert "ServerAliveInterval=10" in cmd
assert "ExitOnForwardFailure=yes" in cmd
def test_ssh_key_expanded(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
key_idx = cmd.index("-i") + 1
assert not cmd[key_idx].startswith("~")
class TestTunnelManager:
def test_get_state_initial(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
assert mgr.get_state() == BridgeState.STOPPED
def test_stop_when_not_running_is_noop(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# Should not raise
mgr.stop()
assert mgr.get_state() == BridgeState.STOPPED
def test_stop_kills_pid(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# Write a fake PID of our own process to simulate running
mgr._state.write_pid(tunnel_cfg.name, os.getpid())
mgr._state.write_state(tunnel_cfg.name, BridgeState.CONNECTED)
with patch("os.kill") as mock_kill:
mgr.stop()
# Should have sent SIGTERM
mock_kill.assert_any_call(os.getpid(), signal.SIGTERM)
assert mgr.get_state() == BridgeState.STOPPED
def test_backoff_calculation(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# First backoff = initial
assert mgr._next_backoff(0) == 1
# Doubles each time up to max
assert mgr._next_backoff(1) == 2
assert mgr._next_backoff(2) == 4
assert mgr._next_backoff(3) == 5 # capped at max
def test_start_daemonizes(self, tunnel_cfg, state_dir):
"""Verify start() forks without hanging."""
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
# We can't actually fork in tests; verify state transitions via mock
with patch("subprocess.Popen") as mock_popen, \
patch("os.fork", return_value=1234), \
patch("os.setsid"), \
patch("os._exit"):
mock_proc = MagicMock()
mock_proc.pid = 9999
mock_popen.return_value = mock_proc
# When fork returns non-zero we're the parent — just check PID written
mgr.start()
# After start the state should be STARTING (set before fork)
# and PID file should exist (written in parent branch)
def test_is_running_false_initially(self, tunnel_cfg, state_dir):
mgr = TunnelManager(tunnel_cfg, state_dir=state_dir)
assert not mgr.is_running()
class TestBuildSshCommandWithCert:
def test_no_cert_path_omits_extra_i(self, tunnel_cfg):
cmd = build_ssh_command(tunnel_cfg)
assert cmd.count("-i") == 1
def test_cert_path_appends_after_key(self, tunnel_cfg, tmp_path):
cert = tmp_path / "test-cert.pub"
cert.write_text("cert")
cmd = build_ssh_command(tunnel_cfg, cert_path=cert)
i_indices = [i for i, x in enumerate(cmd) if x == "-i"]
assert len(i_indices) == 2
key_idx, cert_idx = i_indices
assert not cmd[key_idx + 1].endswith("-cert.pub") # key comes first
assert cmd[cert_idx + 1] == str(cert)
class TestRunCertCommand:
def test_returns_none_when_no_cert_command(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
assert _run_cert_command(tunnel_cfg, tmp_path) is None
def test_writes_cert_and_returns_path(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
tunnel_cfg.cert_command = "echo 'ssh-rsa-cert AAAA'"
path = _run_cert_command(tunnel_cfg, tmp_path)
assert path is not None
assert path.exists()
assert "ssh-rsa-cert" in path.read_text()
def test_raises_on_nonzero_exit(self, tunnel_cfg, tmp_path):
from bridge.manager import _run_cert_command
from bridge.models import CertAcquisitionError
tunnel_cfg.cert_command = "exit 1"
with pytest.raises(CertAcquisitionError):
_run_cert_command(tunnel_cfg, tmp_path)
class TestActorTypeFromName:
def test_adm_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("adm-bernd") == "adm"
def test_agt_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("agt-claude") == "agt"
def test_atm_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("atm-cron") == "atm"
def test_unknown_prefix(self):
from bridge.manager import _actor_type_from_name
assert _actor_type_from_name("operator.bernd") == "unknown"
class TestTtlRefresh:
def test_parse_cert_expiry_returns_none_for_missing_file(self, tmp_path):
from bridge.manager import _parse_cert_expiry
missing = tmp_path / "no.pub"
result = _parse_cert_expiry(missing)
assert result is None
def test_parse_cert_identity_returns_none_for_missing_file(self, tmp_path):
from bridge.manager import _parse_cert_identity
missing = tmp_path / "no.pub"
result = _parse_cert_identity(missing)
assert result is None
def test_parse_cert_identity_from_keygen_output(self, tmp_path):
from unittest.mock import patch, MagicMock
from bridge.manager import _parse_cert_identity
cert = tmp_path / "test.pub"
cert.write_text("fake")
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
stdout='test.pub:\n Key ID: "agt-bridge"\n',
returncode=0,
)
result = _parse_cert_identity(cert)
assert result == "agt-bridge"
def test_parse_cert_expiry_from_keygen_output(self, tmp_path):
from unittest.mock import patch, MagicMock
from bridge.manager import _parse_cert_expiry
cert = tmp_path / "test.pub"
cert.write_text("fake")
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(
stdout="test.pub:\n Valid: from 2026-05-15T10:00:00 to 2030-05-15T22:00:00\n",
returncode=0,
)
result = _parse_cert_expiry(cert)
assert result is not None
assert result.year == 2030