Pandera
The
Pandera plugin validates dataframes at task boundaries using
DataFrameModel schemas. When a task receives or
returns a pandera-typed dataframe, the plugin automatically validates the data, raises or warns on schema violations,
and writes an HTML validation report to the Flyte deck.
Pandera supports multiple dataframe backends. The flyteplugins-pandera plugin handles:
| Pandera typing module | DataFrame library | Additional plugin |
|---|---|---|
pandera.typing.pandas |
pandas | — |
pandera.typing.polars |
Polars (eager and lazy) | flyteplugins-polars |
pandera.typing.pyspark_sql |
PySpark SQL | flyteplugins-spark |
When to use this plugin
- You want compile-time-style guarantees that data flowing between tasks conforms to a declared schema
- You need column-level type, constraint, and statistical checks on task inputs and outputs
- You want automatic validation reports visible in the Flyte UI
Installation
Install the plugin with the pandera extras for your dataframe backend:
Defining schemas
Schemas are defined as Python classes that inherit from pandera’s DataFrameModel. Each field declares a column name,
type, and optional constraints:
import pandera.pandas as pa
class EmployeeSchema(pa.DataFrameModel):
employee_id: int = pa.Field(ge=0)
name: str
class EmployeeSchemaWithStatus(EmployeeSchema):
status: str = pa.Field(isin=["active", "inactive"])Schemas compose through inheritance: EmployeeSchemaWithStatus includes all columns from EmployeeSchema plus the
status column.
For full details on schema definition—including custom checks, regex column matching, and Config options—see the
pandera DataFrameModel documentation.
Using schemas in tasks
Annotate task inputs and outputs with pandera’s generic DataFrame type. The plugin validates data on every
encode (output) and decode (input):
import pandera.typing.pandas as pt
@env.task(report=True)
async def build_employees() -> pt.DataFrame[EmployeeSchema]:
return pd.DataFrame({
"employee_id": [1, 2, 3],
"name": ["Ada", "Grace", "Barbara"],
})
@env.task(report=True)
async def add_status(
df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
return df.assign(status="active")Setting report=True on the task makes validation reports visible as deck tabs in the Flyte UI.
Error handling with ValidationConfig
By default, a validation failure raises an exception and fails the task. To downgrade failures to warnings instead,
annotate the parameter with ValidationConfig(on_error="warn"):
from typing import Annotated
from flyteplugins.pandera import ValidationConfig
@env.task(report=True)
async def lenient_pass_through(
df: Annotated[pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")],
) -> Annotated[pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")]:
...on_error value |
Behavior |
|---|---|
"raise" (default) |
Validation failure raises pandera.errors.SchemaError and the task fails |
"warn" |
Validation failure logs a warning and writes the report, but the task continues |
You can mix "raise" and "warn" across inputs and outputs of the same task. For example, use "warn" on inputs
to accept best-effort data while still enforcing strict output contracts.
Image configuration
Include the plugin in your task image. The exact setup depends on your dataframe backend:
import flyte
img = flyte.Image.from_debian_base(
python_version=(3, 12),
).with_pip_packages("flyteplugins-pandera")
env = flyte.TaskEnvironment(
"pandera_pandas",
image=img,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)import flyte
img = (
flyte.Image.from_debian_base(python_version=(3, 12))
.with_pip_packages("flyteplugins-polars", "pandera[polars]")
)
env = flyte.TaskEnvironment(
"pandera_polars",
image=img,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)import flyte
from flyteplugins.spark.task import Spark
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
.with_pip_packages("flyteplugins-spark", "pandera[pyspark]")
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)
env = flyte.TaskEnvironment(
name="pandera_pyspark",
plugin_config=spark_conf,
image=image,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)Polars lazy frames
The Polars backend supports both pt.DataFrame (eager) and pt.LazyFrame (lazy). With lazy frames, pandera
validates the data when the frame is materialized at task I/O boundaries:
import pandera.typing.polars as pt
import polars as pl
@env.task(report=True)
async def create_lazy() -> pt.LazyFrame[MetricsSchema]:
return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})
@env.task(report=True)
async def consume_lazy(
lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
return lf.filter(pl.col("value") > 0.0).collect()Examples
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte",
# "flyteplugins-pandera",
# "pandera[pandas]",
# ]
# main = "main"
# ///
from __future__ import annotations
from typing import Annotated
import pandas as pd
import pandera.pandas as pa
import pandera.typing.pandas as pt
from flyteplugins.pandera import ValidationConfig
import flyte
img = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
"flyteplugins-pandera", "pandera[pandas]"
)
env = flyte.TaskEnvironment(
"pandera_pandas_schema",
image=img,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
class EmployeeSchema(pa.DataFrameModel):
employee_id: int = pa.Field(ge=0)
name: str
class EmployeeSchemaWithStatus(EmployeeSchema):
status: str = pa.Field(isin=["active", "inactive"])
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
return pd.DataFrame(
{
"employee_id": [1, 2, 3],
"name": ["Ada", "Grace", "Barbara"],
}
)
@env.task(report=True)
async def pass_through(
df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
return df.assign(status="active")
@env.task(report=True)
async def pass_through_with_error_warn(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
del df["name"]
return df
@env.task(report=True)
async def pass_through_with_error_raise(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
del df["name"]
return df
@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
df = await build_valid_employees()
df2 = await pass_through(df)
await pass_through_with_error_warn(df.drop(["employee_id"], axis="columns"))
await pass_through_with_error_warn(df.assign(employee_id=-1))
try:
await pass_through_with_error_raise(df)
except Exception as exc:
print(exc)
return df2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
print("pandas pandera example OK:", run.outputs()[0])
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-pandera",
# "flyteplugins-polars",
# "pandera[polars]",
# ]
# main = "main"
# ///
from __future__ import annotations
from typing import Annotated
import pandera.polars as pa
import pandera.typing.polars as pt
import polars as pl
from flyteplugins.pandera import ValidationConfig
import flyte
img = (
flyte.Image.from_debian_base(python_version=(3, 12))
.with_pip_packages("flyteplugins-pandera", "flyteplugins-polars", "pandera[polars]")
)
env = flyte.TaskEnvironment(
"pandera_polars_schema",
image=img,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
class EmployeeSchema(pa.DataFrameModel):
employee_id: int = pa.Field(ge=0)
name: str
class EmployeeSchemaWithStatus(EmployeeSchema):
status: str = pa.Field(isin=["active", "inactive"])
class MetricsSchema(pa.DataFrameModel):
item: str
value: float
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
return pl.DataFrame(
{
"employee_id": [1, 2, 3],
"name": ["Ada", "Grace", "Barbara"],
}
)
@env.task(report=True)
async def pass_through(
df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
return df.with_columns(pl.lit("active").alias("status"))
@env.task(report=True)
async def pass_through_with_error_warn(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
return df.drop("name")
@env.task(report=True)
async def pass_through_with_error_raise(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
return df.drop("name")
@env.task(report=True)
async def metrics_eager() -> pt.DataFrame[MetricsSchema]:
return pl.DataFrame({"item": ["a", "b"], "value": [1.0, 2.0]})
@env.task(report=True)
async def metrics_lazy() -> pt.LazyFrame[MetricsSchema]:
return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})
@env.task(report=True)
async def filter_metrics(
lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
return lf.filter(pl.col("value") > 0.0).collect()
@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
df = await build_valid_employees()
df2 = await pass_through(df)
await pass_through_with_error_warn(df.drop("employee_id"))
await pass_through_with_error_warn(
df.with_columns(pl.lit(-1).alias("employee_id"))
)
try:
await pass_through_with_error_raise(df)
except Exception as exc:
print(exc)
_ = await metrics_eager()
lazy = await metrics_lazy()
_ = await filter_metrics(lazy)
return df2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
print("polars pandera example OK:", run.outputs()[0])
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-pandera",
# "flyteplugins-spark",
# "pandera[pyspark]",
# ]
# main = "main"
# ///
from __future__ import annotations
from typing import Annotated, cast
import pandera.typing.pyspark_sql as pt
import pyspark.sql.types as T
from flyteplugins.pandera import ValidationConfig
from flyteplugins.spark.task import Spark
from pandera.pyspark import DataFrameModel, Field
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import flyte
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
.with_pip_packages(
"flyteplugins-pandera",
"flyteplugins-spark",
"pandera[pyspark]",
)
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
"spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
"spark.jars": (
"https://storage.googleapis.com/hadoop-lib/gcs/"
"gcs-connector-hadoop3-latest.jar,"
"https://repo1.maven.org/maven2/org/apache/hadoop/"
"hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,"
"https://repo1.maven.org/maven2/com/amazonaws/"
"aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar"
),
},
)
env = flyte.TaskEnvironment(
name="pandera_pyspark_sql_schema",
plugin_config=spark_conf,
image=image,
resources=flyte.Resources(cpu="1", memory="2Gi"),
)
class EmployeeSchema(DataFrameModel):
employee_id: int = Field(ge=0)
name: str = Field()
job_title: str = Field()
class EmployeeSchemaWithStatus(EmployeeSchema):
status: str = Field(isin=["active", "inactive"])
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
spark = cast(SparkSession, flyte.ctx().data["spark_session"])
data = [
(1, "Ada", "Engineer"),
(2, "Grace", "Mathematician"),
(3, "Barbara", "Computer scientist"),
]
schema = T.StructType(
[
T.StructField("employee_id", T.IntegerType(), False),
T.StructField("name", T.StringType(), False),
T.StructField("job_title", T.StringType(), False),
]
)
return spark.createDataFrame(data, schema=schema)
@env.task(report=True)
async def pass_through(
df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
return df.withColumn("status", F.lit("active"))
@env.task(report=True)
async def pass_through_with_error_warn(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
return df.drop("name")
@env.task(report=True)
async def pass_through_with_error_raise(
df: Annotated[
pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
],
) -> Annotated[
pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
return df.drop("name")
@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
df = await build_valid_employees()
df2 = await pass_through(df)
await pass_through_with_error_warn(df.drop("employee_id"))
await pass_through_with_error_warn(df.withColumn("employee_id", F.lit(-1)))
try:
await pass_through_with_error_raise(df)
except Exception as exc:
print(exc)
return df2
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(main)
print(run.url)
run.wait()
print("pyspark_sql pandera example OK:", run.outputs()[0])