Serving graphs
A serving graph is a set of Flyte apps that talk to each other inside the
cluster. Instead of putting every stage of a request into one process, you
split the work across multiple AppEnvironments that you deploy together —
each one sized for its own bottleneck, with its own image and scaling policy.
This pattern is useful for:
- Heterogeneous resource requirements: CPU pre/postprocessing in front of a GPU forward pass
- Microservice architectures: Independent components with distinct lifecycles
- A/B testing and canary rollouts: A root app routes traffic across variant apps
- Proxy / gateway patterns: One app fronts several backends
Core concepts: a minimal two-app chain
The simplest serving graph — app2 proxies HTTP calls to app1 — is enough
to introduce every core concept: deploying multiple apps together, discovering
an upstream app’s endpoint, and sizing each app independently.
Both apps share an image and live in the same Python file:
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
app1 = FastAPI(
title="App 1",
description="A FastAPI app that runs some computations",
)
app2 = FastAPI(
title="App 2",
description="A FastAPI app that proxies requests to another FastAPI app",
)
Deploying multiple apps together with depends_on
The callee env is straightforward — it has no upstream dependencies of its own:
env1 = FastAPIAppEnvironment(
name="app1-is-called-by-app2",
app=app1,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=True,
)
The caller declares depends_on=[env1], which tells Flyte that env1 must
be deployed alongside this one:
env2 = FastAPIAppEnvironment(
name="app2-calls-app1",
app=app2,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=True,
parameters=[
flyte.app.Parameter(
name="app1_url",
value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
env_var="APP1_URL",
),
],
depends_on=[env1],
env_vars={"LOG_LEVEL": "10"},
)
Calling flyte.serve(env2) then deploys the whole dependency closure
transitively, so you only ever name the entry-point app:
if __name__ == "__main__":
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent,
log_level=logging.DEBUG,
)
app = flyte.serve(env2)
print(f"Deployed FastAPI app: {app.url}")
depends_on is about deployment co-scheduling, not request-time ordering —
at runtime each app is independent.
Getting an upstream app’s endpoint
There are two ways for one app to discover another app’s URL. Both resolve correctly across local, in-cluster, and external contexts.
Pattern A — env.endpoint (Python property). When both apps live in the
same Python module, the upstream env object is in scope and you can read
env.endpoint directly. The example above uses this pattern in app2’s
proxy endpoint:
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
return env1.endpoint
@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
async with httpx.AsyncClient() as client:
response = await client.get(f"{env1.endpoint}/greeting/{name}")
return response.json()
Pattern B — flyte.app.AppEndpoint as a parameter. When the upstream env
object isn’t importable (different file, different process, looking it up by
name), declare it as a flyte.app.Parameter and have Flyte inject the
resolved URL via an environment variable. The env2 definition above shows
this — app1_url becomes available as os.getenv("APP1_URL") at runtime:
@app2.get("/app1-url")
async def get_app1_url() -> str:
return os.getenv("APP1_URL")
Sizing each node independently
Each AppEnvironment carries its own image, resources, and scaling. That’s
the entire point of splitting — for example, the GPU side of an inference
graph can stay narrow with scaling=Scaling(replicas=(1, 2)) while the CPU
side scales wide with scaling=Scaling(replicas=(1, 8)), with no shared
autoscaling policy between them. The next example shows this in practice.
Example: CPU / GPU inference split
The canonical heterogeneous-resource pipeline: heavy CPU preprocessing in front of a fast GPU forward pass, talking to each other over HTTP inside the cluster.
flowchart LR
client["client"] --> cpu["cpu_app (×N replicas)<br/>decode + resize<br/>+ softmax"]
cpu --> gpu["gpu_app (×M replicas)<br/>ResNet18 forward only"]
gpu --> cpu
cpu --> client
In a typical vision/audio pipeline, the GPU forward pass takes milliseconds but is sandwiched between slow CPU work (image decode, resize, normalization, softmax, label lookup). If both stages share one process you pay for an idle GPU during preprocessing. Splitting them lets each side scale independently: cheap CPU wide, expensive GPU narrow.
Disjoint images per node
The two apps share a small base image and add their own disjoint stacks. The
CPU app never imports torch; the GPU app never imports PIL:
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"fastapi",
"uvicorn",
"numpy",
)
cpu_image = base_image.with_pip_packages(
"httpx",
"pillow",
)
gpu_image = base_image.with_pip_packages(
"torch==2.7.1",
"torchvision==0.22.1",
)
GPU app: model.forward only
The GPU app loads the model once at startup using FastAPI’s lifespan, so model weights stay resident across requests:
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
# Imported lazily so the CPU app never has to import torch.
import torch
from torchvision.models import ResNet18_Weights, resnet18
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights).eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
model = model.to("cuda")
app.state.model = model
app.state.device = device
app.state.categories = list(weights.meta["categories"])
logging.getLogger(__name__).info("model loaded on %s", device)
yield
gpu_app = FastAPI(
title="inference-gpu",
description="ResNet18 forward pass.",
lifespan=_gpu_lifespan,
)
The inference endpoint speaks raw float32 bytes over
application/octet-stream. For anything tensor-shaped this is the single
biggest perf knob — JSON-serializing a 19MB batch dominates end-to-end
latency:
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
"""Run a batched forward pass.
Request body: raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
Response body: raw float32 bytes, shape (B, 1000) — raw logits.
We deliberately do NOT use JSON here. For a batch of 32 images the tensor
is ~19MB; JSON-serializing that is the dominant cost end-to-end.
"""
import torch
raw = await request.body()
arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
raise HTTPException(400, "payload size is not a multiple of one image tensor")
batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)
x = torch.from_numpy(batch).to(gpu_app.state.device)
with torch.inference_mode():
logits = gpu_app.state.model(x)
out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
return Response(content=out.tobytes(), media_type="application/octet-stream")
The GPU environment requests a GPU and keeps replicas narrow:
gpu_env = FastAPIAppEnvironment(
name="serving-graph-gpu",
app=gpu_app,
image=gpu_image,
resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
# GPU replicas are expensive; keep at least one warm so model weights stay
# resident, and cap the max. Bump if a single replica saturates.
scaling=flyte.app.Scaling(replicas=(1, 2)),
requires_auth=True,
)
CPU app: pre/postprocess + call GPU
Preprocessing is deliberately CPU-bound — decode, denoise, resize, normalize:
def _preprocess(img_bytes: bytes) -> np.ndarray:
"""Decode → denoise → resize → normalize. CPU-bound, deliberately so.
Real preprocessing stacks (detection, OCR, audio) do substantially more
than this — sliding window crops, color-space conversion, etc. The point
is that none of it benefits from a GPU sitting next to it.
"""
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
arr = arr.transpose(2, 0, 1) # HWC → CHW
arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
The CPU app uses its lifespan to resolve the GPU endpoint via gpu_env.endpoint,
fetch labels once at startup, and build one persistent httpx.AsyncClient per
replica. Persistent clients avoid a TCP/TLS handshake per request, which
matters at high request rates:
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
# Resolved at serving time via the cluster-internal endpoint pattern,
# so this stays correct across local/remote deploys without an env var.
gpu_url = gpu_env.endpoint
log = logging.getLogger(__name__)
log.info("resolved GPU endpoint: %s", gpu_url)
async with httpx.AsyncClient(timeout=30.0) as bootstrap:
try:
r = await bootstrap.get(f"{gpu_url}/labels")
r.raise_for_status()
except (httpx.HTTPError, OSError) as e:
# Most common reason on a fresh deploy: GPU replica hasn't finished
# pulling its image / loading weights yet. Crash-looping is fine —
# the next attempt will likely succeed — but make the cause obvious.
log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
raise
app.state.labels = r.json()
# One persistent client per replica — avoids TCP/TLS handshake per request,
# which matters once you're doing 100s of req/s.
async with httpx.AsyncClient(
base_url=gpu_url,
timeout=httpx.Timeout(30.0, connect=5.0),
limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
) as client:
app.state.client = client
yield
cpu_app = FastAPI(
title="inference-cpu",
description="Pre/post around the GPU forward pass.",
lifespan=_cpu_lifespan,
)
The /classify endpoint glues it all together. Heavy CPU work runs in this
process; the GPU forward pass is delegated over HTTP using the raw-bytes wire
format:
async def validate_public_image_url(image_url: str) -> str:
try:
parsed = httpx.URL(image_url)
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
if parsed.scheme not in {"http", "https"}:
raise HTTPException(status_code=400, detail="image_url must use http or https.")
host = parsed.host
if not host:
raise HTTPException(status_code=400, detail="image_url must include a hostname.")
try:
addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
except socket.gaierror as exc:
raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
for info in addr_info:
ip_text = info[4][0]
ip_obj = ipaddress.ip_address(ip_text)
if not ip_obj.is_global:
raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
return str(parsed)
@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
async with httpx.AsyncClient(timeout=30.0) as client:
img_resp = await client.get(await validate_public_image_url(req.image_url))
img_resp.raise_for_status()
tensor = _preprocess(img_resp.content) # heavy CPU
batch = tensor[np.newaxis, ...] # add batch dim
gpu_resp = await cpu_app.state.client.post(
"/infer",
content=batch.tobytes(),
headers={"content-type": "application/octet-stream"},
)
gpu_resp.raise_for_status()
logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)
probs = _softmax(logits, axis=-1)[0] # back to CPU work
top_idx = np.argsort(-probs)[: req.top_k]
return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
The CPU environment scales wide and declares depends_on=[gpu_env] so both
sides deploy together:
cpu_env = FastAPIAppEnvironment(
name="serving-graph-cpu",
app=cpu_app,
image=cpu_image,
resources=flyte.Resources(cpu=4, memory="4Gi"),
# Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
# traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
scaling=flyte.app.Scaling(replicas=(1, 8)),
requires_auth=True,
depends_on=[gpu_env],
)
Deploy
flyte.serve(cpu_env) deploys both apps. The CPU app is the public entry
point; the GPU app is reached only via the cluster-internal endpoint:
if __name__ == "__main__":
flyte.init_from_config(
root_dir=pathlib.Path(__file__).parent,
log_level=logging.INFO,
)
app = flyte.serve(cpu_env)
print(f"Deployed serving graph; public CPU endpoint: {app.url}")
print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
print(
' -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
)
Example: A/B testing with Statsig
A serving graph also lets you shape traffic. A root app routes each incoming request to one of two variant apps using a Statsig feature gate, with consistent per-user bucketing.
flowchart LR
client["client"] --> root["root_app<br/>(check_gate)"]
root -->|"gate off"| a["app_a<br/>fast-processing"]
root -->|"gate on"| b["app_b<br/>enhanced-processing"]
Statsig client singleton
The variant routing logic needs a single Statsig client per process. Wrap it in a singleton so lifespan startup/shutdown is the only place that touches its lifecycle:
class StatsigClient:
"""Singleton to manage Statsig client lifecycle."""
_instance: "StatsigClient | None" = None
_statsig = None
@classmethod
def initialize(cls, api_key: str):
"""Initialize Statsig client (call during lifespan startup)."""
if cls._instance is None:
cls._instance = cls()
# Import statsig at runtime (only available in container)
from statsig_python_core import Statsig
cls._statsig = Statsig(api_key)
cls._statsig.initialize().wait()
@classmethod
def get_client(cls):
"""Get the initialized Statsig instance."""
if cls._statsig is None:
raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
return cls._statsig
@classmethod
def shutdown(cls):
"""Shutdown Statsig client (call during lifespan shutdown)."""
if cls._statsig is not None:
cls._statsig.shutdown()
cls._statsig = None
cls._instance = None
Variant apps
The two variants are independent FastAPI apps with their own endpoints. Each variant returns a payload labeled with its identity, but they’re otherwise deployed and scaled independently:
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")
# App A - First variant
app_a = FastAPI(
title="App A",
description="Variant A for A/B testing",
)
# App B - Second variant
app_b = FastAPI(
title="App B",
description="Variant B for A/B testing",
)
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
return {
"variant": "A",
"message": f"App A processed: {message}",
"algorithm": "fast-processing",
}
# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
return {
"variant": "B",
"message": f"App B processed: {message}",
"algorithm": "enhanced-processing",
}
Root app with Statsig in its lifespan
The root app’s lifespan initializes Statsig at startup and shuts it down cleanly. The API key arrives as an env var because the env is configured with a Flyte secret (see below):
@asynccontextmanager
async def lifespan(_app: FastAPI):
"""Initialize and shutdown Statsig for A/B testing."""
# Startup: Initialize Statsig using singleton
api_key = os.getenv("STATSIG_API_KEY", None)
if api_key is None:
raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
StatsigClient.initialize(api_key)
yield
# Shutdown: Cleanup Statsig
StatsigClient.shutdown()
# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
title="Root App - A/B Testing",
description="Routes requests to App A or App B based on Statsig A/B test",
lifespan=lifespan,
)
App environments
Variant envs are minimal:
env_a = FastAPIAppEnvironment(
name="app-a-variant",
app=app_a,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
env_b = FastAPIAppEnvironment(
name="app-b-variant",
app=app_b,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
)
The root env declares depends_on=[env_a, env_b] so all three deploy
together, and pulls the Statsig API key from a Flyte secret:
env_root = FastAPIAppEnvironment(
name="root-ab-testing-app",
app=root_app,
image=image,
resources=flyte.Resources(cpu=1, memory="512Mi"),
depends_on=[env_a, env_b],
secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
Routing endpoint
The root app checks the variant_b feature gate against a user key and
proxies to the matching variant using its endpoint property:
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
"""
Process a message using A/B testing to determine which app to call.
Args:
message: The message to process
user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)
Returns:
Response from either App A or App B, plus metadata about which variant was used
"""
# Import StatsigUser at runtime (only available in container)
from statsig_python_core import StatsigUser
# Get statsig client from singleton
statsig = StatsigClient.get_client()
# Create Statsig user with the provided key
user = StatsigUser(user_id=user_key)
# Check the feature gate "variant_b" to determine which variant
# If gate is enabled, use App B; otherwise use App A
use_variant_b = statsig.check_gate(user, "variant_b")
# Call the appropriate app based on A/B test result
async with httpx.AsyncClient() as client:
if use_variant_b:
endpoint = f"{env_b.endpoint}/process/{message}"
response = await client.get(endpoint)
result = response.json()
else:
endpoint = f"{env_a.endpoint}/process/{message}"
response = await client.get(endpoint)
result = response.json()
# Add A/B test metadata to response
return {
"ab_test_result": {
"user_key": user_key,
"selected_variant": "B" if use_variant_b else "A",
"gate_name": "variant_b",
},
"response": result,
}
Use stable identifiers (user ID, session ID) for user_key so the same user
always lands in the same bucket. To swap check_gate for an experiment or
dynamic config:
experiment = statsig.get_experiment(user, "my_experiment")
variant = experiment.get("variant", "A")Deploy
if __name__ == "__main__":
flyte.init_from_config()
flyte.deploy(env_root)
print("Deployed A/B Testing Root App")
print("\nUsage:")
print(" Open your browser to '<endpoint>/' to access the interactive demo")
print(" Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
print(" Create a feature gate named 'variant_b' in your Statsig dashboard.")
Setup before running:
- Get a Server Secret Key at statsig.com → Settings → API Keys.
- Create a feature gate named
variant_b(e.g. 50% rollout). - Set the Flyte secret:
flyte create secret statsig-api-key <your-secret-key-here>
When to split into a serving graph
Split when stages have:
- Different bottlenecks — CPU vs GPU vs memory
- Different scaling needs — bursty vs steady, wide vs narrow
- Different lifecycles — model weights you don’t want to reload, expensive cold starts
- Different routing concerns — A/B, canary, proxy, gateway
Don’t split just to separate code — a single app with a few endpoints is simpler to operate.
Best practices
- Use
depends_on: Always specify dependencies to ensure the dependency closure is deployed in one shot. - Persistent HTTP clients: Open one
httpx.AsyncClientper replica in the app’s lifespan rather than per request, to avoid TCP/TLS setup overhead. - Pick the right wire format: For tensor-shaped payloads, send raw bytes over
application/octet-streaminstead of JSON. - Size each node independently: GPU narrow, CPU wide; use scale-to-zero (
replicas=(0, N)) for bursty downstream services. - Authentication: Use
requires_auth=Trueon internal-only apps so they can’t be reached from the public internet, and put public-facing auth on the entry-point app. - Endpoint access: Prefer
app_env.endpointfor in-module references; useflyte.app.AppEndpointparameters when the upstream env isn’t importable.