init: add code
This commit is contained in:
commit
c2f9a6e600
4 changed files with 593 additions and 0 deletions
172
main.py
Normal file
172
main.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OutboundPayload(BaseModel):
|
||||
payload: Any
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# In-memory registries
|
||||
clients: Dict[int, WebSocket] = {}
|
||||
client_lock = asyncio.Lock()
|
||||
client_id_seq = itertools.count(1)
|
||||
|
||||
sse_subscribers: Set[asyncio.Queue[str]] = set()
|
||||
sse_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def format_sse_event(data: Dict[str, Any]) -> str:
|
||||
"""Build a Server-Sent Event payload with an optional event name."""
|
||||
event_name = data.get("event")
|
||||
payload = json.dumps(data, separators=(",", ":"))
|
||||
lines: List[str] = []
|
||||
if event_name:
|
||||
lines.append(f"event: {event_name}")
|
||||
for line in payload.splitlines():
|
||||
lines.append(f"data: {line}")
|
||||
lines.append("")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
async def broadcast_event(event: Dict[str, Any]) -> None:
|
||||
"""Send an event to all SSE subscribers without blocking the websocket loop."""
|
||||
message = format_sse_event(event)
|
||||
async with sse_lock:
|
||||
for queue in list(sse_subscribers):
|
||||
try:
|
||||
queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with the default unbounded queues, but guard anyway.
|
||||
pass
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index() -> HTMLResponse:
|
||||
html_path = Path(__file__).parent / "webui.html"
|
||||
return HTMLResponse(html_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
@app.get("/clients")
|
||||
async def list_clients() -> Dict[str, List[int]]:
|
||||
async with client_lock:
|
||||
return {"clients": sorted(clients.keys())}
|
||||
|
||||
|
||||
@app.post("/clients/{client_id}/send")
|
||||
async def send_to_client(client_id: int, body: OutboundPayload) -> Dict[str, Any]:
|
||||
async with client_lock:
|
||||
websocket = clients.get(client_id)
|
||||
if websocket is None:
|
||||
raise HTTPException(status_code=404, detail="Client not connected")
|
||||
|
||||
try:
|
||||
payload_text = json.dumps(body.payload)
|
||||
except TypeError as exc: # noqa: BLE001
|
||||
raise HTTPException(
|
||||
status_code=422, detail="Payload must be JSON serializable"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
await websocket.send_text(payload_text)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to send to client {client_id}"
|
||||
) from exc
|
||||
|
||||
await broadcast_event(
|
||||
{
|
||||
"event": "outbound",
|
||||
"timestamp": now_iso(),
|
||||
"client_id": client_id,
|
||||
"payload": body.payload,
|
||||
}
|
||||
)
|
||||
return {"status": "sent"}
|
||||
|
||||
|
||||
@app.get("/events")
|
||||
async def sse_events() -> StreamingResponse:
|
||||
queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
async with sse_lock:
|
||||
sse_subscribers.add(queue)
|
||||
|
||||
async def event_stream():
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(queue.get(), timeout=15)
|
||||
yield message
|
||||
except asyncio.TimeoutError:
|
||||
# Keep connection alive
|
||||
yield "event: keepalive\ndata: {}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
# Client disconnected
|
||||
pass
|
||||
finally:
|
||||
async with sse_lock:
|
||||
sse_subscribers.discard(queue)
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_handler(websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
client_id = next(client_id_seq)
|
||||
async with client_lock:
|
||||
clients[client_id] = websocket
|
||||
|
||||
await broadcast_event(
|
||||
{"event": "client_connected", "timestamp": now_iso(), "client_id": client_id}
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_text()
|
||||
try:
|
||||
parsed = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
print("json decode error")
|
||||
print(message)
|
||||
parsed = {"raw": message}
|
||||
|
||||
await broadcast_event(
|
||||
{
|
||||
"event": "inbound",
|
||||
"timestamp": now_iso(),
|
||||
"client_id": client_id,
|
||||
"payload": parsed,
|
||||
}
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
async with client_lock:
|
||||
clients.pop(client_id, None)
|
||||
await broadcast_event(
|
||||
{
|
||||
"event": "client_disconnected",
|
||||
"timestamp": now_iso(),
|
||||
"client_id": client_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
Loading…
Add table
Add a link
Reference in a new issue