import uuid from datetime import datetime, timezone from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.agent_message import AgentMessage from api.schemas.agent_message import MessageCreate, MessageRead, MessageReply router = APIRouter(prefix="/messages", tags=["messages"]) @router.post("/", response_model=MessageRead, status_code=status.HTTP_201_CREATED) async def send_message( body: MessageCreate, session: AsyncSession = Depends(get_session), ) -> AgentMessage: """Send a message from one agent to another (or 'broadcast').""" if body.thread_id: root = await session.get(AgentMessage, body.thread_id) if root is None: raise HTTPException(status_code=404, detail=f"Thread root {body.thread_id} not found") msg = AgentMessage( from_agent=body.from_agent, to_agent=body.to_agent, subject=body.subject, body=body.body, thread_id=body.thread_id, ) session.add(msg) await session.commit() await session.refresh(msg) return msg @router.get("/", response_model=list[MessageRead]) async def list_messages( to_agent: str | None = None, from_agent: str | None = None, unread_only: bool = False, limit: int = 50, session: AsyncSession = Depends(get_session), ) -> list[AgentMessage]: """List messages. Filter by recipient, sender, or unread status.""" q = select(AgentMessage).where(AgentMessage.archived_at.is_(None)) if to_agent: q = q.where( (AgentMessage.to_agent == to_agent) | (AgentMessage.to_agent == "broadcast") ) if from_agent: q = q.where(AgentMessage.from_agent == from_agent) if unread_only: q = q.where(AgentMessage.read_at.is_(None)) q = q.order_by(AgentMessage.created_at.desc()).limit(limit) result = await session.execute(q) return list(result.scalars().all()) @router.get("/thread/{thread_id}", response_model=list[MessageRead]) async def get_thread( thread_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> list[AgentMessage]: """Get all messages in a thread (root + replies), oldest first.""" # Include the root message itself q = select(AgentMessage).where( (AgentMessage.id == thread_id) | (AgentMessage.thread_id == thread_id) ).order_by(AgentMessage.created_at) result = await session.execute(q) return list(result.scalars().all()) @router.patch("/{message_id}/read", response_model=MessageRead) async def mark_read( message_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> AgentMessage: """Mark a message as read.""" msg = await _get_message(message_id, session) if msg.read_at is None: msg.read_at = datetime.now(timezone.utc) await session.commit() await session.refresh(msg) return msg @router.patch("/{message_id}/archive", response_model=MessageRead) async def archive_message( message_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> AgentMessage: """Archive a message (soft-delete).""" msg = await _get_message(message_id, session) msg.archived_at = datetime.now(timezone.utc) if msg.read_at is None: msg.read_at = msg.archived_at await session.commit() await session.refresh(msg) return msg @router.post("/{message_id}/reply", response_model=MessageRead, status_code=status.HTTP_201_CREATED) async def reply_to_message( message_id: uuid.UUID, body: MessageReply, session: AsyncSession = Depends(get_session), ) -> AgentMessage: """Reply to a message. Marks the original as read and creates a reply in the same thread.""" original = await _get_message(message_id, session) # Mark original as read if original.read_at is None: original.read_at = datetime.now(timezone.utc) # Thread root is either the original's thread_id or the original itself thread_root = original.thread_id or original.id reply = AgentMessage( from_agent=body.from_agent, to_agent=original.from_agent, subject=f"Re: {original.subject}", body=body.body, thread_id=thread_root, ) session.add(reply) await session.commit() await session.refresh(reply) return reply async def _get_message(message_id: uuid.UUID, session: AsyncSession) -> AgentMessage: msg = await session.get(AgentMessage, message_id) if msg is None: raise HTTPException(status_code=404, detail=f"Message {message_id} not found") return msg