# Serving graphs

A *serving graph* is a set of Flyte apps that talk to each other inside the
cluster. Instead of putting every stage of a request into one process, you
split the work across multiple `AppEnvironment`s that you deploy together —
each one sized for its own bottleneck, with its own image and scaling policy.

This pattern is useful for:

- **Heterogeneous resource requirements**: CPU pre/postprocessing in front of a GPU forward pass
- **Microservice architectures**: Independent components with distinct lifecycles
- **A/B testing and canary rollouts**: A root app routes traffic across variant apps
- **Proxy / gateway patterns**: One app fronts several backends

## Core concepts: a minimal two-app chain

The simplest serving graph — `app2` proxies HTTP calls to `app1` — is enough
to introduce every core concept: deploying multiple apps together, discovering
an upstream app's endpoint, and sizing each app independently.

Both apps share an image and live in the same Python file:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

### Deploying multiple apps together with `depends_on`

The callee env is straightforward — it has no upstream dependencies of its
own:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

The caller declares `depends_on=[env1]`, which tells Flyte that `env1` must
be deployed alongside this one:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

Calling `flyte.serve(env2)` then deploys the whole dependency closure
transitively, so you only ever name the entry-point app:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

`depends_on` is about deployment co-scheduling, not request-time ordering —
at runtime each app is independent.

### Getting an upstream app's endpoint

There are two ways for one app to discover another app's URL. Both resolve
correctly across local, in-cluster, and external contexts.

**Pattern A — `env.endpoint` (Python property).** When both apps live in the
same Python module, the upstream env object is in scope and you can read
`env.endpoint` directly. The example above uses this pattern in `app2`'s
proxy endpoint:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

**Pattern B — `flyte.app.AppEndpoint` as a parameter.** When the upstream env
object isn't importable (different file, different process, looking it up by
name), declare it as a `flyte.app.Parameter` and have Flyte inject the
resolved URL via an environment variable. The `env2` definition above shows
this — `app1_url` becomes available as `os.getenv("APP1_URL")` at runtime:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

### Sizing each node independently

Each `AppEnvironment` carries its own image, resources, and scaling. That's
the entire point of splitting — for example, the GPU side of an inference
graph can stay narrow with `scaling=Scaling(replicas=(1, 2))` while the CPU
side scales wide with `scaling=Scaling(replicas=(1, 8))`, with no shared
autoscaling policy between them. The next example shows this in practice.

## Example: CPU / GPU inference split

The canonical heterogeneous-resource pipeline: heavy CPU preprocessing in
front of a fast GPU forward pass, talking to each other over HTTP inside the
cluster.

```mermaid
flowchart LR
    client["client"] --> cpu["cpu_app (×N replicas)<br/>decode + resize<br/>+ softmax"]
    cpu --> gpu["gpu_app (×M replicas)<br/>ResNet18 forward only"]
    gpu --> cpu
    cpu --> client
```

In a typical vision/audio pipeline, the GPU forward pass takes milliseconds
but is sandwiched between slow CPU work (image decode, resize, normalization,
softmax, label lookup). If both stages share one process you pay for an idle
GPU during preprocessing. Splitting them lets each side scale independently:
cheap CPU wide, expensive GPU narrow.

### Disjoint images per node

The two apps share a small base image and add their own disjoint stacks. The
CPU app never imports `torch`; the GPU app never imports `PIL`:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### GPU app: model.forward only

The GPU app loads the model once at startup using FastAPI's lifespan, so model
weights stay resident across requests:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The inference endpoint speaks raw `float32` bytes over
`application/octet-stream`. For anything tensor-shaped this is the single
biggest perf knob — JSON-serializing a 19MB batch dominates end-to-end
latency:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The GPU environment requests a GPU and keeps replicas narrow:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### CPU app: pre/postprocess + call GPU

