Flyte 2 is available today for local execution - distributed execution coming to open source soon. Preview Flyte 2 for production, hosted on Union.ai

Pandera

The Pandera plugin validates dataframes at task boundaries using DataFrameModel schemas. When a task receives or returns a pandera-typed dataframe, the plugin automatically validates the data, raises or warns on schema violations, and writes an HTML validation report to the Flyte deck.

Pandera supports multiple dataframe backends. The flyteplugins-pandera plugin handles:

Pandera typing module DataFrame library Additional plugin
pandera.typing.pandas pandas
pandera.typing.polars Polars (eager and lazy) flyteplugins-polars
pandera.typing.pyspark_sql PySpark SQL flyteplugins-spark

When to use this plugin

  • You want compile-time-style guarantees that data flowing between tasks conforms to a declared schema
  • You need column-level type, constraint, and statistical checks on task inputs and outputs
  • You want automatic validation reports visible in the Flyte UI

Installation

Install the plugin with the pandera extras for your dataframe backend:

pandasPolarsPySpark SQL ```bash pip install flyteplugins-pandera 'pandera[pandas]' ``` ```bash pip install flyteplugins-pandera flyteplugins-polars 'pandera[polars]' ``` ```bash pip install flyteplugins-pandera flyteplugins-spark 'pandera[pyspark]' ```

Defining schemas

Schemas are defined as Python classes that inherit from pandera’s DataFrameModel. Each field declares a column name, type, and optional constraints:

import pandera.pandas as pa

class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str

class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])

Schemas compose through inheritance: EmployeeSchemaWithStatus includes all columns from EmployeeSchema plus the status column.

For full details on schema definition—including custom checks, regex column matching, and Config options—see the pandera DataFrameModel documentation.

Using schemas in tasks

Annotate task inputs and outputs with pandera’s generic DataFrame type. The plugin validates data on every encode (output) and decode (input):

import pandera.typing.pandas as pt

@env.task(report=True)
async def build_employees() -> pt.DataFrame[EmployeeSchema]:
    return pd.DataFrame({
        "employee_id": [1, 2, 3],
        "name": ["Ada", "Grace", "Barbara"],
    })

@env.task(report=True)
async def add_status(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.assign(status="active")

Setting report=True on the task makes validation reports visible as deck tabs in the Flyte UI.

Error handling with ValidationConfig

By default, a validation failure raises an exception and fails the task. To downgrade failures to warnings instead, annotate the parameter with ValidationConfig(on_error="warn"):

from typing import Annotated
from flyteplugins.pandera import ValidationConfig

@env.task(report=True)
async def lenient_pass_through(
    df: Annotated[pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")],
) -> Annotated[pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")]:
    ...
on_error value Behavior
"raise" (default) Validation failure raises pandera.errors.SchemaError and the task fails
"warn" Validation failure logs a warning and writes the report, but the task continues

You can mix "raise" and "warn" across inputs and outputs of the same task. For example, use "warn" on inputs to accept best-effort data while still enforcing strict output contracts.

Image configuration

Include the plugin in your task image. The exact setup depends on your dataframe backend:

PandasPolarsPySpark SQL
import flyte

img = flyte.Image.from_debian_base(
    python_version=(3, 12),
).with_pip_packages("flyteplugins-pandera")

env = flyte.TaskEnvironment(
    "pandera_pandas",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)
import flyte

img = (
    flyte.Image.from_debian_base(python_version=(3, 12))
    .with_pip_packages("flyteplugins-polars", "pandera[polars]")
)

env = flyte.TaskEnvironment(
    "pandera_polars",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)
import flyte
from flyteplugins.spark.task import Spark

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
    .with_pip_packages("flyteplugins-spark", "pandera[pyspark]")
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "1000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
)

env = flyte.TaskEnvironment(
    name="pandera_pyspark",
    plugin_config=spark_conf,
    image=image,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)

Polars lazy frames

The Polars backend supports both pt.DataFrame (eager) and pt.LazyFrame (lazy). With lazy frames, pandera validates the data when the frame is materialized at task I/O boundaries:

import pandera.typing.polars as pt
import polars as pl

@env.task(report=True)
async def create_lazy() -> pt.LazyFrame[MetricsSchema]:
    return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})

@env.task(report=True)
async def consume_lazy(
    lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
    return lf.filter(pl.col("value") > 0.0).collect()

Examples

pandasPolarsPySpark SQL
pandas_schema.py
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte",
#    "flyteplugins-pandera",
#    "pandera[pandas]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated

import pandas as pd
import pandera.pandas as pa
import pandera.typing.pandas as pt
from flyteplugins.pandera import ValidationConfig

import flyte

img = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "flyteplugins-pandera", "pandera[pandas]"
)

