Dask

The Dask plugin lets you run Dask jobs natively on Kubernetes. Flyte provisions a transient Dask cluster for each task execution using the Dask Kubernetes Operator and tears it down on completion.

When to use this plugin

  • Parallel Python workloads that outgrow a single machine
  • Distributed DataFrame operations on large datasets
  • Workloads that use Dask’s task scheduler for arbitrary computation graphs
  • Jobs that need to scale NumPy, pandas, or scikit-learn workflows across multiple nodes

Installation

pip install flyteplugins-dask

Your task image must also include the Dask distributed scheduler:

image = flyte.Image.from_debian_base(name="dask").with_pip_packages("flyteplugins-dask")

Configuration

Create a Dask configuration and pass it as plugin_config to a TaskEnvironment:

from flyteplugins.dask import Dask, Scheduler, WorkerGroup

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
)

Dask parameters

Parameter Type Description
scheduler Scheduler Scheduler pod configuration (defaults to Scheduler())
workers WorkerGroup Worker group configuration (defaults to WorkerGroup())

Scheduler parameters

Parameter Type Description
image str Custom scheduler image (must include dask[distributed])
resources Resources Resource requests for the scheduler pod

WorkerGroup parameters

Parameter Type Description
number_of_workers int Number of worker pods (default: 1)
image str Custom worker image (must include dask[distributed])
resources Resources Resource requests per worker pod

The scheduler and all workers should use the same Python environment to avoid serialization issues.

Accessing the Dask client

Inside a Dask task, create a distributed.Client() with no arguments. It automatically connects to the provisioned cluster:

from distributed import Client

@dask_env.task
async def my_dask_task(n: int) -> list:
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    return client.gather(futures)

Example

dask_example.py
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-dask",
#    "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///

import asyncio
import typing

from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

import flyte.remote
import flyte.storage
from flyte import Resources

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

task_env = flyte.TaskEnvironment(
    name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
    resources=Resources(cpu="1", memory="1Gi"),
    depends_on=[task_env],
)


@task_env.task()
async def hello_dask():
    await asyncio.sleep(5)
    print("Hello from the Dask task!")


@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
    print("running dask task")
    t = asyncio.create_task(hello_dask())
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    res = client.gather(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_dask_nested)
    print(r.name)
    print(r.url)
    r.wait()

API reference

See the Dask API reference for full details.