"""Signal extractors (PRD §6.2; T04). Pure functions over a session digest (Tier 2) — the compact, durable view. Each extractor emits zero or more :class:`Signal`s. A signal records its source session, a *locus* (what it's about), a *polarity* (problem vs. success), and a *magnitude*. Signals are the atoms the clusterer groups into candidate patterns. No new capture happens here; everything is derived from digests already written by the Capture layer, so detection is cheap and re-runnable. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Optional # polarity PROBLEM = "problem" SUCCESS = "success" @dataclass class Signal: session_uid: str flavor: str repo: Optional[str] type: str # e.g. "budget_overrun", "clean_pass" polarity: str # PROBLEM | SUCCESS locus: str # normalized subject key (tool, marker, ...) magnitude: float = 1.0 # strength / cost weight detail: dict[str, Any] = field(default_factory=dict) # --- individual extractors -------------------------------------------------- # Each takes (digest, ctx) and returns a list[Signal]. ctx carries corpus-level # stats (e.g. cost percentiles) so extractors can compare a session to its peers. def _base(digest, type_, polarity, locus, magnitude=1.0, **detail) -> Signal: return Signal( session_uid=digest["session_uid"], flavor=digest["flavor"], repo=digest.get("repo"), type=type_, polarity=polarity, locus=locus, magnitude=magnitude, detail=detail, ) def sig_retry_storm(digest, ctx) -> list[Signal]: retries = digest.get("markers", {}).get("retries", 0) if retries >= ctx.get("retry_storm_threshold", 3): return [_base(digest, "retry_storm", PROBLEM, "retries", float(retries), retries=retries)] return [] def sig_repeated_errors(digest, ctx) -> list[Signal]: errors = digest.get("markers", {}).get("errors", 0) if errors >= ctx.get("error_threshold", 3): return [_base(digest, "repeated_errors", PROBLEM, "errors", float(errors), errors=errors)] return [] def sig_budget_overrun(digest, ctx) -> list[Signal]: total = digest.get("cost", {}).get("input_tokens", 0) + digest.get("cost", {}).get("output_tokens", 0) p90 = ctx.get("tokens_p90", 0) if p90 and total > p90: return [_base(digest, "budget_overrun", PROBLEM, "tokens", float(total) / max(p90, 1), tokens=total, p90=p90)] return [] def sig_abandoned(digest, ctx) -> list[Signal]: if digest.get("outcome") == "abandoned": return [_base(digest, "abandoned", PROBLEM, "outcome", 1.0)] return [] def sig_clean_pass(digest, ctx) -> list[Signal]: """Success: ended success, ran tests, no errors, modest cost.""" m = digest.get("markers", {}) if (digest.get("outcome") == "success" and m.get("test_runs", 0) >= 1 and m.get("errors", 0) == 0 and m.get("retries", 0) == 0): return [_base(digest, "clean_pass", SUCCESS, "outcome", 1.0, test_runs=m.get("test_runs"))] return [] def sig_error_then_recovery(digest, ctx) -> list[Signal]: """Success despite hitting errors — a recovery worth learning from.""" m = digest.get("markers", {}) if digest.get("outcome") == "success" and m.get("errors", 0) >= 1: return [_base(digest, "error_then_recovery", SUCCESS, "errors", float(m.get("errors", 1)), errors=m.get("errors"))] return [] # --- tool-mix / infrastructure-overhead signals (WP-0005 T02) ---------------- # These read the captured ``tool_histogram`` — friction that the outcome+marker # signals above are blind to (sessions still "succeed", just expensively). def tool_bucket(tool: str) -> str: """Group a tool name into a coarse activity bucket (flavor-agnostic).""" if tool.startswith("mcp__state-hub"): return "statehub_mcp" if tool in ("TaskUpdate", "TaskCreate", "TaskGet", "TaskList", "TaskOutput", "TaskStop", "todo_write", "update_task_status"): return "task_mgmt" if tool == "ToolSearch": return "schema_load" if tool in ("Bash", "run_terminal_command"): return "shell" if tool in ("Edit", "Write", "search_replace", "write", "NotebookEdit"): return "edit" if tool in ("Read", "read_file", "grep", "Grep", "glob", "Glob"): return "read" return "other" def _bucketed(digest) -> tuple[dict, int]: buckets: dict[str, int] = {} for tool, n in (digest.get("tool_histogram") or {}).items(): buckets[tool_bucket(tool)] = buckets.get(tool_bucket(tool), 0) + n return buckets, sum(buckets.values()) def sig_infra_overhead(digest, ctx) -> list[Signal]: """Problem: a large share of tool calls is hub/task/schema plumbing, not work.""" buckets, total = _bucketed(digest) if total < ctx.get("infra_min_calls", 20): return [] overhead = buckets.get("statehub_mcp", 0) + buckets.get("task_mgmt", 0) + buckets.get("schema_load", 0) share = overhead / total if share >= ctx.get("infra_overhead_threshold", 0.30): return [_base(digest, "infra_overhead", PROBLEM, "infra_overhead", round(share, 3), overhead_calls=overhead, total_calls=total, statehub=buckets.get("statehub_mcp", 0), task_mgmt=buckets.get("task_mgmt", 0), schema_load=buckets.get("schema_load", 0))] return [] def sig_schema_thrash(digest, ctx) -> list[Signal]: """Problem: repeated ToolSearch — deferred-tool schemas reloaded over and over.""" buckets, _ = _bucketed(digest) n = buckets.get("schema_load", 0) if n >= ctx.get("schema_thrash_threshold", 5): return [_base(digest, "schema_thrash", PROBLEM, "schema_load", float(n), tool_searches=n)] return [] def sig_tool_thrash(digest, ctx) -> list[Signal]: """Problem: a single tool is hammered far more than any other — likely churn.""" hist = digest.get("tool_histogram") or {} if not hist: return [] tool, n = max(hist.items(), key=lambda kv: kv[1]) if n >= ctx.get("tool_thrash_threshold", 80): return [_base(digest, "tool_thrash", PROBLEM, f"tool:{tool}", float(n), tool=tool, calls=n)] return [] def sig_recurring_error(digest, ctx) -> list[Signal]: """Problem: a normalized error fingerprint (WP-0006) — one signal per distinct error in the session, so the same error across sessions/repos/flavors clusters into a candidate root-cause pattern (locus = fingerprint, magnitude = in-session occurrences). This is the content-level 'why', not just a coarse error count. """ out: list[Signal] = [] for snip in digest.get("error_snippets", []) or []: fp = snip.get("fingerprint") if not fp: continue out.append(_base(digest, "recurring_error", PROBLEM, fp, float(snip.get("count", 1)), sample=snip.get("sample", ""), tool=snip.get("tool"), occurrences=snip.get("count", 1))) return out EXTRACTORS: list[Callable] = [ sig_retry_storm, sig_repeated_errors, sig_budget_overrun, sig_abandoned, sig_clean_pass, sig_error_then_recovery, sig_infra_overhead, sig_schema_thrash, sig_tool_thrash, sig_recurring_error, ] def build_context(digests: list[dict]) -> dict[str, Any]: """Corpus-level stats so extractors can compare a session to its peers.""" totals = sorted( d.get("cost", {}).get("input_tokens", 0) + d.get("cost", {}).get("output_tokens", 0) for d in digests ) p90 = totals[int(0.9 * (len(totals) - 1))] if totals else 0 return { "tokens_p90": p90, "retry_storm_threshold": 3, "error_threshold": 3, # tool-mix / infra-overhead thresholds (WP-0005 T02) "infra_min_calls": 20, "infra_overhead_threshold": 0.30, "schema_thrash_threshold": 5, "tool_thrash_threshold": 80, } def extract_signals(digests: list[dict], ctx: Optional[dict] = None) -> list[Signal]: ctx = ctx or build_context(digests) out: list[Signal] = [] for d in digests: for ex in EXTRACTORS: out.extend(ex(d, ctx)) return out