PyTorch
The PyTorch plugin lets you run distributed
PyTorch training jobs natively on Kubernetes. It uses the
Kubeflow Training Operator to manage multi-node training with PyTorch’s elastic launch (torchrun).
When to use this plugin
- Single-node or multi-node distributed training with
DistributedDataParallel(DDP) - Elastic training that can scale up and down during execution
- Any workload that uses
torch.distributedfor data-parallel or model-parallel training
Installation
pip install flyteplugins-pytorchConfiguration
Create an Elastic configuration and pass it as plugin_config to a TaskEnvironment:
from flyteplugins.pytorch import Elastic
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nnodes=2,
nproc_per_node=1,
),
image=image,
)Elastic parameters
| Parameter | Type | Description |
|---|---|---|
nnodes |
int or str |
Required. Number of nodes. Use an int for a fixed count or a range string (e.g., "2:4") for elastic training |
nproc_per_node |
int |
Required. Number of processes (workers) per node |
rdzv_backend |
str |
Rendezvous backend: "c10d" (default), "etcd", or "etcd-v2" |
max_restarts |
int |
Maximum worker group restarts (default: 3) |
monitor_interval |
int |
Agent health check interval in seconds (default: 3) |
run_policy |
RunPolicy |
Job run policy (cleanup, TTL, deadlines, retries) |
RunPolicy parameters
| Parameter | Type | Description |
|---|---|---|
clean_pod_policy |
str |
Pod cleanup policy: "None", "all", or "Running" |
ttl_seconds_after_finished |
int |
Seconds to keep pods after job completion |
active_deadline_seconds |
int |
Maximum time the job can run (seconds) |
backoff_limit |
int |
Number of retries before marking the job as failed |
NCCL tuning parameters
The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):
| Parameter | Type | Default | Description |
|---|---|---|---|
nccl_heartbeat_timeout_sec |
int |
300 |
NCCL heartbeat timeout (seconds) |
nccl_async_error_handling |
bool |
False |
Enable async NCCL error handling |
nccl_collective_timeout_sec |
int |
None |
Timeout for NCCL collective operations |
nccl_enable_monitoring |
bool |
True |
Enable NCCL monitoring |
Writing a distributed training task
Tasks using this plugin do not need to be async. Initialize the process group and use DistributedDataParallel as you normally would with torchrun:
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP
@torch_env.task
def train(epochs: int) -> float:
torch.distributed.init_process_group("gloo")
model = DDP(MyModel())
# ... training loop ...
return final_lossWhen nnodes=1, the task runs as a regular Python task (no Kubernetes training job is created). Set nnodes >= 2 for multi-node distributed training.
Example
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte>=2.0.0b52",
# "flyteplugins-pytorch",
# "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///
import typing
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import flyte
image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
image=image,
)
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
"""
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
dataset = TensorDataset(x_train, y_train)
# Distributed-aware sampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
def train_loop(epochs: int = 3) -> float:
"""
A simple training loop for linear regression.
"""
torch.distributed.init_process_group("gloo")
model = DDP(LinearRegressionModel())
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
dataloader = prepare_dataloader(
rank=rank,
world_size=world_size,
batch_size=64,
)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
final_loss = 0.0
for _ in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")
return final_loss
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
A nested task that sets up a simple distributed training job using PyTorch's
"""
print("starting launcher")
loss = train_loop(epochs=epochs)
print("Training complete")
return loss
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(torch_distributed_train, epochs=3)
print(r.name)
print(r.url)
r.wait()
API reference
See the PyTorch API reference for full details.