WebSocket apps

WebSockets enable bidirectional, real-time communication between clients and servers. Flyte apps can serve WebSocket endpoints for real-time applications like chat, live updates, or streaming data.

Example: Basic WebSocket app

Here’s a simple FastAPI app with WebSocket support:

basic_websocket.py
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""A FastAPI app with WebSocket support."""

import pathlib
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import asyncio
import json
from datetime import UTC, datetime
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="Flyte WebSocket Demo",
    description="A FastAPI app with WebSocket support",
    version="1.0.0",
)

class ConnectionManager:
    """Manages WebSocket connections."""

    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        """Accept and register a new WebSocket connection."""
        await websocket.accept()
        self.active_connections.append(websocket)
        print(f"Client connected. Total: {len(self.active_connections)}")

    def disconnect(self, websocket: WebSocket):
        """Remove a WebSocket connection."""
        self.active_connections.remove(websocket)
        print(f"Client disconnected. Total: {len(self.active_connections)}")

    async def send_personal_message(self, message: str, websocket: WebSocket):
        """Send a message to a specific WebSocket connection."""
        await websocket.send_text(message)

    async def broadcast(self, message: str):
        """Broadcast a message to all active connections."""
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception as e:
                print(f"Error broadcasting: {e}")

manager = ConnectionManager()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for real-time communication."""
    await manager.connect(websocket)

    try:
        # Send welcome message
        await manager.send_personal_message(
            json.dumps({
                "type": "system",
                "message": "Welcome! You are connected.",
                "timestamp": datetime.now(UTC).isoformat(),
            }),
            websocket,
        )

        # Listen for messages
        while True:
            data = await websocket.receive_text()

            # Echo back to sender
            await manager.send_personal_message(
                json.dumps({
                    "type": "echo",
                    "message": f"Echo: {data}",
                    "timestamp": datetime.now(UTC).isoformat(),
                }),
                websocket,
            )

            # Broadcast to all clients
            await manager.broadcast(
                json.dumps({
                    "type": "broadcast",
                    "message": f"Broadcast: {data}",
                    "timestamp": datetime.now(UTC).isoformat(),
                    "connections": len(manager.active_connections),
                })
            )

    except WebSocketDisconnect:
        manager.disconnect(websocket)
        await manager.broadcast(
            json.dumps({
                "type": "system",
                "message": "A client disconnected",
                "connections": len(manager.active_connections),
            })
        )

env = FastAPIAppEnvironment(
    name="websocket-app",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed websocket app: {app_deployment[0].summary_repr()}")

WebSocket patterns

Echo server

websocket_patterns.py
@app.websocket("/echo")
async def echo(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        pass

Broadcast server

websocket_patterns.py
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)

Real-time data streaming

websocket_patterns.py
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Generate or fetch data
            data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
            await websocket.send_json(data)
            await asyncio.sleep(1)  # Send update every second
    except WebSocketDisconnect:
        pass

Chat application

websocket_patterns.py
class ChatRoom:
    def __init__(self, name: str):
        self.name = name
        self.connections: list[WebSocket] = []

    async def join(self, websocket: WebSocket):
        self.connections.append(websocket)

    async def leave(self, websocket: WebSocket):
        self.connections.remove(websocket)

    async def broadcast(self, message: str, sender: WebSocket):
        for connection in self.connections:
            if connection != sender:
                await connection.send_text(message)


rooms: dict[str, ChatRoom] = {}


@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
    await websocket.accept()

    if room_name not in rooms:
        rooms[room_name] = ChatRoom(room_name)

    room = rooms[room_name]
    await room.join(websocket)

    try:
        while True:
            data = await websocket.receive_text()
            await room.broadcast(data, websocket)
    except WebSocketDisconnect:
        await room.leave(websocket)

Using WebSockets with Flyte tasks

You can trigger Flyte tasks from WebSocket messages:

task_runner_websocket.py
@app.websocket("/task-runner")
async def task_runner(websocket: WebSocket):
    await websocket.accept()

    try:
        while True:
            # Receive task request
            message = await websocket.receive_text()
            request = json.loads(message)

            # Trigger Flyte task
            task = remote.Task.get(
                project=request["project"],
                domain=request["domain"],
                name=request["task"],
                version=request["version"],
            )

            run = await flyte.run.aio(task, **request["inputs"])

            # Send run info back
            await websocket.send_json({
                "run_id": run.id,
                "url": run.url,
                "status": "started",
            })

            # Optionally stream updates
            async for update in run.stream():
                await websocket.send_json({
                    "status": update.status,
                    "message": update.message,
                })

    except WebSocketDisconnect:
        pass

WebSocket client example

Connect from Python:

import asyncio
import websockets
import json

async def client():
    uri = "ws://your-app-url/ws"
    async with websockets.connect(uri) as websocket:
        # Send message
        await websocket.send("Hello, Server!")
        
        # Receive message
        response = await websocket.recv()
        print(f"Received: {response}")

asyncio.run(client())

Best practices

  1. Connection management: Track active connections and handle disconnections gracefully.
  2. Heartbeats: Implement ping/pong for connection health monitoring.
  3. Rate limiting: Consider rate limiting for production deployments.
  4. Error handling: Handle WebSocket errors and connection drops.
  5. Authentication: Implement authentication for secure WebSocket connections.