Maximize GPU utilization for batch inference
GPUs are expensive. When running batch inference, the single biggest cost driver is idle GPU time — cycles where the GPU sits waiting with nothing to do. Understanding why this happens and how to fix it is the key to cost-effective batch inference.
Why GPU utilization drops
A typical inference task does three things:
- Load data — read from storage, deserialize, preprocess (CPU/IO-bound)
- Run inference — forward pass through the model (GPU-bound)
- Post-process — format results, write outputs (CPU/IO-bound)
When these steps run sequentially, the GPU is idle during steps 1 and 3. For many workloads, data loading and preprocessing dominate wall-clock time, leaving the GPU busy for only a fraction of the total:
gantt
title Sequential execution — GPU idle during CPU/IO work
dateFormat X
axisFormat %s
section Task 1
Load data (CPU/IO) :a1, 0, 3
Inference (GPU) :a2, after a1, 2
Post-process (CPU/IO) :a3, after a2, 1
section Task 2
Load data (CPU/IO) :b1, after a3, 3
Inference (GPU) :b2, after b1, 2
Post-process (CPU/IO) :b3, after b2, 1
section GPU
Idle :crit, g1, 0, 3
Busy :active, g2, 3, 5
Idle :crit, g3, 5, 9
Busy :active, g4, 9, 11
Idle :crit, g5, 11, 12
In this example, the GPU is busy for only 4 out of 12 time units — 33% utilization. The rest is wasted waiting for CPU and IO operations.
Serving vs in-process batch inference
There are two common approaches to batch inference: sending requests to a hosted model server (serving), or running the model in-process alongside data loading. Each has distinct trade-offs:
| Hosted serving | In-process (Flyte) | |
|---|---|---|
| Architecture | Separate inference server (e.g. Triton, vLLM server, TGI) accessed over the network | Model loaded directly in the task process, inference via DynamicBatcher |
| Data transfer | Every request serialized over the network; large payloads add latency | Zero-copy — data stays in-process, no serialization overhead |
| Backpressure | Hard to implement; push-based architecture can overwhelm the server or drop requests | Two levels: DynamicBatcher queue blocks producers when full, and Flyte’s task scheduling automatically queues new inference tasks when replicas are busy — backpressure propagates end-to-end without any extra code |
| Utilization | Servers are often over-provisioned to maintain availability, leading to low average utilization | Batcher continuously fills the GPU with work from concurrent producers |
| Multi-model | Each model needs its own serving deployment, load balancer, and scaling config | Multiple models can time-share the same GPU — when one model finishes, the next is loaded automatically via reusable containers, no container orchestration required |
| Scaling | Requires separate infrastructure for the serving layer (load balancers, autoscalers, health checks) | Scales with Flyte — replicas auto-scale based on demand |
| Cost | Pay for always-on serving infrastructure even during low-traffic periods | Pay only for the duration of the batch job |
| Fault tolerance | Need retries, circuit breakers, and timeout handling for network failures | Failures are local; Flyte handles retries and recovery at the task level |
| Best for | Real-time / low-latency serving with unpredictable request patterns | Large-scale batch processing with known datasets |
For batch workloads, in-process inference eliminates the network overhead and infrastructure complexity of a serving layer while achieving higher GPU utilization through intelligent batching.
Solution: DynamicBatcher
DynamicBatcher from flyte.extras solves the utilization problem by separating data loading from inference and running them concurrently. Multiple async producers load and preprocess data while a single consumer feeds the GPU in optimally-sized batches:
flowchart LR
subgraph producers ["Concurrent producers (CPU/IO)"]
P1["Stream 1: load + preprocess"]
P2["Stream 2: load + preprocess"]
P3["Stream N: load + preprocess"]
end
subgraph batcher ["DynamicBatcher"]
Q["Queue with backpressure"]
A["Aggregation loop<br/>(assembles cost-budgeted batches)"]
Q --> A
end
subgraph consumer ["Processing loop (GPU)"]
G["process_fn / inference_fn<br/>(batched forward pass)"]
end
P1 --> Q
P2 --> Q
P3 --> Q
A --> G
The batcher runs two internal loops:
- Aggregation loop — drains the submission queue and assembles batches that respect a cost budget (
target_batch_cost), a maximum size (max_batch_size), and a timeout (batch_timeout_s). This ensures the GPU always receives optimally-sized batches. - Processing loop — pulls assembled batches and calls your processing function, resolving each record’s future with its result.
This pipelining means the GPU is processing batch N while data for batch N+1 is being loaded and assembled — eliminating idle time.
Basic usage
from flyte.extras import DynamicBatcher
async def process(batch: list[dict]) -> list[str]:
"""Your batch processing function. Must return results in the same order as the input."""
return [heavy_computation(item) for item in batch]
async with DynamicBatcher(
process_fn=process,
target_batch_cost=1000, # cost budget per batch
max_batch_size=64, # hard cap on records per batch
batch_timeout_s=0.05, # max wait time before dispatching a partial batch
max_queue_size=5_000, # queue size for backpressure
) as batcher:
futures = []
for record in my_records:
future = await batcher.submit(record, estimated_cost=10)
futures.append(future)
results = await asyncio.gather(*futures)Each call to submit() is non-blocking — it enqueues the record and immediately returns a Future. When the queue is full, submit() awaits until space is available, providing natural backpressure to prevent producers from overwhelming the GPU.
Cost estimation
The batcher uses cost estimates to decide how many records to group into each batch. You can provide costs in several ways (checked in order of precedence):
- Explicit — pass
estimated_costtosubmit() - Estimator function — pass
cost_estimatorto the constructor - Protocol — implement
estimate_cost()on your record type - Default — falls back to
default_cost(default: 1)
TokenBatcher for LLM inference
For LLM workloads, TokenBatcher is a convenience subclass that uses token-aware parameter names:
from dataclasses import dataclass
from flyte.extras import TokenBatcher
@dataclass
class Prompt:
text: str
def estimate_tokens(self) -> int:
"""Rough token estimate (~4 chars per token)."""
return len(self.text) // 4 + 1
async def inference(batch: list[Prompt]) -> list[str]:
"""Run batched inference through your model."""
texts = [p.text for p in batch]
outputs = model.generate(texts, sampling_params)
return [o.outputs[0].text for o in outputs]
async with TokenBatcher(
inference_fn=inference,
target_batch_tokens=32_000, # token budget per batch
max_batch_size=256,
) as batcher:
future = await batcher.submit(Prompt(text="What is 2+2?"))
result = await futureTokenBatcher checks the TokenEstimator protocol (estimate_tokens()) in addition to CostEstimator (estimate_cost()), making it natural to work with prompt types.
Combining with app environments
DynamicBatcher on its own improves utilization within a single task, but the model has to be loaded from scratch on every invocation. To amortize that cost across many task runs, host the model inside a long-lived
AppEnvironment and have driver tasks call it over HTTP:
- Amortized model loading — the model is loaded once when the app starts and stays in memory for the lifetime of the replica
- Cross-task batching — every concurrent HTTP request submits to the same shared
TokenBatcher, so the GPU always has a full queue of work - Automatic scaling — the app autoscales between min and max replicas based on a concurrency target, and each replica maintains its own model and batcher
flowchart LR
D["Driver task<br/>fans out chunks<br/>(concurrency cap)"]
subgraph calls ["infer_batch tasks (HTTP clients)"]
T1["call 1"]
T2["call 2"]
T3["call N"]
end
D --> T1
D --> T2
D --> T3
subgraph app ["FastAPI app environment (GPU)"]
FA["POST /generate"]
B["Shared TokenBatcher"]
M["vLLM model<br/>(loaded in lifespan)"]
FA --> B --> M
end
T1 --> FA
T2 --> FA
T3 --> FA
The two key techniques are:
- Use FastAPI’s
lifespanto load the model and start theTokenBatcherexactly once per replica, then attach the batcher toapp.stateso request handlers can reach it. - Cap driver concurrency with
flyte.map.aio(..., concurrency=N)so the orchestrator doesn’t overload the app with more in-flight requests than its scaling target can serve.
Example: batch LLM inference with vLLM behind a FastAPI app
This example loads math problems from HuggingFace’s gsm8k dataset and solves them by calling a FastAPI app that runs vLLM with a shared TokenBatcher.
1. Load the model and batcher once via FastAPI lifespan
The FastAPI lifespan runs on startup and shutdown. Use it to load the vLLM model and start the TokenBatcher exactly once per replica, then attach the batcher to app.state so request handlers can reach it:
import asyncio
import logging
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment
from flyte.extras import TokenBatcher
logger = logging.getLogger(__name__)
@dataclass
class Prompt:
task_id: str
index: int
text: str
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Load the vLLM model and start the TokenBatcher once at startup."""
from vllm import LLM, SamplingParams
llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct",
gpu_memory_utilization=0.9,
max_model_len=4096,
)
params = SamplingParams(temperature=0.7, max_tokens=512)
logger.info("vLLM model loaded")
async def inference(batch: list[Prompt]) -> list[str]:
texts = [p.text for p in batch]
outputs = llm.generate(texts, params)
return [o.outputs[0].text for o in outputs]
batcher = TokenBatcher[Prompt, str](
inference_fn=inference,
target_batch_tokens=32_000,
max_batch_size=256,
batch_timeout_s=0.05,
max_queue_size=5_000,
)
await batcher.start()
logger.info("TokenBatcher started")
app.state.batcher = batcher
yield
await batcher.stop()
app = FastAPI(title="Batched Inference Service", lifespan=lifespan)Stashing the batcher on app.state means every request handler can grab the same shared instance via request.app.state.batcher, so concurrent requests all feed into one queue.
2. Add an endpoint that submits to the shared batcher
Each request just enqueues records — the batcher aggregates records across concurrent requests into token-budgeted batches before hitting the GPU:
class GenerateRequest(BaseModel):
prompts: list[str]
task_id: str
@app.post("/generate")
async def generate(request_body: GenerateRequest, request: Request):
if not request_body.prompts:
raise HTTPException(status_code=400, detail="No prompts provided")
batcher: TokenBatcher[Prompt, str] = request.app.state.batcher
futures: list[asyncio.Future[str]] = []
for idx, text in enumerate(request_body.prompts):
record = Prompt(task_id=request_body.task_id, index=idx, text=text)
future = await batcher.submit(record)
futures.append(future)
results = await asyncio.gather(*futures)
return {"results": results}3. Define the app environment and driver task environment
The app uses a
FastAPIAppEnvironment on a GPU and autoscales via
Scaling. The driver runs in a CPU-only
TaskEnvironment that depends_on the app so the app is deployed before the driver runs:
image = (
flyte.Image.from_debian_base()
.with_pip_packages("vllm", "hf-transfer", "fastapi", "uvicorn")
.with_env_vars({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
)
app_env = FastAPIAppEnvironment(
name="batch-inference-saturate-app",
app=app,
image=image,
resources=flyte.Resources(cpu=6, memory="24Gi", gpu="L4:1", disk="64Gi"),
scaling=flyte.app.Scaling(
replicas=(0, 2),
metric=flyte.app.Scaling.Concurrency(val=10),
scaledown_after=300,
),
requires_auth=False,
)
driver_env = flyte.TaskEnvironment(
name="batch_inference_saturate_app_driver",
resources=flyte.Resources(cpu=2, memory="2Gi"),
image=image,
depends_on=[app_env],
)With replicas=(0, 2) and a concurrency target of 10, the app scales between 0 and 2 GPU replicas and aims for ~10 concurrent in-flight requests per replica — so up to ~20 requests can be served in parallel.
4. Define a driver task that calls the app
The driver task POSTs prompt chunks to the app’s endpoint. Use generous timeouts and retries to absorb cold starts and transient failures during scaling events:
import httpx
@driver_env.task(retries=20)
async def infer_batch(
endpoint: str,
prompts: list[str],
task_id: str,
) -> list[str]:
url = f"{endpoint}/generate"
async with httpx.AsyncClient(
timeout=httpx.Timeout(connect=60.0, read=600.0, write=30.0, pool=10.0),
) as client:
response = await client.post(
url,
json={"prompts": prompts, "task_id": task_id},
)
response.raise_for_status()
return response.json()["results"]5. Fan out chunks with a concurrency cap
The orchestrator chunks the dataset and submits each chunk as a separate infer_batch call. Use flyte.map.aio(..., concurrency=max_concurrency) to cap the number of in-flight HTTP calls so the task doesn’t overload the app with more requests than its scaling target can serve:
@driver_env.task
async def main(
num_questions: int = 500,
chunk_size: int = 50,
max_concurrency: int = 10,
) -> dict[str, list[str]]:
questions = await fetch_gsm8k_questions(num_questions)
endpoint = app_env.endpoint
chunks = [
questions[i : i + chunk_size]
for i in range(0, len(questions), chunk_size)
]
task_ids = [f"gsm8k_{i:03d}" for i in range(len(chunks))]
all_results = [
result
async for result in flyte.map.aio(
infer_batch,
[endpoint] * len(chunks),
chunks,
task_ids,
concurrency=max_concurrency,
)
]
return dict(zip(task_ids, all_results))Match max_concurrency to the app’s scaling configuration. In this example, the app autoscales up to 2 replicas with a concurrency target of 10, so ~20 requests can be in flight at once. Setting max_concurrency=10 keeps the driver from queueing requests far beyond what the app can absorb — which would otherwise stack up behind the batcher’s max_queue_size, exhaust HTTP timeouts, and waste retry budget.
Monitoring utilization
DynamicBatcher exposes a stats property with real-time metrics:
stats = batcher.stats
print(f"Utilization: {stats.utilization:.1%}") # fraction of time spent processing
print(f"Records processed: {stats.total_completed}")
print(f"Batches dispatched: {stats.total_batches}")
print(f"Avg batch size: {stats.avg_batch_size:.1f}")
print(f"Busy time: {stats.busy_time_s:.1f}s")
print(f"Idle time: {stats.idle_time_s:.1f}s")| Metric | Description |
|---|---|
utilization |
Fraction of wall-clock time spent inside process_fn (0.0–1.0). Target: > 0.9. |
total_submitted |
Total records submitted via submit() |
total_completed |
Total records whose futures have been resolved |
total_batches |
Number of batches dispatched to process_fn |
avg_batch_size |
Running average records per batch |
avg_batch_cost |
Running average cost per batch |
busy_time_s |
Cumulative seconds spent inside process_fn |
idle_time_s |
Cumulative seconds the processing loop waited for batches |
If utilization is low, consider:
- Increasing concurrency — more concurrent producers means the batcher has more records to assemble into batches
- Reducing
batch_timeout_s— dispatch partial batches faster instead of waiting - Increasing
max_queue_size— allow more records to be buffered ahead of the GPU - Adding more data streams — ensure the GPU always has work queued up