generated from coulomb/repo-seed
feat: use hub-core policy router
This commit is contained in:
@@ -1,23 +1,13 @@
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from hub_core.routers.policy import create_policy_router
|
||||||
|
from hub_core.schemas.policy import PolicyRead
|
||||||
|
|
||||||
POLICY_DIR = Path(__file__).parent.parent.parent / "policies"
|
POLICY_DIR = Path(__file__).parent.parent.parent / "policies"
|
||||||
_VALID_NAME = re.compile(r"^[a-z0-9][a-z0-9-]{0,63}$")
|
_VALID_NAME = re.compile(r"^[a-z0-9][a-z0-9-]{0,63}$")
|
||||||
|
|
||||||
router = APIRouter(prefix="/policy", tags=["policy"])
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyRead(BaseModel):
|
|
||||||
name: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class PolicyUpdate(BaseModel):
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
def _policy_path(name: str) -> Path:
|
def _policy_path(name: str) -> Path:
|
||||||
if not _VALID_NAME.match(name):
|
if not _VALID_NAME.match(name):
|
||||||
@@ -28,14 +18,17 @@ def _policy_path(name: str) -> Path:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{name}", response_model=PolicyRead)
|
def _load_policy(name: str) -> PolicyRead:
|
||||||
def get_policy(name: str) -> PolicyRead:
|
|
||||||
path = _policy_path(name)
|
path = _policy_path(name)
|
||||||
return PolicyRead(name=name, content=path.read_text())
|
return PolicyRead(name=name, content=path.read_text())
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{name}", response_model=PolicyRead)
|
def _update_policy(name: str, content: str) -> PolicyRead:
|
||||||
def update_policy(name: str, body: PolicyUpdate) -> PolicyRead:
|
|
||||||
path = _policy_path(name)
|
path = _policy_path(name)
|
||||||
path.write_text(body.content)
|
path.write_text(content)
|
||||||
return PolicyRead(name=name, content=body.content)
|
return PolicyRead(name=name, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
router = create_policy_router(_load_policy, update_policy=_update_policy)
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
|
|||||||
21
tests/test_policy_router.py
Normal file
21
tests/test_policy_router.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from api.routers import policy
|
||||||
|
|
||||||
|
|
||||||
|
async def test_policy_router_reads_and_updates_policy(client, tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(policy, "POLICY_DIR", tmp_path)
|
||||||
|
(tmp_path / "example.md").write_text("old content")
|
||||||
|
read_response = await client.get("/policy/example")
|
||||||
|
update_response = await client.put("/policy/example", json={"content": "new content"})
|
||||||
|
|
||||||
|
assert read_response.status_code == 200
|
||||||
|
assert read_response.json() == {"name": "example", "content": "old content"}
|
||||||
|
assert update_response.status_code == 200
|
||||||
|
assert update_response.json() == {"name": "example", "content": "new content"}
|
||||||
|
assert (tmp_path / "example.md").read_text() == "new content"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_policy_router_rejects_invalid_policy_name(client, tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(policy, "POLICY_DIR", tmp_path)
|
||||||
|
response = await client.get("/policy/BadName")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
Reference in New Issue
Block a user