Spark
The Spark plugin lets you run Apache Spark jobs natively on Kubernetes. Flyte manages the full cluster lifecycle: provisioning a transient Spark cluster for each task execution, running the job, and tearing the cluster down on completion.
Under the hood, the plugin uses the Spark on Kubernetes Operator to create and manage Spark applications. No external Spark service or long-running cluster is required.
When to use this plugin
- Large-scale data processing and ETL pipelines
- Jobs that benefit from Spark’s distributed execution engine (Spark SQL, PySpark, Spark MLlib)
- Workloads that need Hadoop-compatible storage access (S3, GCS, HDFS)
Installation
pip install flyteplugins-sparkConfiguration
Create a Spark configuration and pass it as plugin_config to a TaskEnvironment:
from flyteplugins.spark import Spark
spark_config = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_config,
image=image,
)Spark parameters
| Parameter | Type | Description |
|---|---|---|
spark_conf |
Dict[str, str] |
Spark configuration key-value pairs (e.g., executor memory, cores, instances) |
hadoop_conf |
Dict[str, str] |
Hadoop configuration key-value pairs (e.g., S3/GCS access settings) |
executor_path |
str |
Path to the Python binary for PySpark executors |
applications_path |
str |
Path to the main Spark application file |
driver_pod |
PodTemplate |
Pod template for the Spark driver pod |
executor_pod |
PodTemplate |
Pod template for the Spark executor pods |
Accessing the Spark session
Inside a Spark task, the SparkSession is available through the task context:
from flyte._context import internal_ctx
@spark_env.task
async def my_spark_task() -> float:
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
# Use spark as a normal SparkSession
df = spark.read.parquet("s3://my-bucket/data.parquet")
return df.count()Overriding configuration at runtime
You can override Spark configuration for individual task calls using .override():
from copy import deepcopy
updated_config = deepcopy(spark_config)
updated_config.spark_conf["spark.executor.instances"] = "4"
result = await my_spark_task.override(plugin_config=updated_config)()Example
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///
import random
from copy import deepcopy
from operator import add
from flyteplugins.spark.task import Spark
import flyte.remote
from flyte._context import internal_ctx
image = (
flyte.Image.from_base("apache/spark-py:v3.4.0")
.clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
.with_pip_packages("flyteplugins-spark", pre=True)
)
task_env = flyte.TaskEnvironment(
name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
"spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
"spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
plugin_config=spark_conf,
image=image,
depends_on=[task_env],
)
def f(_):
x = random.random() * 2 - 1
y = random.random() * 2 - 1
return 1 if x**2 + y**2 <= 1 else 0
@task_env.task
async def get_pi(count: int, partitions: int) -> float:
return 4.0 * count / partitions
@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
n = 1 * partitions
ctx = internal_ctx()
spark = ctx.data.task_context.data["spark_session"]
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)
return await get_pi(count, partitions)
@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
updated_spark_conf = deepcopy(spark_conf)
updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_spark_nested)
print(r.name)
print(r.url)
r.wait()
API reference
See the Spark API reference for full details.