Multimodal retrieval evaluation

Code available here.

This tutorial builds an experiment framework for benchmarking visual document retrieval on the ViDoRe benchmark. The corpus is a set of PDF page images and the queries are plain-text questions — each retrieval method must find the page that answers a question from the raw image alone.

Three approaches are compared:

  • ColPali-v1.2 — patch-level multi-vector embeddings from a vision-language model (PaliGemma), scored with MaxSim late interaction. No OCR.
  • SigLIP-SO400M — a single global embedding per page from Google’s CLIP successor.
  • OCR + BM25 — a text-only baseline that OCRs each page with docTR on GPU and ranks with BM25.

Each experiment is one ExperimentConfig; the pipeline fans them out as concurrent Flyte tasks and returns a ranked comparison table with an interactive HTML report. It’s a strong showcase of several Flyte features working together:

  • ReusePolicy keeps warm GPU containers (with the ~7 GB ColPali weights already in VRAM) alive across task calls.
  • A process-level DynamicBatcher aggregates queries from all concurrent search tasks into single GPU batches.
  • cache="auto" so a model’s index is built at most once per corpus and shared across experiments.
  • Typed Pydantic inputs/outputs so every metric is stored alongside the exact config that produced it.

Define the container image

One image serves every task. unionai-reuse provides the actor bridge required by ReusePolicy.

retrieval_eval.py
image = (
    flyte.Image.from_uv_script(__file__, name="vidore-eval-v2")
    .with_apt_packages("ca-certificates", "libxcb1", "libgl1", "libglib2.0-0")
    # unionai-reuse installs the unionai-actor-bridge binary required by ReusePolicy.
    # Without it every reusable container exits with StartError (exit code 128).
    .with_pip_packages("unionai-reuse>=0.1.11")
)

The Python dependencies (ColPali, transformers, docTR, etc.) are declared in the uv script header at the top of the file.

Define the task environments

Each model gets its own GPU environment so their warm-container pools scale independently. The ColPali and SigLIP environments use ReusePolicy to keep model weights resident; the driver coordinates orchestration, BM25, evaluation, and reporting.

