Spark

The Spark plugin lets you run Apache Spark jobs natively on Kubernetes. Flyte manages the full cluster lifecycle: provisioning a transient Spark cluster for each task execution, running the job, and tearing the cluster down on completion.

Under the hood, the plugin uses the Spark on Kubernetes Operator to create and manage Spark applications. No external Spark service or long-running cluster is required.

When to use this plugin

  • Large-scale data processing and ETL pipelines
  • Jobs that benefit from Spark’s distributed execution engine (Spark SQL, PySpark, Spark MLlib)
  • Workloads that need Hadoop-compatible storage access (S3, GCS, HDFS)

Installation

pip install flyteplugins-spark

Configuration

Create a Spark configuration and pass it as plugin_config to a TaskEnvironment:

from flyteplugins.spark import Spark

spark_config = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    plugin_config=spark_config,
    image=image,
)

Spark parameters

Parameter Type Description
spark_conf Dict[str, str] Spark configuration key-value pairs (e.g., executor memory, cores, instances)
hadoop_conf Dict[str, str] Hadoop configuration key-value pairs (e.g., S3/GCS access settings)
executor_path str Path to the Python binary for PySpark executors
applications_path str Path to the main Spark application file
driver_pod PodTemplate Pod template for the Spark driver pod
executor_pod PodTemplate Pod template for the Spark executor pods

Accessing the Spark session

Inside a Spark task, the SparkSession is available through the task context:

from flyte._context import internal_ctx

@spark_env.task
async def my_spark_task() -> float:
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    # Use spark as a normal SparkSession
    df = spark.read.parquet("s3://my-bucket/data.parquet")
    return df.count()

Overriding configuration at runtime

You can override Spark configuration for individual task calls using .override():

from copy import deepcopy

updated_config = deepcopy(spark_config)
updated_config.spark_conf["spark.executor.instances"] = "4"

result = await my_spark_task.override(plugin_config=updated_config)()

Example

spark_example.py
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///

import random
from copy import deepcopy
from operator import add

from flyteplugins.spark.task import Spark

import flyte.remote
from flyte._context import internal_ctx

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
    .with_pip_packages("flyteplugins-spark", pre=True)
)

task_env = flyte.TaskEnvironment(
    name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
        "spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
        "spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
    plugin_config=spark_conf,
    image=image,
    depends_on=[task_env],
)


def f(_):
    x = random.random() * 2 - 1
    y = random.random() * 2 - 1
    return 1 if x**2 + y**2 <= 1 else 0


@task_env.task
async def get_pi(count: int, partitions: int) -> float:
    return 4.0 * count / partitions


@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
    n = 1 * partitions
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)

    return await get_pi(count, partitions)


@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
    updated_spark_conf = deepcopy(spark_conf)
    updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
    return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_spark_nested)
    print(r.name)
    print(r.url)
    r.wait()

API reference

See the Spark API reference for full details.