Dask
The Dask plugin lets you run Dask jobs natively on Kubernetes. Flyte provisions a transient Dask cluster for each task execution using the Dask Kubernetes Operator and tears it down on completion.
When to use this plugin
- Parallel Python workloads that outgrow a single machine
- Distributed DataFrame operations on large datasets
- Workloads that use Dask’s task scheduler for arbitrary computation graphs
- Jobs that need to scale NumPy, pandas, or scikit-learn workflows across multiple nodes
Installation
pip install flyteplugins-daskYour task image must also include the Dask distributed scheduler:
image = flyte.Image.from_debian_base(name="dask").with_pip_packages("flyteplugins-dask")Configuration
Create a Dask configuration and pass it as plugin_config to a TaskEnvironment:
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
)Dask parameters
| Parameter | Type | Description |
|---|---|---|
scheduler |
Scheduler |
Scheduler pod configuration (defaults to Scheduler()) |
workers |
WorkerGroup |
Worker group configuration (defaults to WorkerGroup()) |
Scheduler parameters
| Parameter | Type | Description |
|---|---|---|
image |
str |
Custom scheduler image (must include dask[distributed]) |
resources |
Resources |
Resource requests for the scheduler pod |
WorkerGroup parameters
| Parameter | Type | Description |
|---|---|---|
number_of_workers |
int |
Number of worker pods (default: 1) |
image |
str |
Custom worker image (must include dask[distributed]) |
resources |
Resources |
Resource requests per worker pod |
The scheduler and all workers should use the same Python environment to avoid serialization issues.
Accessing the Dask client
Inside a Dask task, create a distributed.Client() with no arguments. It automatically connects to the provisioned cluster:
from distributed import Client
@dask_env.task
async def my_dask_task(n: int) -> list:
client = Client()
futures = client.map(lambda x: x + 1, range(n))
return client.gather(futures)Example
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-dask",
# "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///
import asyncio
import typing
from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
import flyte.remote
import flyte.storage
from flyte import Resources
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
task_env = flyte.TaskEnvironment(
name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image=image,
resources=Resources(cpu="1", memory="1Gi"),
depends_on=[task_env],
)
@task_env.task()
async def hello_dask():
await asyncio.sleep(5)
print("Hello from the Dask task!")
@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
print("running dask task")
t = asyncio.create_task(hello_dask())
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_dask_nested)
print(r.name)
print(r.url)
r.wait()
API reference
See the Dask API reference for full details.