Preprocessing is deliberately CPU-bound — decode, denoise, resize, normalize:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The CPU app uses its lifespan to resolve the GPU endpoint via `gpu_env.endpoint`,
fetch labels once at startup, and build one persistent `httpx.AsyncClient` per
replica. Persistent clients avoid a TCP/TLS handshake per request, which
matters at high request rates:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The `/classify` endpoint glues it all together. Heavy CPU work runs in this
process; the GPU forward pass is delegated over HTTP using the raw-bytes wire
format:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The CPU environment scales wide and declares `depends_on=[gpu_env]` so both
sides deploy together:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### Deploy

`flyte.serve(cpu_env)` deploys both apps. The CPU app is the public entry
point; the GPU app is reached only via the cluster-internal endpoint:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

## Example: A/B testing with Statsig

A serving graph also lets you shape traffic. A root app routes each incoming
request to one of two variant apps using a [Statsig](https://www.statsig.com/)
feature gate, with consistent per-user bucketing.

```mermaid
flowchart LR
    client["client"] --> root["root_app<br/>(check_gate)"]
    root -->|"gate off"| a["app_a<br/>fast-processing"]
    root -->|"gate on"| b["app_b<br/>enhanced-processing"]
```

### Statsig client singleton

The variant routing logic needs a single Statsig client per process. Wrap it
in a singleton so lifespan startup/shutdown is the only place that touches its
lifecycle:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Variant apps

The two variants are independent FastAPI apps with their own endpoints. Each
variant returns a payload labeled with its identity, but they're otherwise
deployed and scaled independently:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Root app with Statsig in its lifespan

The root app's lifespan initializes Statsig at startup and shuts it down
cleanly. The API key arrives as an env var because the env is configured with
a Flyte secret (see below):

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### App environments

Variant envs are minimal:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

The root env declares `depends_on=[env_a, env_b]` so all three deploy
together, and pulls the Statsig API key from a Flyte secret:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Routing endpoint

The root app checks the `variant_b` feature gate against a user key and
proxies to the matching variant using its `endpoint` property:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

Use stable identifiers (user ID, session ID) for `user_key` so the same user
always lands in the same bucket. To swap `check_gate` for an experiment or
dynamic config:

```python
experiment = statsig.get_experiment(user, "my_experiment")
variant = experiment.get("variant", "A")
```

### Deploy

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

**Setup before running:**

1. Get a Server Secret Key at [statsig.com](https://www.statsig.com/) → Settings → API Keys.
2. Create a feature gate named `variant_b` (e.g. 50% rollout).
3. Set the Flyte secret:
   ```bash
   flyte create secret statsig-api-key <your-secret-key-here>
   ```

## When to split into a serving graph

Split when stages have:

- **Different bottlenecks** — CPU vs GPU vs memory
- **Different scaling needs** — bursty vs steady, wide vs narrow
- **Different lifecycles** — model weights you don't want to reload, expensive cold starts
- **Different routing concerns** — A/B, canary, proxy, gateway

Don't split just to separate code — a single app with a few endpoints is
simpler to operate.

## Best practices

1. **Use `depends_on`**: Always specify dependencies to ensure the dependency closure is deployed in one shot.
2. **Persistent HTTP clients**: Open one `httpx.AsyncClient` per replica in the app's lifespan rather than per request, to avoid TCP/TLS setup overhead.
3. **Pick the right wire format**: For tensor-shaped payloads, send raw bytes over `application/octet-stream` instead of JSON.
4. **Size each node independently**: GPU narrow, CPU wide; use scale-to-zero (`replicas=(0, N)`) for bursty downstream services.
5. **Authentication**: Use `requires_auth=True` on internal-only apps so they can't be reached from the public internet, and put public-facing auth on the entry-point app.
6. **Endpoint access**: Prefer `app_env.endpoint` for in-module references; use `flyte.app.AppEndpoint` parameters when the upstream env isn't importable.

---
**Source**: https://github.com/unionai/unionai-docs/blob/main/content/user-guide/build-apps/serving-graphs.md
**HTML**: https://www.union.ai/docs/v2/union/user-guide/build-apps/serving-graphs/
