2.0.6
TokenBatcher
Package: flyte.extras
Token-aware batcher for LLM inference workloads.
A thin convenience wrapper around :class:DynamicBatcher that accepts
token-specific parameter names (inference_fn, token_estimator,
target_batch_tokens, etc.) and maps them to the base class.
Also checks the :class:TokenEstimator protocol (estimate_tokens())
in addition to :class:CostEstimator (estimate_cost()).
Example::
async def inference(batch: list[Prompt]) -> list[str]:
...
async with TokenBatcher(inference_fn=inference) as batcher:
future = await batcher.submit(Prompt(text="Hello"))
result = await future
class TokenBatcher(
inference_fn: ProcessFn[RecordT, ResultT] | None,
process_fn: ProcessFn[RecordT, ResultT] | None,
token_estimator: CostEstimatorFn[RecordT] | None,
cost_estimator: CostEstimatorFn[RecordT] | None,
target_batch_tokens: int | None,
target_batch_cost: int,
default_token_estimate: int | None,
default_cost: int,
max_batch_size: int,
min_batch_size: int,
batch_timeout_s: float,
max_queue_size: int,
prefetch_batches: int,
)| Parameter | Type | Description |
|---|---|---|
inference_fn |
ProcessFn[RecordT, ResultT] | None |
|
process_fn |
ProcessFn[RecordT, ResultT] | None |
|
token_estimator |
CostEstimatorFn[RecordT] | None |
|
cost_estimator |
CostEstimatorFn[RecordT] | None |
|
target_batch_tokens |
int | None |
|
target_batch_cost |
int |
|
default_token_estimate |
int | None |
|
default_cost |
int |
|
max_batch_size |
int |
|
min_batch_size |
int |
|
batch_timeout_s |
float |
|
max_queue_size |
int |
|
prefetch_batches |
int |
Properties
| Property | Type | Description |
|---|---|---|
is_running |
None |
Whether the aggregation and processing loops are active. |
stats |
None |
Current :class:BatchStats snapshot. |
Methods
| Method | Description |
|---|---|
start() |
Start the aggregation and processing loops. |
stop() |
Graceful shutdown: process all enqueued work, then stop. |
submit() |
Submit a single record for batched inference. |
submit_batch() |
Convenience: submit multiple records and return their futures. |
start()
def start()Start the aggregation and processing loops.
Raises: RuntimeError: If the batcher is already running.
stop()
def stop()Graceful shutdown: process all enqueued work, then stop.
Blocks until every pending future is resolved.
submit()
def submit(
record: RecordT,
estimated_tokens: int | None,
estimated_cost: int | None,
) -> asyncio.Future[ResultT]Submit a single record for batched inference.
Accepts either estimated_tokens or estimated_cost.
| Parameter | Type | Description |
|---|---|---|
record |
RecordT |
The input record. |
estimated_tokens |
int | None |
Optional explicit token count. |
estimated_cost |
int | None |
Optional explicit cost (base class parameter). |
submit_batch()
def submit_batch(
records: Sequence[RecordT],
estimated_cost: Sequence[int] | None,
) -> list[asyncio.Future[ResultT]]Convenience: submit multiple records and return their futures.
| Parameter | Type | Description |
|---|---|---|
records |
Sequence[RecordT] |
Iterable of input records. |
estimated_cost |
Sequence[int] | None |
Optional per-record cost estimates. Length must match records when provided. |