env = flyte.TaskEnvironment(
    "pandera_pandas_schema",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)


class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str


class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])


@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    return pd.DataFrame(
        {
            "employee_id": [1, 2, 3],
            "name": ["Ada", "Grace", "Barbara"],
        }
    )


@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.assign(status="active")


@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    del df["name"]
    return df


@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    del df["name"]
    return df


@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop(["employee_id"], axis="columns"))
    await pass_through_with_error_warn(df.assign(employee_id=-1))

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    return df2


if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("pandas pandera example OK:", run.outputs()[0])

polars_schema.py
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pandera",
#    "flyteplugins-polars",
#    "pandera[polars]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated

import pandera.polars as pa
import pandera.typing.polars as pt
import polars as pl
from flyteplugins.pandera import ValidationConfig

import flyte

img = (
    flyte.Image.from_debian_base(python_version=(3, 12))
    .with_pip_packages("flyteplugins-pandera", "flyteplugins-polars", "pandera[polars]")
)

env = flyte.TaskEnvironment(
    "pandera_polars_schema",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)


class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str


class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])


class MetricsSchema(pa.DataFrameModel):
    item: str
    value: float


@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    return pl.DataFrame(
        {
            "employee_id": [1, 2, 3],
            "name": ["Ada", "Grace", "Barbara"],
        }
    )


@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.with_columns(pl.lit("active").alias("status"))


@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    return df.drop("name")


@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    return df.drop("name")


@env.task(report=True)
async def metrics_eager() -> pt.DataFrame[MetricsSchema]:
    return pl.DataFrame({"item": ["a", "b"], "value": [1.0, 2.0]})


@env.task(report=True)
async def metrics_lazy() -> pt.LazyFrame[MetricsSchema]:
    return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})


@env.task(report=True)
async def filter_metrics(
    lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
    return lf.filter(pl.col("value") > 0.0).collect()


@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop("employee_id"))
    await pass_through_with_error_warn(
        df.with_columns(pl.lit(-1).alias("employee_id"))
    )

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    _ = await metrics_eager()
    lazy = await metrics_lazy()
    _ = await filter_metrics(lazy)

    return df2


if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("polars pandera example OK:", run.outputs()[0])

pyspark_sql_schema.py
# /// script
# requires-python = ">=3.10"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pandera",
#    "flyteplugins-spark",
#    "pandera[pyspark]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated, cast

import pandera.typing.pyspark_sql as pt
import pyspark.sql.types as T
from flyteplugins.pandera import ValidationConfig
from flyteplugins.spark.task import Spark
from pandera.pyspark import DataFrameModel, Field
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

import flyte

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
    .with_pip_packages(
        "flyteplugins-pandera",
        "flyteplugins-spark",
        "pandera[pyspark]",
    )
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "1000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
        "spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
        "spark.jars": (
            "https://storage.googleapis.com/hadoop-lib/gcs/"
            "gcs-connector-hadoop3-latest.jar,"
            "https://repo1.maven.org/maven2/org/apache/hadoop/"
            "hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,"
            "https://repo1.maven.org/maven2/com/amazonaws/"
            "aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar"
        ),
    },
)

env = flyte.TaskEnvironment(
    name="pandera_pyspark_sql_schema",
    plugin_config=spark_conf,
    image=image,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)


class EmployeeSchema(DataFrameModel):
    employee_id: int = Field(ge=0)
    name: str = Field()
    job_title: str = Field()


class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = Field(isin=["active", "inactive"])


@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    spark = cast(SparkSession, flyte.ctx().data["spark_session"])
    data = [
        (1, "Ada", "Engineer"),
        (2, "Grace", "Mathematician"),
        (3, "Barbara", "Computer scientist"),
    ]
    schema = T.StructType(
        [
            T.StructField("employee_id", T.IntegerType(), False),
            T.StructField("name", T.StringType(), False),
            T.StructField("job_title", T.StringType(), False),
        ]
    )
    return spark.createDataFrame(data, schema=schema)


@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.withColumn("status", F.lit("active"))


@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    return df.drop("name")


@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    return df.drop("name")


@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop("employee_id"))
    await pass_through_with_error_warn(df.withColumn("employee_id", F.lit(-1)))

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    return df2


if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("pyspark_sql pandera example OK:", run.outputs()[0])