Ray
The Ray plugin lets you run Ray jobs natively on Kubernetes. Flyte provisions a transient Ray cluster for each task execution using KubeRay and tears it down on completion.
When to use this plugin
- Distributed Python workloads (parallel computation, data processing)
- ML training with Ray Train or hyperparameter tuning with Ray Tune
- Ray Serve inference workloads
- Any workload that benefits from Ray’s actor model or task parallelism
Installation
pip install flyteplugins-rayYour task image must also include a compatible version of Ray:
image = (
flyte.Image.from_debian_base(name="ray")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)Configuration
Create a RayJobConfig and pass it as plugin_config to a TaskEnvironment:
from flyteplugins.ray import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
)RayJobConfig parameters
| Parameter | Type | Description |
|---|---|---|
worker_node_config |
List[WorkerNodeConfig] |
Required. List of worker group configurations |
head_node_config |
HeadNodeConfig |
Head node configuration (optional) |
enable_autoscaling |
bool |
Enable Ray autoscaler (default: False) |
runtime_env |
dict |
Ray runtime environment (pip packages, env vars, etc.) |
address |
str |
Connect to an existing Ray cluster instead of provisioning one |
shutdown_after_job_finishes |
bool |
Shut down the cluster after the job completes (default: False) |
ttl_seconds_after_finished |
int |
Seconds to keep the cluster after completion before cleanup |
WorkerNodeConfig parameters
| Parameter | Type | Description |
|---|---|---|
group_name |
str |
Required. Name of this worker group |
replicas |
int |
Required. Number of worker replicas |
min_replicas |
int |
Minimum replicas (for autoscaling) |
max_replicas |
int |
Maximum replicas (for autoscaling) |
ray_start_params |
Dict[str, str] |
Ray start parameters for workers |
requests |
Resources |
Resource requests per worker |
limits |
Resources |
Resource limits per worker |
pod_template |
PodTemplate |
Full pod template (mutually exclusive with requests/limits) |
HeadNodeConfig parameters
| Parameter | Type | Description |
|---|---|---|
ray_start_params |
Dict[str, str] |
Ray start parameters for the head node |
requests |
Resources |
Resource requests for the head node |
limits |
Resources |
Resource limits for the head node |
pod_template |
PodTemplate |
Full pod template (mutually exclusive with requests/limits) |
Connecting to an existing cluster
To connect to an existing Ray cluster instead of provisioning a new one, set the address parameter:
ray_config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
address="ray://existing-cluster:10001",
)Examples
The following example shows how to configure Ray in a TaskEnvironment. Flyte automatically provisions a Ray cluster for each task using this configuration:
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///
import asyncio
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.remote
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)
task_env = flyte.TaskEnvironment(
name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray():
await asyncio.sleep(20)
print("Hello from the Ray task!")
@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
print("running ray task")
t = asyncio.create_task(hello_ray())
futures = [f.remote(i) for i in range(n)]
res = ray.get(futures)
await t
return res
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(hello_ray_nested)
print(r.name)
print(r.url)
r.wait()
The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-ray",
# "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///
import os
import typing
import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import flyte.storage
@ray.remote
def f(x):
return x * x
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
enable_autoscaling=False,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=3600,
)
image = (
flyte.Image.from_debian_base(name="ray")
.with_apt_packages("wget")
.with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
task_env = flyte.TaskEnvironment(
name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
name="ray_cluster",
plugin_config=ray_config,
image=image,
resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
depends_on=[task_env],
)
@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
"""
Run a simple Ray task that connects to an existing Ray cluster.
"""
ray.init(address=f"ray://{cluster_ip}:10001")
futures = [f.remote(i) for i in range(5)]
res = ray.get(futures)
return res
@ray_env.task
async def create_ray_cluster() -> str:
"""
Create a Ray cluster and return the head node IP address.
"""
print("creating ray cluster")
cluster_ip = os.getenv("MY_POD_IP")
if cluster_ip is None:
raise ValueError("MY_POD_IP environment variable is not set")
return f"{cluster_ip}"
if __name__ == "__main__":
flyte.init_from_config()
run = flyte.run(create_ray_cluster)
run.wait()
print("run url:", run.url)
print("cluster created, running ray task")
print("ray address:", run.outputs()[0])
run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
print("run url:", run.url)
API reference
See the Ray API reference for full details.