Maximize GPU utilization for batch inference
GPUs are expensive. When running batch inference, the single biggest cost driver is idle GPU time — cycles where the GPU sits waiting with nothing to do. Understanding why this happens and how to fix it is the key to cost-effective batch inference.
Why GPU utilization drops
A typical inference task does three things:
- Load data — read from storage, deserialize, preprocess (CPU/IO-bound)
- Run inference — forward pass through the model (GPU-bound)
- Post-process — format results, write outputs (CPU/IO-bound)
When these steps run sequentially, the GPU is idle during steps 1 and 3. For many workloads, data loading and preprocessing dominate wall-clock time, leaving the GPU busy for only a fraction of the total:
gantt
title Sequential execution — GPU idle during CPU/IO work
dateFormat X
axisFormat %s
section Task 1
Load data (CPU/IO) :a1, 0, 3
Inference (GPU) :a2, after a1, 2
Post-process (CPU/IO) :a3, after a2, 1
section Task 2
Load data (CPU/IO) :b1, after a3, 3
Inference (GPU) :b2, after b1, 2
Post-process (CPU/IO) :b3, after b2, 1
section GPU
Idle :crit, g1, 0, 3
Busy :active, g2, 3, 5
Idle :crit, g3, 5, 9
Busy :active, g4, 9, 11
Idle :crit, g5, 11, 12
In this example, the GPU is busy for only 4 out of 12 time units — 33% utilization. The rest is wasted waiting for CPU and IO operations.
Serving vs in-process batch inference
There are two common approaches to batch inference: sending requests to a hosted model server (serving), or running the model in-process alongside data loading. Each has distinct trade-offs:
| Hosted serving | In-process (Flyte) | |
|---|---|---|
| Architecture | Separate inference server (e.g. Triton, vLLM server, TGI) accessed over the network | Model loaded directly in the task process, inference via DynamicBatcher |
| Data transfer | Every request serialized over the network; large payloads add latency | Zero-copy — data stays in-process, no serialization overhead |
| Backpressure | Hard to implement; push-based architecture can overwhelm the server or drop requests | Two levels: DynamicBatcher queue blocks producers when full, and Flyte’s task scheduling automatically queues new inference tasks when replicas are busy — backpressure propagates end-to-end without any extra code |
| Utilization | Servers are often over-provisioned to maintain availability, leading to low average utilization | Batcher continuously fills the GPU with work from concurrent producers |
| Multi-model | Each model needs its own serving deployment, load balancer, and scaling config | Multiple models can time-share the same GPU — when one model finishes, the next is loaded automatically via reusable containers, no container orchestration required |
| Scaling | Requires separate infrastructure for the serving layer (load balancers, autoscalers, health checks) | Scales with Flyte — replicas auto-scale based on demand |
| Cost | Pay for always-on serving infrastructure even during low-traffic periods | Pay only for the duration of the batch job |
| Fault tolerance | Need retries, circuit breakers, and timeout handling for network failures | Failures are local; Flyte handles retries and recovery at the task level |
| Best for | Real-time / low-latency serving with unpredictable request patterns | Large-scale batch processing with known datasets |
For batch workloads, in-process inference eliminates the network overhead and infrastructure complexity of a serving layer while achieving higher GPU utilization through intelligent batching.
Solution: DynamicBatcher
DynamicBatcher from flyte.extras solves the utilization problem by separating data loading from inference and running them concurrently. Multiple async producers load and preprocess data while a single consumer feeds the GPU in optimally-sized batches:
flowchart LR
subgraph producers ["Concurrent producers (CPU/IO)"]
P1["Stream 1: load + preprocess"]
P2["Stream 2: load + preprocess"]
P3["Stream N: load + preprocess"]
end
subgraph batcher ["DynamicBatcher"]
Q["Queue with backpressure"]
A["Aggregation loop<br/>(assembles cost-budgeted batches)"]
Q --> A
end
subgraph consumer ["Processing loop (GPU)"]
G["process_fn / inference_fn<br/>(batched forward pass)"]
end
P1 --> Q
P2 --> Q
P3 --> Q
A --> G
The batcher runs two internal loops:
- Aggregation loop — drains the submission queue and assembles batches that respect a cost budget (
target_batch_cost), a maximum size (max_batch_size), and a timeout (batch_timeout_s). This ensures the GPU always receives optimally-sized batches. - Processing loop — pulls assembled batches and calls your processing function, resolving each record’s future with its result.
This pipelining means the GPU is processing batch N while data for batch N+1 is being loaded and assembled — eliminating idle time.
Basic usage
from flyte.extras import DynamicBatcher
async def process(batch: list[dict]) -> list[str]:
"""Your batch processing function. Must return results in the same order as the input."""
return [heavy_computation(item) for item in batch]
async with DynamicBatcher(
process_fn=process,
target_batch_cost=1000, # cost budget per batch
max_batch_size=64, # hard cap on records per batch
batch_timeout_s=0.05, # max wait time before dispatching a partial batch
max_queue_size=5_000, # queue size for backpressure
) as batcher:
futures = []
for record in my_records:
future = await batcher.submit(record, estimated_cost=10)
futures.append(future)
results = await asyncio.gather(*futures)Each call to submit() is non-blocking — it enqueues the record and immediately returns a Future. When the queue is full, submit() awaits until space is available, providing natural backpressure to prevent producers from overwhelming the GPU.
Cost estimation
The batcher uses cost estimates to decide how many records to group into each batch. You can provide costs in several ways (checked in order of precedence):
- Explicit — pass
estimated_costtosubmit() - Estimator function — pass
cost_estimatorto the constructor - Protocol — implement
estimate_cost()on your record type - Default — falls back to
default_cost(default: 1)
TokenBatcher for LLM inference
For LLM workloads, TokenBatcher is a convenience subclass that uses token-aware parameter names:
from dataclasses import dataclass
from flyte.extras import TokenBatcher
@dataclass
class Prompt:
text: str
def estimate_tokens(self) -> int:
"""Rough token estimate (~4 chars per token)."""
return len(self.text) // 4 + 1
async def inference(batch: list[Prompt]) -> list[str]:
"""Run batched inference through your model."""
texts = [p.text for p in batch]
outputs = model.generate(texts, sampling_params)
return [o.outputs[0].text for o in outputs]
async with TokenBatcher(
inference_fn=inference,
target_batch_tokens=32_000, # token budget per batch
max_batch_size=256,
) as batcher:
future = await batcher.submit(Prompt(text="What is 2+2?"))
result = await futureTokenBatcher checks the TokenEstimator protocol (estimate_tokens()) in addition to CostEstimator (estimate_cost()), making it natural to work with prompt types.
Monitoring utilization
DynamicBatcher exposes a stats property with real-time metrics:
stats = batcher.stats
print(f"Utilization: {stats.utilization:.1%}") # fraction of time spent processing
print(f"Records processed: {stats.total_completed}")
print(f"Batches dispatched: {stats.total_batches}")
print(f"Avg batch size: {stats.avg_batch_size:.1f}")
print(f"Busy time: {stats.busy_time_s:.1f}s")
print(f"Idle time: {stats.idle_time_s:.1f}s")| Metric | Description |
|---|---|
utilization |
Fraction of wall-clock time spent inside process_fn (0.0–1.0). Target: > 0.9. |
total_submitted |
Total records submitted via submit() |
total_completed |
Total records whose futures have been resolved |
total_batches |
Number of batches dispatched to process_fn |
avg_batch_size |
Running average records per batch |
avg_batch_cost |
Running average cost per batch |
busy_time_s |
Cumulative seconds spent inside process_fn |
idle_time_s |
Cumulative seconds the processing loop waited for batches |
If utilization is low, consider:
- Increasing concurrency — more concurrent producers means the batcher has more records to assemble into batches
- Reducing
batch_timeout_s— dispatch partial batches faster instead of waiting - Increasing
max_queue_size— allow more records to be buffered ahead of the GPU - Adding more data streams — ensure the GPU always has work queued up