172 lines
4.9 KiB
Python
172 lines
4.9 KiB
Python
import asyncio
|
|
import itertools
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Set
|
|
|
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class OutboundPayload(BaseModel):
|
|
payload: Any
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# In-memory registries
|
|
clients: Dict[int, WebSocket] = {}
|
|
client_lock = asyncio.Lock()
|
|
client_id_seq = itertools.count(1)
|
|
|
|
sse_subscribers: Set[asyncio.Queue[str]] = set()
|
|
sse_lock = asyncio.Lock()
|
|
|
|
|
|
def now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def format_sse_event(data: Dict[str, Any]) -> str:
|
|
"""Build a Server-Sent Event payload with an optional event name."""
|
|
event_name = data.get("event")
|
|
payload = json.dumps(data, separators=(",", ":"))
|
|
lines: List[str] = []
|
|
if event_name:
|
|
lines.append(f"event: {event_name}")
|
|
for line in payload.splitlines():
|
|
lines.append(f"data: {line}")
|
|
lines.append("")
|
|
return "\n".join(lines) + "\n"
|
|
|
|
|
|
async def broadcast_event(event: Dict[str, Any]) -> None:
|
|
"""Send an event to all SSE subscribers without blocking the websocket loop."""
|
|
message = format_sse_event(event)
|
|
async with sse_lock:
|
|
for queue in list(sse_subscribers):
|
|
try:
|
|
queue.put_nowait(message)
|
|
except asyncio.QueueFull:
|
|
# Should not happen with the default unbounded queues, but guard anyway.
|
|
pass
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def index() -> HTMLResponse:
|
|
html_path = Path(__file__).parent / "webui.html"
|
|
return HTMLResponse(html_path.read_text(encoding="utf-8"))
|
|
|
|
|
|
@app.get("/clients")
|
|
async def list_clients() -> Dict[str, List[int]]:
|
|
async with client_lock:
|
|
return {"clients": sorted(clients.keys())}
|
|
|
|
|
|
@app.post("/clients/{client_id}/send")
|
|
async def send_to_client(client_id: int, body: OutboundPayload) -> Dict[str, Any]:
|
|
async with client_lock:
|
|
websocket = clients.get(client_id)
|
|
if websocket is None:
|
|
raise HTTPException(status_code=404, detail="Client not connected")
|
|
|
|
try:
|
|
payload_text = json.dumps(body.payload)
|
|
except TypeError as exc: # noqa: BLE001
|
|
raise HTTPException(
|
|
status_code=422, detail="Payload must be JSON serializable"
|
|
) from exc
|
|
|
|
try:
|
|
await websocket.send_text(payload_text)
|
|
except Exception as exc: # noqa: BLE001
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to send to client {client_id}"
|
|
) from exc
|
|
|
|
await broadcast_event(
|
|
{
|
|
"event": "outbound",
|
|
"timestamp": now_iso(),
|
|
"client_id": client_id,
|
|
"payload": body.payload,
|
|
}
|
|
)
|
|
return {"status": "sent"}
|
|
|
|
|
|
@app.get("/events")
|
|
async def sse_events() -> StreamingResponse:
|
|
queue: asyncio.Queue[str] = asyncio.Queue()
|
|
async with sse_lock:
|
|
sse_subscribers.add(queue)
|
|
|
|
async def event_stream():
|
|
try:
|
|
while True:
|
|
try:
|
|
message = await asyncio.wait_for(queue.get(), timeout=15)
|
|
yield message
|
|
except asyncio.TimeoutError:
|
|
# Keep connection alive
|
|
yield "event: keepalive\ndata: {}\n\n"
|
|
except asyncio.CancelledError:
|
|
# Client disconnected
|
|
pass
|
|
finally:
|
|
async with sse_lock:
|
|
sse_subscribers.discard(queue)
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_handler(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
client_id = next(client_id_seq)
|
|
async with client_lock:
|
|
clients[client_id] = websocket
|
|
|
|
await broadcast_event(
|
|
{"event": "client_connected", "timestamp": now_iso(), "client_id": client_id}
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
message = await websocket.receive_text()
|
|
try:
|
|
parsed = json.loads(message)
|
|
except json.JSONDecodeError:
|
|
print("json decode error")
|
|
print(message)
|
|
parsed = {"raw": message}
|
|
|
|
await broadcast_event(
|
|
{
|
|
"event": "inbound",
|
|
"timestamp": now_iso(),
|
|
"client_id": client_id,
|
|
"payload": parsed,
|
|
}
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
async with client_lock:
|
|
clients.pop(client_id, None)
|
|
await broadcast_event(
|
|
{
|
|
"event": "client_disconnected",
|
|
"timestamp": now_iso(),
|
|
"client_id": client_id,
|
|
}
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|