import uuid from collections.abc import Callable from datetime import datetime, timezone from typing import Any from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from hub_core.models.agent_message import AgentMessage from hub_core.schemas.agent_message import MessageCreate, MessageRead, MessageReply def create_messages_router( get_session: Callable[..., AsyncSession], *, message_model: type[AgentMessage] = AgentMessage, ) -> APIRouter: router = APIRouter(prefix="/messages", tags=["messages"]) async def _get_message(message_id: uuid.UUID, session: AsyncSession) -> Any: msg = await session.get(message_model, message_id) if msg is None: raise HTTPException(status_code=404, detail=f"Message {message_id} not found") return msg @router.post("/", response_model=MessageRead, status_code=status.HTTP_201_CREATED) async def send_message( body: MessageCreate, session: AsyncSession = Depends(get_session), ) -> Any: if body.thread_id: root = await session.get(message_model, body.thread_id) if root is None: raise HTTPException(status_code=404, detail=f"Thread root {body.thread_id} not found") msg = message_model(**body.model_dump()) session.add(msg) await session.commit() await session.refresh(msg) return msg @router.get("/", response_model=list[MessageRead]) async def list_messages( to_agent: str | None = None, from_agent: str | None = None, unread_only: bool = False, limit: int = 50, session: AsyncSession = Depends(get_session), ) -> list[Any]: q = select(message_model).where(message_model.archived_at.is_(None)) if to_agent: q = q.where( (message_model.to_agent == to_agent) | (message_model.to_agent == "broadcast") ) if from_agent: q = q.where(message_model.from_agent == from_agent) if unread_only: q = q.where(message_model.read_at.is_(None)) q = q.order_by(message_model.created_at.desc()).limit(limit) result = await session.execute(q) return list(result.scalars().all()) @router.get("/thread/{thread_id}", response_model=list[MessageRead]) async def get_thread( thread_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> list[Any]: q = select(message_model).where( (message_model.id == thread_id) | (message_model.thread_id == thread_id) ).order_by(message_model.created_at) result = await session.execute(q) return list(result.scalars().all()) @router.patch("/{message_id}/read", response_model=MessageRead) async def mark_read( message_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Any: msg = await _get_message(message_id, session) if msg.read_at is None: msg.read_at = datetime.now(timezone.utc) await session.commit() await session.refresh(msg) return msg @router.patch("/{message_id}/archive", response_model=MessageRead) async def archive_message( message_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Any: msg = await _get_message(message_id, session) msg.archived_at = datetime.now(timezone.utc) if msg.read_at is None: msg.read_at = msg.archived_at await session.commit() await session.refresh(msg) return msg @router.post("/{message_id}/reply", response_model=MessageRead, status_code=status.HTTP_201_CREATED) async def reply_to_message( message_id: uuid.UUID, body: MessageReply, session: AsyncSession = Depends(get_session), ) -> Any: original = await _get_message(message_id, session) if original.read_at is None: original.read_at = datetime.now(timezone.utc) thread_root = original.thread_id or original.id reply = message_model( from_agent=body.from_agent, to_agent=original.from_agent, subject=f"Re: {original.subject}", body=body.body, thread_id=thread_root, ) session.add(reply) await session.commit() await session.refresh(reply) return reply return router