From a7eaf59ced0967db5ae92fd1ff1ff648341e82ea Mon Sep 17 00:00:00 2001 From: Bernd Worsch Date: Thu, 12 Mar 2026 01:40:08 +0000 Subject: [PATCH] feat: implement OpsBridge CLI (BRIDGE-WP-0001) Full TDD implementation of the `bridge` CLI tool covering all phases from BRIDGE-WP-0001: project scaffolding, config loading, state management, audit logging, health checks, tunnel lifecycle manager, and all CLI commands (up/down/restart/status/logs). 77 tests, all green. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 34 +++++ src/bridge/__init__.py | 0 src/bridge/audit.py | 65 ++++++++++ src/bridge/cli.py | 219 +++++++++++++++++++++++++++++++++ src/bridge/config.py | 109 +++++++++++++++++ src/bridge/health.py | 31 +++++ src/bridge/manager.py | 252 ++++++++++++++++++++++++++++++++++++++ src/bridge/models.py | 49 ++++++++ src/bridge/state.py | 73 +++++++++++ tests/__init__.py | 0 tests/test_audit.py | 90 ++++++++++++++ tests/test_cli.py | 201 ++++++++++++++++++++++++++++++ tests/test_config.py | 130 ++++++++++++++++++++ tests/test_health.py | 78 ++++++++++++ tests/test_integration.py | 219 +++++++++++++++++++++++++++++++++ tests/test_manager.py | 109 +++++++++++++++++ tests/test_models.py | 75 ++++++++++++ tests/test_state.py | 69 +++++++++++ 18 files changed, 1803 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/bridge/__init__.py create mode 100644 src/bridge/audit.py create mode 100644 src/bridge/cli.py create mode 100644 src/bridge/config.py create mode 100644 src/bridge/health.py create mode 100644 src/bridge/manager.py create mode 100644 src/bridge/models.py create mode 100644 src/bridge/state.py create mode 100644 tests/__init__.py create mode 100644 tests/test_audit.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_config.py create mode 100644 tests/test_health.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_manager.py create mode 100644 tests/test_models.py create mode 100644 tests/test_state.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..843b6bd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "ops-bridge" +version = "0.1.0" +description = "SSH reverse tunnel lifecycle manager" +requires-python = ">=3.11" +dependencies = [ + "typer>=0.12", + "pyyaml>=6.0", + "httpx>=0.27", +] + +[project.scripts] +bridge = "bridge.cli:app" + +[tool.hatch.build.targets.wheel] +packages = ["src/bridge"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] + +[tool.ruff] +line-length = 88 + +[dependency-groups] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "ruff>=0.4", +] diff --git a/src/bridge/__init__.py b/src/bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/bridge/audit.py b/src/bridge/audit.py new file mode 100644 index 0000000..f7f71be --- /dev/null +++ b/src/bridge/audit.py @@ -0,0 +1,65 @@ +"""Audit logging for OpsBridge lifecycle events.""" +from __future__ import annotations + +import json +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class AuditEvent(str, Enum): + BRIDGE_STARTED = "bridge_started" + BRIDGE_CONNECTED = "bridge_connected" + BRIDGE_DISCONNECTED = "bridge_disconnected" + BRIDGE_RECONNECTING = "bridge_reconnecting" + HEALTH_CHECK_FAILED = "health_check_failed" + HEALTH_CHECK_RECOVERED = "health_check_recovered" + BRIDGE_STOPPED = "bridge_stopped" + + +def _default_state_dir() -> Path: + return Path.home() / ".local" / "state" / "bridge" + + +class AuditLogger: + def __init__(self, state_dir: Optional[Path] = None): + self._dir = Path(state_dir) if state_dir else _default_state_dir() + + def _log_path(self, tunnel: str) -> Path: + return self._dir / f"{tunnel}.log" + + def log( + self, + tunnel: str, + event: AuditEvent, + actor: str, + actor_class: str, + detail: str = "", + ) -> None: + self._dir.mkdir(parents=True, exist_ok=True) + entry: Dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "tunnel": tunnel, + "actor": actor, + "actor_class": actor_class, + "event": event.value, + } + if detail: + entry["detail"] = detail + with self._log_path(tunnel).open("a") as f: + f.write(json.dumps(entry) + "\n") + + def read_events(self, tunnel: str) -> List[Dict[str, Any]]: + path = self._log_path(tunnel) + if not path.exists(): + return [] + events = [] + for line in path.read_text().splitlines(): + line = line.strip() + if line: + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + pass + return events diff --git a/src/bridge/cli.py b/src/bridge/cli.py new file mode 100644 index 0000000..9654296 --- /dev/null +++ b/src/bridge/cli.py @@ -0,0 +1,219 @@ +"""CLI for OpsBridge — bridge command.""" +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path +from typing import Optional + +import typer + +from bridge.audit import AuditLogger +from bridge.config import ConfigError, load_config +from bridge.manager import TunnelManager +from bridge.models import BridgeState +from bridge.state import StateManager + +app = typer.Typer( + name="bridge", + help="OpsBridge — SSH reverse tunnel lifecycle manager.", + no_args_is_help=True, +) + + +def _state_dir() -> Path: + return Path(os.environ.get("BRIDGE_STATE_DIR", str(Path.home() / ".local" / "state" / "bridge"))) + + +def _load_or_exit(): + try: + return load_config() + except ConfigError as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) + + +def _require_tunnel(cfg, name: str): + if name not in cfg.tunnels: + typer.echo(f"Error: tunnel '{name}' not found in config", err=True) + raise typer.Exit(1) + return cfg.tunnels[name] + + +@app.command() +def up( + tunnel: Optional[str] = typer.Argument(None, help="Tunnel name (omit for all)"), +): + """Start one or all tunnels.""" + cfg = _load_or_exit() + sd = _state_dir() + + names = [tunnel] if tunnel else list(cfg.tunnels.keys()) + if tunnel: + _require_tunnel(cfg, tunnel) + + any_already_running = False + for name in names: + tcfg = cfg.tunnels[name] + mgr = TunnelManager(tcfg, state_dir=sd) + if mgr.is_running(): + typer.echo(f"Tunnel '{name}' is already running.") + any_already_running = True + else: + mgr.start() + typer.echo(f"Started tunnel '{name}'.") + + if any_already_running and len(names) == 1: + raise typer.Exit(2) + + +@app.command() +def down( + tunnel: Optional[str] = typer.Argument(None, help="Tunnel name (omit for all)"), +): + """Stop one or all tunnels.""" + cfg = _load_or_exit() + sd = _state_dir() + + names = [tunnel] if tunnel else list(cfg.tunnels.keys()) + if tunnel: + _require_tunnel(cfg, tunnel) + + any_not_running = False + for name in names: + tcfg = cfg.tunnels[name] + mgr = TunnelManager(tcfg, state_dir=sd) + if not mgr.is_running(): + typer.echo(f"Tunnel '{name}' is not running.") + any_not_running = True + else: + mgr.stop() + typer.echo(f"Stopped tunnel '{name}'.") + + if any_not_running and len(names) == 1: + raise typer.Exit(2) + + +@app.command() +def restart( + tunnel: Optional[str] = typer.Argument(None, help="Tunnel name (omit for all)"), +): + """Restart one or all tunnels.""" + cfg = _load_or_exit() + sd = _state_dir() + + names = [tunnel] if tunnel else list(cfg.tunnels.keys()) + if tunnel: + _require_tunnel(cfg, tunnel) + + for name in names: + tcfg = cfg.tunnels[name] + mgr = TunnelManager(tcfg, state_dir=sd) + mgr.stop() + mgr.start() + typer.echo(f"Restarted tunnel '{name}'.") + + +@app.command() +def status( + as_json: bool = typer.Option(False, "--json", help="Output as JSON"), +): + """Show status of all tunnels.""" + cfg = _load_or_exit() + sd = _state_dir() + state_mgr = StateManager(state_dir=sd) + + rows = [] + for name, tcfg in cfg.tunnels.items(): + state = state_mgr.read_state(name) + pid = state_mgr.read_pid(name) + rows.append({ + "tunnel": name, + "state": state.value, + "actor": tcfg.actor, + "host": tcfg.host, + "pid": pid, + "uptime": None, # future: track start time + "health": None, # future: last health check result + }) + + if as_json: + typer.echo(json.dumps(rows, indent=2)) + else: + _print_status_table(rows) + + +def _print_status_table(rows): + headers = ["TUNNEL", "STATE", "ACTOR", "HOST", "PID"] + col_widths = [max(len(h), max((len(str(r.get(h.lower(), "") or "")) for r in rows), default=0)) for h in headers] + + def _fmt_row(vals): + return " ".join(str(v).ljust(w) for v, w in zip(vals, col_widths)) + + typer.echo(_fmt_row(headers)) + typer.echo(_fmt_row(["-" * w for w in col_widths])) + for row in rows: + typer.echo(_fmt_row([ + row["tunnel"], + row["state"], + row["actor"], + row["host"], + str(row["pid"] or ""), + ])) + + +@app.command() +def logs( + tunnel: str = typer.Argument(..., help="Tunnel name"), + lines: int = typer.Option(50, "--lines", "-n", help="Number of lines to show"), + follow: bool = typer.Option(False, "--follow", "-f", help="Follow the log"), +): + """Show audit log for a tunnel.""" + cfg = _load_or_exit() + _require_tunnel(cfg, tunnel) + + sd = _state_dir() + logger = AuditLogger(state_dir=sd) + events = logger.read_events(tunnel) + + if not events: + typer.echo(f"No log entries for tunnel '{tunnel}'.") + return + + # Show last N lines + for entry in events[-lines:]: + ts = entry.get("timestamp", "") + event = entry.get("event", "") + actor = entry.get("actor", "") + detail = entry.get("detail", "") + parts = [ts, event, f"actor={actor}"] + if detail: + parts.append(detail) + typer.echo(" ".join(parts)) + + if follow: + import time + log_path = sd / f"{tunnel}.log" + try: + with log_path.open() as f: + f.seek(0, 2) # seek to end + while True: + line = f.readline() + if line: + try: + entry = json.loads(line) + ts = entry.get("timestamp", "") + event = entry.get("event", "") + actor = entry.get("actor", "") + detail = entry.get("detail", "") + parts = [ts, event, f"actor={actor}"] + if detail: + parts.append(detail) + typer.echo(" ".join(parts)) + except json.JSONDecodeError: + pass + else: + time.sleep(0.5) + except KeyboardInterrupt: + pass diff --git a/src/bridge/config.py b/src/bridge/config.py new file mode 100644 index 0000000..9294815 --- /dev/null +++ b/src/bridge/config.py @@ -0,0 +1,109 @@ +"""Config loading for OpsBridge.""" +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Dict + +import yaml + +from bridge.models import ActorInfo, HealthCheckConfig, ReconnectPolicy, TunnelConfig + + +class ConfigError(Exception): + """Raised when config is invalid or missing.""" + + +@dataclass +class BridgeConfig: + tunnels: Dict[str, TunnelConfig] + actors: Dict[str, ActorInfo] + + +def _default_config_path() -> Path: + return Path.home() / ".config" / "bridge" / "tunnels.yaml" + + +def load_config() -> BridgeConfig: + """Load and validate tunnels.yaml. Respects BRIDGE_CONFIG env var.""" + path = Path(os.environ.get("BRIDGE_CONFIG", str(_default_config_path()))) + + if not path.exists(): + raise ConfigError(f"Config file not found: {path}") + + try: + with path.open() as f: + raw = yaml.safe_load(f) + except yaml.YAMLError as e: + raise ConfigError(f"Invalid YAML in {path}: {e}") from e + + if not isinstance(raw, dict): + raise ConfigError(f"Config must be a YAML mapping, got: {type(raw)}") + + tunnels = _parse_tunnels(raw.get("tunnels") or {}) + actors = _parse_actors(raw.get("actors") or {}) + return BridgeConfig(tunnels=tunnels, actors=actors) + + +def _parse_tunnels(raw: dict) -> Dict[str, TunnelConfig]: + tunnels = {} + for name, data in raw.items(): + if not isinstance(data, dict): + raise ConfigError(f"Tunnel '{name}' must be a mapping") + tunnels[name] = _parse_tunnel(name, data) + return tunnels + + +def _parse_tunnel(name: str, data: dict) -> TunnelConfig: + required = ["host", "remote_port", "local_port", "ssh_user", "ssh_key", "actor"] + for field in required: + if field not in data: + raise ConfigError(f"Tunnel '{name}' missing required field: {field}") + + reconnect = ReconnectPolicy() + if "reconnect" in data and data["reconnect"]: + r = data["reconnect"] + reconnect = ReconnectPolicy( + max_attempts=r.get("max_attempts", 0), + backoff_initial=r.get("backoff_initial", 5), + backoff_max=r.get("backoff_max", 60), + ) + + health_check = None + if "health_check" in data and data["health_check"]: + hc = data["health_check"] + if "url" not in hc: + raise ConfigError(f"Tunnel '{name}' health_check missing required field: url") + health_check = HealthCheckConfig( + url=hc["url"], + interval_seconds=hc.get("interval_seconds", 30), + timeout_seconds=hc.get("timeout_seconds", 5), + ) + + return TunnelConfig( + name=name, + host=str(data["host"]), + remote_port=int(data["remote_port"]), + local_port=int(data["local_port"]), + ssh_user=str(data["ssh_user"]), + ssh_key=str(data["ssh_key"]), + actor=str(data["actor"]), + reconnect=reconnect, + health_check=health_check, + ) + + +def _parse_actors(raw: dict) -> Dict[str, ActorInfo]: + actors = {} + for name, data in raw.items(): + if not isinstance(data, dict): + raise ConfigError(f"Actor '{name}' must be a mapping") + if "class" not in data: + raise ConfigError(f"Actor '{name}' missing required field: class") + actors[name] = ActorInfo( + name=name, + actor_class=str(data["class"]), + description=str(data.get("description", "")), + ) + return actors diff --git a/src/bridge/health.py b/src/bridge/health.py new file mode 100644 index 0000000..ba9614e --- /dev/null +++ b/src/bridge/health.py @@ -0,0 +1,31 @@ +"""HTTP health checker for OpsBridge.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional + +import httpx + + +@dataclass +class HealthResult: + ok: bool + status_code: Optional[int] = None + error: Optional[str] = None + + +class HealthChecker: + def __init__(self, url: str, timeout_seconds: int = 5): + self._url = url + self._timeout = timeout_seconds + + async def check(self) -> HealthResult: + try: + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.get(self._url) + response.raise_for_status() + return HealthResult(ok=True, status_code=response.status_code) + except httpx.HTTPStatusError as e: + return HealthResult(ok=False, status_code=e.response.status_code, error=str(e)) + except Exception as e: + return HealthResult(ok=False, error=str(e)) diff --git a/src/bridge/manager.py b/src/bridge/manager.py new file mode 100644 index 0000000..af47c5c --- /dev/null +++ b/src/bridge/manager.py @@ -0,0 +1,252 @@ +"""Tunnel lifecycle manager for OpsBridge.""" +from __future__ import annotations + +import logging +import os +import signal +import subprocess +import time +from pathlib import Path +from typing import List, Optional + +from bridge.audit import AuditEvent, AuditLogger +from bridge.config import BridgeConfig +from bridge.health import HealthChecker +from bridge.models import BridgeState, TunnelConfig +from bridge.state import StateManager + +log = logging.getLogger(__name__) + + +def build_ssh_command(cfg: TunnelConfig) -> List[str]: + """Build the SSH reverse tunnel command.""" + key = os.path.expanduser(cfg.ssh_key) + return [ + "ssh", + "-N", + "-R", f"{cfg.remote_port}:127.0.0.1:{cfg.local_port}", + "-i", key, + "-o", "ServerAliveInterval=10", + "-o", "ServerAliveCountMax=3", + "-o", "ExitOnForwardFailure=yes", + "-o", "StrictHostKeyChecking=accept-new", + f"{cfg.ssh_user}@{cfg.host}", + ] + + +class TunnelManager: + """Manages a single named SSH reverse tunnel. + + start() daemonises: forks a child that runs the reconnect loop, then the + parent returns immediately after writing the manager PID. + """ + + def __init__(self, cfg: TunnelConfig, state_dir: Optional[Path] = None): + self._cfg = cfg + self._state = StateManager(state_dir=state_dir) + self._audit = AuditLogger(state_dir=state_dir) + + def get_state(self) -> BridgeState: + return self._state.read_state(self._cfg.name) + + def is_running(self) -> bool: + return self._state.is_running(self._cfg.name) + + def _actor_info(self): + return self._cfg.actor, "unknown" + + def _next_backoff(self, attempt: int) -> int: + initial = self._cfg.reconnect.backoff_initial + max_b = self._cfg.reconnect.backoff_max + value = initial * (2 ** attempt) + return min(value, max_b) + + def start(self) -> None: + """Start the tunnel manager as a daemonised subprocess.""" + if self.is_running(): + log.info("Tunnel %s already running", self._cfg.name) + return + + self._state.write_state(self._cfg.name, BridgeState.STARTING) + actor, actor_class = self._actor_info() + self._audit.log( + tunnel=self._cfg.name, + event=AuditEvent.BRIDGE_STARTED, + actor=actor, + actor_class=actor_class, + ) + + pid = os.fork() + if pid > 0: + # Parent: record manager PID and return + self._state.write_pid(self._cfg.name, pid) + return + + # Child: become a daemon + os.setsid() + + try: + self._run_loop() + except Exception as e: + log.exception("Tunnel manager loop crashed: %s", e) + finally: + self._state.write_state(self._cfg.name, BridgeState.STOPPED) + self._state.clear_pid(self._cfg.name) + self._audit.log( + tunnel=self._cfg.name, + event=AuditEvent.BRIDGE_STOPPED, + actor=actor, + actor_class=actor_class, + ) + + os._exit(0) + + def stop(self) -> None: + """Stop the running tunnel manager.""" + pid = self._state.read_pid(self._cfg.name) + if pid is None: + self._state.write_state(self._cfg.name, BridgeState.STOPPED) + return + + try: + os.kill(pid, signal.SIGTERM) + # Give up to 5 seconds for graceful shutdown + for _ in range(50): + try: + os.kill(pid, 0) + time.sleep(0.1) + except ProcessLookupError: + break + else: + # Force kill if still running + try: + os.kill(pid, signal.SIGKILL) + except ProcessLookupError: + pass + except ProcessLookupError: + pass + + self._state.clear_pid(self._cfg.name) + self._state.write_state(self._cfg.name, BridgeState.STOPPED) + actor, actor_class = self._actor_info() + self._audit.log( + tunnel=self._cfg.name, + event=AuditEvent.BRIDGE_STOPPED, + actor=actor, + actor_class=actor_class, + ) + + def _run_loop(self) -> None: + """Reconnect loop running in daemon child.""" + import asyncio + + cfg = self._cfg + actor, actor_class = self._actor_info() + attempt = 0 + max_attempts = cfg.reconnect.max_attempts # 0 = infinite + + # Setup signal handler for graceful shutdown + _stop = [False] + + def _on_term(signum, frame): + _stop[0] = True + + signal.signal(signal.SIGTERM, _on_term) + signal.signal(signal.SIGINT, _on_term) + + while not _stop[0]: + if max_attempts > 0 and attempt >= max_attempts: + self._state.write_state(cfg.name, BridgeState.FAILED) + break + + cmd = build_ssh_command(cfg) + log.info("Starting SSH: %s", " ".join(cmd)) + self._state.write_state(cfg.name, BridgeState.STARTING) + + try: + proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except FileNotFoundError: + self._state.write_state(cfg.name, BridgeState.FAILED) + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.BRIDGE_DISCONNECTED, + actor=actor, + actor_class=actor_class, + detail="ssh binary not found", + ) + break + + # Wait briefly then assume connected if still running + time.sleep(2) + if proc.poll() is None: + self._state.write_state(cfg.name, BridgeState.CONNECTED) + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.BRIDGE_CONNECTED, + actor=actor, + actor_class=actor_class, + ) + attempt = 0 + + # Health check loop + if cfg.health_check: + checker = HealthChecker( + url=cfg.health_check.url, + timeout_seconds=cfg.health_check.timeout_seconds, + ) + health_failing = False + while not _stop[0] and proc.poll() is None: + result = asyncio.run(checker.check()) + if result.ok: + if health_failing: + health_failing = False + self._state.write_state(cfg.name, BridgeState.CONNECTED) + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.HEALTH_CHECK_RECOVERED, + actor=actor, + actor_class=actor_class, + ) + else: + if not health_failing: + health_failing = True + self._state.write_state(cfg.name, BridgeState.DEGRADED) + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.HEALTH_CHECK_FAILED, + actor=actor, + actor_class=actor_class, + detail=result.error or f"HTTP {result.status_code}", + ) + time.sleep(cfg.health_check.interval_seconds) + else: + while not _stop[0] and proc.poll() is None: + time.sleep(1) + + # SSH exited + if proc.poll() is not None: + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.BRIDGE_DISCONNECTED, + actor=actor, + actor_class=actor_class, + detail=f"exit code {proc.returncode}", + ) + + if _stop[0]: + if proc.poll() is None: + proc.terminate() + break + + attempt += 1 + backoff = self._next_backoff(attempt - 1) + self._state.write_state(cfg.name, BridgeState.RECONNECTING) + self._audit.log( + tunnel=cfg.name, + event=AuditEvent.BRIDGE_RECONNECTING, + actor=actor, + actor_class=actor_class, + detail=f"retry {attempt}, backoff {backoff}s", + ) + log.info("Reconnecting in %ds (attempt %d)", backoff, attempt) + time.sleep(backoff) diff --git a/src/bridge/models.py b/src/bridge/models.py new file mode 100644 index 0000000..8beca65 --- /dev/null +++ b/src/bridge/models.py @@ -0,0 +1,49 @@ +"""Domain models for OpsBridge.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + + +class BridgeState(str, Enum): + STOPPED = "stopped" + STARTING = "starting" + CONNECTED = "connected" + DEGRADED = "degraded" + RECONNECTING = "reconnecting" + FAILED = "failed" + + +@dataclass +class ReconnectPolicy: + max_attempts: int = 0 # 0 = infinite + backoff_initial: int = 5 + backoff_max: int = 60 + + +@dataclass +class HealthCheckConfig: + url: str + interval_seconds: int = 30 + timeout_seconds: int = 5 + + +@dataclass +class TunnelConfig: + name: str + host: str + remote_port: int + local_port: int + ssh_user: str + ssh_key: str + actor: str + reconnect: ReconnectPolicy = field(default_factory=ReconnectPolicy) + health_check: Optional[HealthCheckConfig] = None + + +@dataclass +class ActorInfo: + name: str + actor_class: str # "human" or "automation" + description: str = "" diff --git a/src/bridge/state.py b/src/bridge/state.py new file mode 100644 index 0000000..e9e2550 --- /dev/null +++ b/src/bridge/state.py @@ -0,0 +1,73 @@ +"""State file management for OpsBridge.""" +from __future__ import annotations + +import os +from pathlib import Path +from typing import Optional + +from bridge.models import BridgeState + + +def _default_state_dir() -> Path: + return Path.home() / ".local" / "state" / "bridge" + + +class StateManager: + def __init__(self, state_dir: Optional[Path] = None): + self._dir = Path(state_dir) if state_dir else _default_state_dir() + + def _ensure_dir(self) -> None: + self._dir.mkdir(parents=True, exist_ok=True) + + def _state_path(self, name: str) -> Path: + return self._dir / f"{name}.state" + + def _pid_path(self, name: str) -> Path: + return self._dir / f"{name}.pid" + + def read_state(self, name: str) -> BridgeState: + path = self._state_path(name) + if not path.exists(): + return BridgeState.STOPPED + text = path.read_text().strip() + try: + return BridgeState(text) + except ValueError: + return BridgeState.STOPPED + + def write_state(self, name: str, state: BridgeState) -> None: + self._ensure_dir() + self._state_path(name).write_text(state.value) + + def read_pid(self, name: str) -> Optional[int]: + path = self._pid_path(name) + if not path.exists(): + return None + try: + pid = int(path.read_text().strip()) + except (ValueError, OSError): + return None + if _pid_alive(pid): + return pid + return None + + def write_pid(self, name: str, pid: int) -> None: + self._ensure_dir() + self._pid_path(name).write_text(str(pid)) + + def clear_pid(self, name: str) -> None: + path = self._pid_path(name) + if path.exists(): + path.unlink() + + def is_running(self, name: str) -> bool: + return self.read_pid(name) is not None + + +def _pid_alive(pid: int) -> bool: + """Return True if the process with given PID exists.""" + try: + os.kill(pid, 0) + return True + except (ProcessLookupError, PermissionError): + return False diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_audit.py b/tests/test_audit.py new file mode 100644 index 0000000..ad24d50 --- /dev/null +++ b/tests/test_audit.py @@ -0,0 +1,90 @@ +"""Tests for audit logging.""" +import json +from pathlib import Path + +import pytest + +from bridge.audit import AuditLogger, AuditEvent + + +@pytest.fixture +def log_dir(tmp_path): + return tmp_path / "bridge" + + +@pytest.fixture +def logger(log_dir): + return AuditLogger(state_dir=log_dir) + + +class TestAuditLogger: + def test_log_event_creates_file(self, logger, log_dir): + logger.log( + tunnel="my-tunnel", + event=AuditEvent.BRIDGE_STARTED, + actor="operator.bernd", + actor_class="human", + ) + log_file = log_dir / "my-tunnel.log" + assert log_file.exists() + + def test_log_event_is_json_line(self, logger, log_dir): + logger.log( + tunnel="my-tunnel", + event=AuditEvent.BRIDGE_STARTED, + actor="operator.bernd", + actor_class="human", + ) + lines = (log_dir / "my-tunnel.log").read_text().strip().splitlines() + assert len(lines) == 1 + entry = json.loads(lines[0]) + assert entry["tunnel"] == "my-tunnel" + assert entry["event"] == "bridge_started" + assert entry["actor"] == "operator.bernd" + assert entry["actor_class"] == "human" + assert "timestamp" in entry + + def test_multiple_events_append(self, logger, log_dir): + for event in [AuditEvent.BRIDGE_STARTED, AuditEvent.BRIDGE_CONNECTED, AuditEvent.BRIDGE_STOPPED]: + logger.log(tunnel="t", event=event, actor="a", actor_class="human") + lines = (log_dir / "t.log").read_text().strip().splitlines() + assert len(lines) == 3 + + def test_log_with_detail(self, logger, log_dir): + logger.log( + tunnel="t", + event=AuditEvent.HEALTH_CHECK_FAILED, + actor="a", + actor_class="automation", + detail="connection refused", + ) + entry = json.loads((log_dir / "t.log").read_text().strip()) + assert entry["detail"] == "connection refused" + + def test_all_event_types_defined(self): + events = {e.value for e in AuditEvent} + assert "bridge_started" in events + assert "bridge_connected" in events + assert "bridge_disconnected" in events + assert "bridge_reconnecting" in events + assert "health_check_failed" in events + assert "health_check_recovered" in events + assert "bridge_stopped" in events + + def test_timestamp_is_iso8601(self, logger, log_dir): + from datetime import datetime + logger.log(tunnel="t", event=AuditEvent.BRIDGE_STOPPED, actor="a", actor_class="human") + entry = json.loads((log_dir / "t.log").read_text().strip()) + # Should parse without error + dt = datetime.fromisoformat(entry["timestamp"]) + assert dt.tzinfo is not None or True # UTC or naive both acceptable + + def test_read_events(self, logger, log_dir): + logger.log(tunnel="t", event=AuditEvent.BRIDGE_STARTED, actor="a", actor_class="human") + logger.log(tunnel="t", event=AuditEvent.BRIDGE_STOPPED, actor="a", actor_class="human") + events = logger.read_events("t") + assert len(events) == 2 + assert events[0]["event"] == "bridge_started" + + def test_read_events_missing_returns_empty(self, logger): + assert logger.read_events("nonexistent") == [] diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..84a2637 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,201 @@ +"""Tests for CLI commands.""" +import json +import os +import textwrap +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from bridge.cli import app + + +VALID_CONFIG = textwrap.dedent("""\ + tunnels: + test-tunnel: + host: host.local + remote_port: 18000 + local_port: 8000 + ssh_user: ubuntu + ssh_key: ~/.ssh/id_ops + actor: operator.bernd + actors: + operator.bernd: + class: human + description: Bernd +""") + +runner = CliRunner() + + +@pytest.fixture +def config_file(tmp_path): + f = tmp_path / "tunnels.yaml" + f.write_text(VALID_CONFIG) + return f + + +@pytest.fixture +def state_dir(tmp_path): + return tmp_path / "state" + + +@pytest.fixture +def env(config_file, state_dir): + return {"BRIDGE_CONFIG": str(config_file), "BRIDGE_STATE_DIR": str(state_dir)} + + +class TestHelpCommand: + def test_app_help(self): + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "bridge" in result.output.lower() or "Usage" in result.output + + def test_up_help(self): + result = runner.invoke(app, ["up", "--help"]) + assert result.exit_code == 0 + + def test_down_help(self): + result = runner.invoke(app, ["down", "--help"]) + assert result.exit_code == 0 + + def test_status_help(self): + result = runner.invoke(app, ["status", "--help"]) + assert result.exit_code == 0 + + def test_logs_help(self): + result = runner.invoke(app, ["logs", "--help"]) + assert result.exit_code == 0 + + def test_restart_help(self): + result = runner.invoke(app, ["restart", "--help"]) + assert result.exit_code == 0 + + +class TestStatusCommand: + def test_status_shows_tunnels(self, env, state_dir): + result = runner.invoke(app, ["status"], env=env) + assert result.exit_code == 0 + assert "test-tunnel" in result.output + + def test_status_json_flag(self, env, state_dir): + result = runner.invoke(app, ["status", "--json"], env=env) + assert result.exit_code == 0 + data = json.loads(result.output) + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["tunnel"] == "test-tunnel" + assert "state" in data[0] + assert "actor" in data[0] + assert "host" in data[0] + + def test_status_shows_state(self, env, state_dir): + result = runner.invoke(app, ["status"], env=env) + assert result.exit_code == 0 + assert "stopped" in result.output.lower() + + def test_status_unknown_config_exit_1(self, tmp_path): + result = runner.invoke(app, ["status"], env={"BRIDGE_CONFIG": str(tmp_path / "no.yaml")}) + assert result.exit_code == 1 + + +class TestUpCommand: + def test_up_unknown_tunnel_exit_1(self, env): + result = runner.invoke(app, ["up", "nonexistent"], env=env) + assert result.exit_code == 1 + assert "nonexistent" in result.output + + def test_up_calls_manager_start(self, env, state_dir): + with patch("bridge.cli.TunnelManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr.is_running.return_value = False + mock_mgr_cls.return_value = mock_mgr + + result = runner.invoke(app, ["up", "test-tunnel"], env=env) + + assert result.exit_code == 0 + mock_mgr.start.assert_called_once() + + def test_up_already_running_exit_2(self, env, state_dir): + with patch("bridge.cli.TunnelManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr.is_running.return_value = True + mock_mgr_cls.return_value = mock_mgr + + result = runner.invoke(app, ["up", "test-tunnel"], env=env) + + assert result.exit_code == 2 + + +class TestDownCommand: + def test_down_unknown_tunnel_exit_1(self, env): + result = runner.invoke(app, ["down", "nonexistent"], env=env) + assert result.exit_code == 1 + + def test_down_calls_manager_stop(self, env, state_dir): + with patch("bridge.cli.TunnelManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr.is_running.return_value = True + mock_mgr_cls.return_value = mock_mgr + + result = runner.invoke(app, ["down", "test-tunnel"], env=env) + + assert result.exit_code == 0 + mock_mgr.stop.assert_called_once() + + def test_down_not_running_exit_2(self, env, state_dir): + with patch("bridge.cli.TunnelManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr.is_running.return_value = False + mock_mgr_cls.return_value = mock_mgr + + result = runner.invoke(app, ["down", "test-tunnel"], env=env) + + assert result.exit_code == 2 + + +class TestLogsCommand: + def test_logs_unknown_tunnel_exit_1(self, env): + result = runner.invoke(app, ["logs", "nonexistent"], env=env) + assert result.exit_code == 1 + + def test_logs_no_log_file_shows_empty(self, env, state_dir): + result = runner.invoke(app, ["logs", "test-tunnel"], env=env) + assert result.exit_code == 0 + + def test_logs_shows_events(self, env, state_dir): + import json as _json + state_dir.mkdir(parents=True, exist_ok=True) + log_file = state_dir / "test-tunnel.log" + log_file.write_text( + _json.dumps({ + "timestamp": "2026-01-01T00:00:00+00:00", + "tunnel": "test-tunnel", + "actor": "operator.bernd", + "actor_class": "human", + "event": "bridge_started", + }) + "\n" + ) + result = runner.invoke(app, ["logs", "test-tunnel"], env=env) + assert result.exit_code == 0 + assert "bridge_started" in result.output + + +class TestRestartCommand: + def test_restart_unknown_tunnel_exit_1(self, env): + result = runner.invoke(app, ["restart", "nonexistent"], env=env) + assert result.exit_code == 1 + + def test_restart_calls_stop_then_start(self, env): + with patch("bridge.cli.TunnelManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr_cls.return_value = mock_mgr + call_order = [] + mock_mgr.stop.side_effect = lambda: call_order.append("stop") + mock_mgr.start.side_effect = lambda: call_order.append("start") + + result = runner.invoke(app, ["restart", "test-tunnel"], env=env) + + assert result.exit_code == 0 + assert call_order == ["stop", "start"] diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..7147ffc --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,130 @@ +"""Tests for config loading.""" +import os +import textwrap + +import pytest + +from bridge.config import ConfigError, load_config + + +VALID_YAML = textwrap.dedent("""\ + tunnels: + state-hub-coulombcore: + host: coulombcore.local + remote_port: 18000 + local_port: 8000 + ssh_user: ubuntu + ssh_key: ~/.ssh/id_ops + actor: agent.claude-coulombcore + health_check: + url: http://127.0.0.1:18000/health + interval_seconds: 30 + timeout_seconds: 5 + reconnect: + max_attempts: 0 + backoff_initial: 5 + backoff_max: 60 + + actors: + agent.claude-coulombcore: + class: automation + description: Claude Code agent on CoulombCore + operator.bernd: + class: human + description: Bernd Worsch +""") + + +@pytest.fixture +def config_file(tmp_path): + f = tmp_path / "tunnels.yaml" + f.write_text(VALID_YAML) + return f + + +def test_load_valid_config(config_file, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(config_file)) + cfg = load_config() + assert "state-hub-coulombcore" in cfg.tunnels + t = cfg.tunnels["state-hub-coulombcore"] + assert t.host == "coulombcore.local" + assert t.remote_port == 18000 + assert t.local_port == 8000 + assert t.ssh_user == "ubuntu" + assert t.actor == "agent.claude-coulombcore" + + +def test_health_check_loaded(config_file, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(config_file)) + cfg = load_config() + t = cfg.tunnels["state-hub-coulombcore"] + assert t.health_check is not None + assert t.health_check.url == "http://127.0.0.1:18000/health" + assert t.health_check.interval_seconds == 30 + + +def test_reconnect_policy_loaded(config_file, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(config_file)) + cfg = load_config() + t = cfg.tunnels["state-hub-coulombcore"] + assert t.reconnect.max_attempts == 0 + assert t.reconnect.backoff_initial == 5 + assert t.reconnect.backoff_max == 60 + + +def test_actors_loaded(config_file, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(config_file)) + cfg = load_config() + assert "agent.claude-coulombcore" in cfg.actors + a = cfg.actors["agent.claude-coulombcore"] + assert a.actor_class == "automation" + assert "operator.bernd" in cfg.actors + + +def test_missing_required_field_raises(tmp_path, monkeypatch): + f = tmp_path / "bad.yaml" + f.write_text(textwrap.dedent("""\ + tunnels: + broken: + remote_port: 18000 + local_port: 8000 + actors: {} + """)) + monkeypatch.setenv("BRIDGE_CONFIG", str(f)) + with pytest.raises(ConfigError, match="host"): + load_config() + + +def test_invalid_yaml_raises(tmp_path, monkeypatch): + f = tmp_path / "bad.yaml" + f.write_text("tunnels: [\nnot: valid: yaml") + monkeypatch.setenv("BRIDGE_CONFIG", str(f)) + with pytest.raises(ConfigError): + load_config() + + +def test_missing_config_file_raises(tmp_path, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(tmp_path / "nonexistent.yaml")) + with pytest.raises(ConfigError, match="not found"): + load_config() + + +def test_tunnel_without_health_check(tmp_path, monkeypatch): + f = tmp_path / "tunnels.yaml" + f.write_text(textwrap.dedent("""\ + tunnels: + simple: + host: host.local + remote_port: 9000 + local_port: 8000 + ssh_user: ubuntu + ssh_key: ~/.ssh/id_rsa + actor: operator.bernd + actors: + operator.bernd: + class: human + description: Bernd + """)) + monkeypatch.setenv("BRIDGE_CONFIG", str(f)) + cfg = load_config() + assert cfg.tunnels["simple"].health_check is None diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..6c12cdc --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,78 @@ +"""Tests for health checking.""" +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from bridge.health import HealthChecker, HealthResult + + +class TestHealthResult: + def test_ok(self): + r = HealthResult(ok=True, status_code=200) + assert r.ok + assert r.status_code == 200 + assert r.error is None + + def test_failure(self): + r = HealthResult(ok=False, error="connection refused") + assert not r.ok + assert r.error == "connection refused" + + +class TestHealthChecker: + @pytest.mark.asyncio + async def test_check_ok(self): + checker = HealthChecker(url="http://127.0.0.1:18000/health", timeout_seconds=5) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value = mock_client + + result = await checker.check() + + assert result.ok + assert result.status_code == 200 + + @pytest.mark.asyncio + async def test_check_connection_error(self): + import httpx + checker = HealthChecker(url="http://127.0.0.1:19999/health", timeout_seconds=1) + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(side_effect=httpx.ConnectError("refused")) + mock_client_cls.return_value = mock_client + + result = await checker.check() + + assert not result.ok + assert result.error is not None + + @pytest.mark.asyncio + async def test_check_http_error(self): + import httpx + checker = HealthChecker(url="http://127.0.0.1:18000/health", timeout_seconds=5) + mock_response = MagicMock() + mock_response.status_code = 503 + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError("503", request=MagicMock(), response=mock_response) + ) + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + mock_client_cls.return_value = mock_client + + result = await checker.check() + + assert not result.ok + assert result.status_code == 503 diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..f1880d9 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,219 @@ +"""Integration tests for OpsBridge.""" +import json +import os +import textwrap +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from bridge.config import load_config +from bridge.manager import TunnelManager +from bridge.models import BridgeState, ReconnectPolicy, TunnelConfig +from bridge.state import StateManager + + +MINIMAL_CONFIG = textwrap.dedent("""\ + tunnels: + local-test: + host: 127.0.0.1 + remote_port: 19000 + local_port: 8000 + ssh_user: testuser + ssh_key: ~/.ssh/id_rsa + actor: operator.bernd + reconnect: + max_attempts: 2 + backoff_initial: 1 + backoff_max: 2 + actors: + operator.bernd: + class: human + description: Bernd +""") + + +@pytest.fixture +def config_file(tmp_path): + f = tmp_path / "tunnels.yaml" + f.write_text(MINIMAL_CONFIG) + return f + + +@pytest.fixture +def state_dir(tmp_path): + return tmp_path / "bridge" + + +@pytest.fixture +def tunnel_cfg(): + return TunnelConfig( + name="local-test", + host="127.0.0.1", + remote_port=19000, + local_port=8000, + ssh_user="testuser", + ssh_key="~/.ssh/id_rsa", + actor="operator.bernd", + reconnect=ReconnectPolicy(max_attempts=2, backoff_initial=1, backoff_max=2), + ) + + +class TestConfigRoundtrip: + def test_load_config_from_file(self, config_file, monkeypatch): + monkeypatch.setenv("BRIDGE_CONFIG", str(config_file)) + cfg = load_config() + assert "local-test" in cfg.tunnels + t = cfg.tunnels["local-test"] + assert t.host == "127.0.0.1" + assert t.reconnect.max_attempts == 2 + assert t.reconnect.backoff_initial == 1 + + +class TestStateRoundtrip: + def test_state_persists_across_manager_instances(self, state_dir, tunnel_cfg): + mgr1 = TunnelManager(tunnel_cfg, state_dir=state_dir) + mgr1._state.write_state(tunnel_cfg.name, BridgeState.CONNECTED) + + mgr2 = TunnelManager(tunnel_cfg, state_dir=state_dir) + assert mgr2.get_state() == BridgeState.CONNECTED + + def test_stale_pid_cleanup(self, state_dir, tunnel_cfg): + sm = StateManager(state_dir=state_dir) + sm.write_pid(tunnel_cfg.name, 999999) # guaranteed not alive + sm.write_state(tunnel_cfg.name, BridgeState.CONNECTED) + + # is_running should return False for dead pid + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + assert not mgr.is_running() + + +class TestReconnectLoop: + def test_reconnect_loop_gives_up_after_max_attempts(self, state_dir, tunnel_cfg): + """Manager should set FAILED state after exhausting max_attempts.""" + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + + attempt_count = [0] + + def fake_popen(cmd, **kwargs): + proc = MagicMock() + proc.poll.return_value = 1 # immediately "dead" + proc.returncode = 1 + attempt_count[0] += 1 + return proc + + with patch("subprocess.Popen", side_effect=fake_popen), \ + patch("time.sleep"): # skip sleeps for speed + mgr._run_loop() + + assert attempt_count[0] >= 1 + assert mgr.get_state() == BridgeState.FAILED + + def test_reconnect_logs_events(self, state_dir, tunnel_cfg): + """Audit log should contain reconnect events.""" + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + + def fake_popen(cmd, **kwargs): + proc = MagicMock() + proc.poll.return_value = 1 + proc.returncode = 1 + return proc + + with patch("subprocess.Popen", side_effect=fake_popen), \ + patch("time.sleep"): + mgr._run_loop() + + events = mgr._audit.read_events(tunnel_cfg.name) + event_types = [e["event"] for e in events] + assert "bridge_started" in event_types or "bridge_reconnecting" in event_types or "bridge_disconnected" in event_types + + +class TestHealthCheckDegradedPath: + def test_degraded_state_on_health_failure(self, state_dir): + """Health check failure sets state to DEGRADED.""" + from bridge.health import HealthChecker, HealthResult + + hc_cfg = MagicMock() + hc_cfg.url = "http://127.0.0.1:19001/health" + hc_cfg.interval_seconds = 0 + hc_cfg.timeout_seconds = 1 + + tunnel_cfg = TunnelConfig( + name="hc-test", + host="127.0.0.1", + remote_port=19001, + local_port=8001, + ssh_user="u", + ssh_key="k", + actor="operator.bernd", + reconnect=ReconnectPolicy(max_attempts=1, backoff_initial=1, backoff_max=1), + health_check=hc_cfg, + ) + + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + + proc_call_count = [0] + + def fake_popen(cmd, **kwargs): + proc = MagicMock() + # First call: "alive" for 1 health check cycle then dies + proc_call_count[0] += 1 + if proc_call_count[0] == 1: + # Poll returns None (alive) once then dies + poll_calls = [None, 1] + proc.poll.side_effect = poll_calls + [1] * 100 + proc.returncode = 1 + else: + proc.poll.return_value = 1 + proc.returncode = 1 + return proc + + failed_result = HealthResult(ok=False, error="connection refused") + recovered_result = HealthResult(ok=True, status_code=200) + + import asyncio + + async def fake_check_failing(): + return failed_result + + with patch("subprocess.Popen", side_effect=fake_popen), \ + patch("time.sleep"), \ + patch("bridge.manager.HealthChecker") as mock_hc_cls: + mock_checker = MagicMock() + mock_checker.check = MagicMock(side_effect=lambda: failed_result) + # Use asyncio.run compatibility + mock_hc_cls.return_value = mock_checker + + with patch("asyncio.run", side_effect=lambda coro: failed_result): + mgr._run_loop() + + # Should have set degraded at some point — check audit log + events = mgr._audit.read_events("hc-test") + event_types = [e["event"] for e in events] + assert "health_check_failed" in event_types or "bridge_disconnected" in event_types + + +class TestAuditTrail: + def test_full_lifecycle_logged(self, state_dir, tunnel_cfg): + """A start + immediate-exit SSH produces at minimum started + disconnected events.""" + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + + def fake_popen(cmd, **kwargs): + proc = MagicMock() + proc.poll.return_value = 1 + proc.returncode = 1 + return proc + + with patch("subprocess.Popen", side_effect=fake_popen), \ + patch("time.sleep"): + mgr._run_loop() + + events = mgr._audit.read_events(tunnel_cfg.name) + assert len(events) >= 2 + # Each event has required fields + for e in events: + assert "timestamp" in e + assert "tunnel" in e + assert "actor" in e + assert "event" in e diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 0000000..3cb3c26 --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,109 @@ +"""Tests for TunnelManager.""" +import os +import signal +import time +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pytest + +from bridge.models import BridgeState, ReconnectPolicy, TunnelConfig +from bridge.manager import TunnelManager, build_ssh_command + + +@pytest.fixture +def tunnel_cfg(): + return TunnelConfig( + name="test-tunnel", + host="host.local", + remote_port=18000, + local_port=8000, + ssh_user="ubuntu", + ssh_key="~/.ssh/id_ops", + actor="operator.bernd", + reconnect=ReconnectPolicy(max_attempts=3, backoff_initial=1, backoff_max=5), + ) + + +@pytest.fixture +def state_dir(tmp_path): + return tmp_path / "bridge" + + +class TestBuildSshCommand: + def test_basic_command(self, tunnel_cfg): + cmd = build_ssh_command(tunnel_cfg) + assert cmd[0] == "ssh" + assert "-N" in cmd + assert "-R" in cmd + assert "18000:127.0.0.1:8000" in cmd + assert "-i" in cmd + assert "ubuntu@host.local" in cmd + + def test_server_alive_options(self, tunnel_cfg): + cmd = build_ssh_command(tunnel_cfg) + assert "-o" in cmd + assert "ServerAliveInterval=10" in cmd + assert "ExitOnForwardFailure=yes" in cmd + + def test_ssh_key_expanded(self, tunnel_cfg): + cmd = build_ssh_command(tunnel_cfg) + key_idx = cmd.index("-i") + 1 + assert not cmd[key_idx].startswith("~") + + +class TestTunnelManager: + def test_get_state_initial(self, tunnel_cfg, state_dir): + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + assert mgr.get_state() == BridgeState.STOPPED + + def test_stop_when_not_running_is_noop(self, tunnel_cfg, state_dir): + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + # Should not raise + mgr.stop() + assert mgr.get_state() == BridgeState.STOPPED + + def test_stop_kills_pid(self, tunnel_cfg, state_dir): + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + # Write a fake PID of our own process to simulate running + mgr._state.write_pid(tunnel_cfg.name, os.getpid()) + mgr._state.write_state(tunnel_cfg.name, BridgeState.CONNECTED) + + with patch("os.kill") as mock_kill: + mgr.stop() + + # Should have sent SIGTERM + mock_kill.assert_any_call(os.getpid(), signal.SIGTERM) + assert mgr.get_state() == BridgeState.STOPPED + + def test_backoff_calculation(self, tunnel_cfg, state_dir): + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + # First backoff = initial + assert mgr._next_backoff(0) == 1 + # Doubles each time up to max + assert mgr._next_backoff(1) == 2 + assert mgr._next_backoff(2) == 4 + assert mgr._next_backoff(3) == 5 # capped at max + + def test_start_daemonizes(self, tunnel_cfg, state_dir): + """Verify start() forks without hanging.""" + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + + # We can't actually fork in tests; verify state transitions via mock + with patch("subprocess.Popen") as mock_popen, \ + patch("os.fork", return_value=1234) as mock_fork, \ + patch("os.setsid"), \ + patch("os._exit"): + mock_proc = MagicMock() + mock_proc.pid = 9999 + mock_popen.return_value = mock_proc + + # When fork returns non-zero we're the parent — just check PID written + mgr.start() + + # After start the state should be STARTING (set before fork) + # and PID file should exist (written in parent branch) + + def test_is_running_false_initially(self, tunnel_cfg, state_dir): + mgr = TunnelManager(tunnel_cfg, state_dir=state_dir) + assert not mgr.is_running() diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..c5e8ae3 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,75 @@ +"""Tests for domain models.""" +import pytest +from bridge.models import ( + ActorInfo, + BridgeState, + HealthCheckConfig, + ReconnectPolicy, + TunnelConfig, +) + + +class TestBridgeState: + def test_all_states_defined(self): + states = {s.value for s in BridgeState} + assert states == {"stopped", "starting", "connected", "degraded", "reconnecting", "failed"} + + def test_state_is_string(self): + assert BridgeState.STOPPED == "stopped" + + +class TestReconnectPolicy: + def test_defaults(self): + p = ReconnectPolicy() + assert p.max_attempts == 0 + assert p.backoff_initial == 5 + assert p.backoff_max == 60 + + def test_custom(self): + p = ReconnectPolicy(max_attempts=3, backoff_initial=2, backoff_max=30) + assert p.max_attempts == 3 + + +class TestHealthCheckConfig: + def test_required_url(self): + h = HealthCheckConfig(url="http://127.0.0.1:18000/health") + assert h.url == "http://127.0.0.1:18000/health" + assert h.interval_seconds == 30 + assert h.timeout_seconds == 5 + + +class TestTunnelConfig: + def test_minimal(self): + t = TunnelConfig( + name="test-tunnel", + host="host.local", + remote_port=18000, + local_port=8000, + ssh_user="ubuntu", + ssh_key="~/.ssh/id_ops", + actor="operator.bernd", + ) + assert t.name == "test-tunnel" + assert t.health_check is None + assert isinstance(t.reconnect, ReconnectPolicy) + + def test_with_health_check(self): + hc = HealthCheckConfig(url="http://127.0.0.1:18000/health") + t = TunnelConfig( + name="test", + host="h", + remote_port=1, + local_port=2, + ssh_user="u", + ssh_key="k", + actor="a", + health_check=hc, + ) + assert t.health_check is hc + + +class TestActorInfo: + def test_fields(self): + a = ActorInfo(name="operator.bernd", actor_class="human", description="Bernd") + assert a.name == "operator.bernd" + assert a.actor_class == "human" diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 0000000..01e9d04 --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,69 @@ +"""Tests for state management.""" +import os +import signal + +import pytest + +from bridge.models import BridgeState +from bridge.state import StateManager + + +@pytest.fixture +def state_dir(tmp_path): + return tmp_path / "bridge" + + +@pytest.fixture +def mgr(state_dir): + return StateManager(state_dir=state_dir) + + +class TestStateManager: + def test_read_state_no_file_returns_stopped(self, mgr): + assert mgr.read_state("my-tunnel") == BridgeState.STOPPED + + def test_write_and_read_state(self, mgr): + mgr.write_state("my-tunnel", BridgeState.CONNECTED) + assert mgr.read_state("my-tunnel") == BridgeState.CONNECTED + + def test_state_roundtrip_all_values(self, mgr): + for state in BridgeState: + mgr.write_state("t", state) + assert mgr.read_state("t") == state + + def test_write_pid(self, mgr): + # Write a live PID (our own process) so read_pid can confirm it's alive + pid = os.getpid() + mgr.write_pid("my-tunnel", pid) + assert mgr.read_pid("my-tunnel") == pid + + def test_read_pid_no_file_returns_none(self, mgr): + assert mgr.read_pid("nonexistent") is None + + def test_stale_pid_returns_none(self, mgr): + # PID 999999 almost certainly does not exist + mgr.write_pid("my-tunnel", 999999) + assert mgr.read_pid("my-tunnel") is None + + def test_current_pid_is_alive(self, mgr): + mgr.write_pid("my-tunnel", os.getpid()) + assert mgr.read_pid("my-tunnel") == os.getpid() + + def test_clear_pid(self, mgr): + mgr.write_pid("my-tunnel", os.getpid()) + mgr.clear_pid("my-tunnel") + assert mgr.read_pid("my-tunnel") is None + + def test_state_dir_created_on_write(self, state_dir): + assert not state_dir.exists() + mgr = StateManager(state_dir=state_dir) + mgr.write_state("t", BridgeState.STOPPED) + assert state_dir.exists() + + def test_is_running_false_when_stopped(self, mgr): + assert not mgr.is_running("my-tunnel") + + def test_is_running_true_when_pid_alive(self, mgr): + mgr.write_pid("my-tunnel", os.getpid()) + mgr.write_state("my-tunnel", BridgeState.CONNECTED) + assert mgr.is_running("my-tunnel")