diff --git a/api/routers/policy.py b/api/routers/policy.py index 85ead6f..3da152e 100644 --- a/api/routers/policy.py +++ b/api/routers/policy.py @@ -1,23 +1,13 @@ import re from pathlib import Path -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel +from fastapi import HTTPException +from hub_core.routers.policy import create_policy_router +from hub_core.schemas.policy import PolicyRead POLICY_DIR = Path(__file__).parent.parent.parent / "policies" _VALID_NAME = re.compile(r"^[a-z0-9][a-z0-9-]{0,63}$") -router = APIRouter(prefix="/policy", tags=["policy"]) - - -class PolicyRead(BaseModel): - name: str - content: str - - -class PolicyUpdate(BaseModel): - content: str - def _policy_path(name: str) -> Path: if not _VALID_NAME.match(name): @@ -28,14 +18,17 @@ def _policy_path(name: str) -> Path: return path -@router.get("/{name}", response_model=PolicyRead) -def get_policy(name: str) -> PolicyRead: +def _load_policy(name: str) -> PolicyRead: path = _policy_path(name) return PolicyRead(name=name, content=path.read_text()) -@router.put("/{name}", response_model=PolicyRead) -def update_policy(name: str, body: PolicyUpdate) -> PolicyRead: +def _update_policy(name: str, content: str) -> PolicyRead: path = _policy_path(name) - path.write_text(body.content) - return PolicyRead(name=name, content=body.content) + path.write_text(content) + return PolicyRead(name=name, content=content) + + +router = create_policy_router(_load_policy, update_policy=_update_policy) + +__all__ = ["router"] diff --git a/tests/test_policy_router.py b/tests/test_policy_router.py new file mode 100644 index 0000000..10349da --- /dev/null +++ b/tests/test_policy_router.py @@ -0,0 +1,21 @@ +from api.routers import policy + + +async def test_policy_router_reads_and_updates_policy(client, tmp_path, monkeypatch) -> None: + monkeypatch.setattr(policy, "POLICY_DIR", tmp_path) + (tmp_path / "example.md").write_text("old content") + read_response = await client.get("/policy/example") + update_response = await client.put("/policy/example", json={"content": "new content"}) + + assert read_response.status_code == 200 + assert read_response.json() == {"name": "example", "content": "old content"} + assert update_response.status_code == 200 + assert update_response.json() == {"name": "example", "content": "new content"} + assert (tmp_path / "example.md").read_text() == "new content" + + +async def test_policy_router_rejects_invalid_policy_name(client, tmp_path, monkeypatch) -> None: + monkeypatch.setattr(policy, "POLICY_DIR", tmp_path) + response = await client.get("/policy/BadName") + + assert response.status_code == 400