Dask

Flyte can execute Dask jobs natively on a Kubernetes Cluster, which manages a cluster’s lifecycle, spin-up, and tear down. It leverages the open-sourced Dask Kubernetes Operator and can be enabled without signing up for any service. This is like running a transient Dask cluster — a type of cluster spun up for a specific Dask job and torn down after completion.

To install the plugin, run the following command:

Install the plugin

To install the Dask plugin, run the following command:

$ pip install --pre flyteplugins-dask

The following example shows how to configure Dask in a TaskEnvironment. Flyte automatically provisions a Dask cluster for each task using this configuration:

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte==2.0.0b31",
#    "flyteplugins-dask",
#    "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///

import asyncio
import typing

from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

import flyte.remote
import flyte.storage
from flyte import Resources

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

task_env = flyte.TaskEnvironment(
    name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
    resources=Resources(cpu="1", memory="1Gi"),
    depends_on=[task_env],
)


@task_env.task()
async def hello_dask():
    await asyncio.sleep(5)
    print("Hello from the Dask task!")


@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
    print("running dask task")
    t = asyncio.create_task(hello_dask())
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    res = client.gather(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_dask_nested)
    print(r.name)
    print(r.url)
    r.wait()