retrieval_eval.py
colpali_indexer = flyte.TaskEnvironment(
    name="vidore-colpali-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu="A10G:1"),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for SigLIP image encoding and search.
#
# Separate from the ColPali environment so each model's warm containers
# are managed independently — ColPali and SigLIP experiments can scale
# without contending for the same pool of reusable containers.
siglip_indexer = flyte.TaskEnvironment(
    name="vidore-siglip-indexer",
    image=image,
    resources=flyte.Resources(cpu=4, memory="8Gi", gpu=1),
    reusable=flyte.ReusePolicy(
        replicas=1,
        concurrency=8,
        idle_ttl=120,
        scaledown_ttl=60,
    ),
)

# GPU environment for doctr OCR. doctr runs DBNet (detection) + CRNN (recognition)
# in batches on GPU — much faster than CPU Tesseract.
# No ReusePolicy needed: the result is cached, so this task runs at most once.
ocr_engine = flyte.TaskEnvironment(
    name="vidore-ocr-engine",
    image=image,
    resources=flyte.Resources(cpu=4, memory="20Gi", gpu=1),
)

# Driver: orchestration, BM25 search, evaluation, and reporting.
# depends_on ensures the shared Docker image is built before all environments
# try to schedule tasks.
driver = flyte.TaskEnvironment(
    name="vidore-driver",
    image=image,
    resources=flyte.Resources(cpu=2, memory="12Gi"),
    depends_on=[colpali_indexer, siglip_indexer, ocr_engine],
)

Configuration and data types

An experiment is fully described by an ExperimentConfig. Because it’s a Pydantic model, Flyte serializes it alongside every output.

retrieval_eval.py
class RetrievalModel(str, enum.Enum):
    """Retrieval backend to evaluate."""

    COLPALI = "colpali-v1.2"  # multi-vector patch embeddings, MaxSim
    SIGLIP = "siglip-so400m"  # single-vector global embedding, cosine sim
    OCR_BM25 = "ocr+bm25"  # text extracted by Tesseract, ranked by BM25


class ExperimentConfig(BaseModel):
    """
    All knobs for one retrieval experiment. Passed as a typed Flyte input.

    Because ExperimentConfig is a Pydantic model, Flyte serialises it
    alongside every task output — so you can always reconstruct which
    config produced which metric without maintaining a separate log.
    """

    name: str  # human-readable label shown in the comparison table
    model: RetrievalModel
    top_k: int = 5  # number of pages to retrieve per query

The corpus, queries, retrieval results, and metrics are likewise typed. Page images are stored as flyte.io.File handles in blob storage, so tasks read images directly rather than re-fetching over HTTP.

retrieval_eval.py
class PageQuery(BaseModel):
    """One retrieval query with its ground-truth page."""

    query_id: str
    text: str  # e.g. "What was revenue growth in Q3?"
    relevant_page_id: str  # one correct page per query


class PageDataset(BaseModel):
    """
    A corpus of document page images paired with text queries.

    page_ids:   unique page identifiers (derived from ViDoRe image filenames).
    page_files: the same pages stored in Flyte's blob store as JPEG File
                handles. Tasks read images directly from here; no live HTTP.
    queries:    text questions with ground-truth page IDs for evaluation.
    """

    page_ids: list[str]
    page_files: list[File]
    queries: list[PageQuery]

    class Config:
        arbitrary_types_allowed = True


class RetrievalResult(BaseModel):
    query_id: str
    ranked_page_ids: list[str]  # ordered best → worst


class Metrics(BaseModel):
    recall_at_k: float
    ndcg_at_k: float
    mrr: float
    k: int


class ExperimentResult(BaseModel):
    config: ExperimentConfig
    metrics: Metrics

load_vidore_pages downloads a ViDoRe subset and uploads each page image to blob storage (cached, with retries). Indexing tasks (index_colpali, index_siglip) encode every page into a .npz index, and the OCR task (extract_page_texts) produces the text baseline. These run on the GPU environments and are cached per corpus.

Search uses the DynamicBatcher so queries from all concurrent search-task invocations on a warm container are merged into a single GPU batch:

retrieval_eval.py
@colpali_indexer.task
async def search_colpali(
    index_file: File,
    queries: list[PageQuery],
    top_k: int,
) -> list[RetrievalResult]:
    """
    Retrieve pages using ColPali MaxSim late interaction via DynamicBatcher.

    MaxSim score for page p given query q:
        score(q, p) = Σ_{t ∈ query tokens} max_{j ∈ page patches} (q_t · p_j)

    Each query is submitted to the process-level DynamicBatcher, which
    aggregates queries from all concurrent search_colpali invocations on the
    same warm container (concurrency=8) into a single GPU batch. This keeps
    the GPU saturated rather than running one small batch per caller.

    The batcher's process_fn runs GPU work in asyncio.to_thread, so the
    aggregation loop stays live while the GPU encodes and scores.
    """
    batcher = await _get_colpali_search_batcher(index_file)
    futures = await batcher.submit_batch(queries)
    all_ranked: list[list[str]] = list(await asyncio.gather(*futures))

    return [
        RetrievalResult(query_id=q.query_id, ranked_page_ids=ranked[:top_k])
        for q, ranked in zip(queries, all_ranked)
    ]

The DynamicBatcher implementation lives in the extras/ package next to the example. Run the script from the example directory so the import resolves.

Run one experiment

run_experiment selects the right index/search path based on the runtime value of config.model — Flyte v2’s dynamic execution means there’s no static DAG to wire up. flyte.group wraps each experiment in a named span in the UI.

retrieval_eval.py
@driver.task
async def run_experiment(config: ExperimentConfig, dataset: PageDataset) -> ExperimentResult:
    """
    End-to-end retrieval pipeline for a single ExperimentConfig.

    Flyte v2's dynamic execution means this driver task can call GPU tasks
    (index_colpali, search_colpali) based on the runtime value of config.model
    — no static DAG wiring required. The if/elif is plain Python; Flyte
    schedules the selected sub-tasks on the appropriate environment.

    Caching: two experiments that share the same model and corpus (e.g. ColPali
    at top_k=5 and top_k=10) will hit the same cached index. GPU work is paid
    at most once per (model, corpus) pair across all experiments.

    Search queries are sharded into chunks of SEARCH_SHARD_SIZE and dispatched
    as concurrent task invocations. All shards land on the single warm container
    (replicas=1) and feed the same DynamicBatcher simultaneously, keeping the
    GPU saturated throughout search rather than processing one large sequential
    batch from a single caller.

    flyte.group wraps each experiment in a named span in the Flyte UI, making
    it easy to compare latencies and drill into individual runs.
    """
    SEARCH_SHARD_SIZE = 256

    with flyte.group(config.name):
        if config.model == RetrievalModel.COLPALI:
            index_file = await index_colpali(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_colpali(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        elif config.model == RetrievalModel.SIGLIP:
            index_file = await index_siglip(dataset.page_ids, dataset.page_files)
            shards = list(_batches(dataset.queries, SEARCH_SHARD_SIZE))
            shard_results = await asyncio.gather(
                *[search_siglip(index_file, shard, config.top_k) for shard in shards]
            )
            results = [r for shard in shard_results for r in shard]

        else:  # RetrievalModel.OCR_BM25
            page_texts = await extract_page_texts(dataset.page_files)
            results = await search_bm25(page_texts, dataset.page_ids, dataset.queries, config.top_k)

        metrics = await evaluate(results, dataset.queries, config.top_k)

    return ExperimentResult(config=config, metrics=metrics)

Compare experiments

The driver loads the dataset once, fans out across all configs with asyncio.gather, and emits an interactive Chart.js report in the Flyte UI. Experiments sharing a model reuse the cached index, so you only pay GPU time for new work.

retrieval_eval.py
@driver.task
async def compare_experiments(
    configs: list[ExperimentConfig],
    subset: str = "docvqa",
    max_pages: int = 200,
) -> ComparisonReport:
    """
    Fan out over all experiment configs and return a ranked comparison table.

    The dataset is loaded once and shared across all experiments. Each config
    runs as a concurrent Flyte task via asyncio.gather. Experiments that share
    a model reuse the cached index — you only pay GPU time for new work.

    On completion, generate_report emits an interactive Chart.js HTML report
    visible directly in the Flyte execution detail page.

    Default dataset: vidore_v3_finance_en (~2 942 corpus pages, 1 854 queries)
    with max_pages=2 000 to exercise the GPU pipeline at scale.
    """
    dataset = await load_vidore_pages(subset=subset, max_pages=max_pages)

    # All experiments launch concurrently. Shared cached outputs (same model,
    # same corpus) are served from cache rather than recomputed.
    experiment_coros = [run_experiment(config=cfg, dataset=dataset) for cfg in configs]
    results: list[ExperimentResult] = list(await asyncio.gather(*experiment_coros))

    report = ComparisonReport(results=results)
    print(report.summary())
    best = report.best_by("recall_at_k")
    print(f"\nBest by Recall@{best.metrics.k}: {best.config.name}")

    # Emit the interactive HTML report in the Flyte UI.
    await generate_report(report)

    return report

Run the evaluation

This example has no secrets — datasets and model weights are pulled from public Hugging Face repositories. It does require GPUs, so run it remotely.

The experiment grid is defined in the entry point; adding a model or varying top_k is a one-line change:

retrieval_eval.py
    configs = [
        ExperimentConfig(name="colpali-top5", model=RetrievalModel.COLPALI, top_k=5),
        ExperimentConfig(name="colpali-top10", model=RetrievalModel.COLPALI, top_k=10),
        ExperimentConfig(name="siglip-top5", model=RetrievalModel.SIGLIP, top_k=5),
        ExperimentConfig(name="ocr-bm25-top5", model=RetrievalModel.OCR_BM25, top_k=5),
    ]

From the example directory:

cd v2/tutorials/multimodal-retrieval-evaluation
python retrieval_eval.py

When the run completes, open the generate_report task in the UI to see the summary cards, the grouped Recall@K / NDCG@K / MRR bar chart, and the ranked results table.