PyTorch

The PyTorch plugin lets you run distributed PyTorch training jobs natively on Kubernetes. It uses the Kubeflow Training Operator to manage multi-node training with PyTorch’s elastic launch (torchrun).

When to use this plugin

  • Single-node or multi-node distributed training with DistributedDataParallel (DDP)
  • Elastic training that can scale up and down during execution
  • Any workload that uses torch.distributed for data-parallel or model-parallel training

Installation

pip install flyteplugins-pytorch

Configuration

Create an Elastic configuration and pass it as plugin_config to a TaskEnvironment:

from flyteplugins.pytorch import Elastic

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nnodes=2,
        nproc_per_node=1,
    ),
    image=image,
)

Elastic parameters

Parameter Type Description
nnodes int or str Required. Number of nodes. Use an int for a fixed count or a range string (e.g., "2:4") for elastic training
nproc_per_node int Required. Number of processes (workers) per node
rdzv_backend str Rendezvous backend: "c10d" (default), "etcd", or "etcd-v2"
max_restarts int Maximum worker group restarts (default: 3)
monitor_interval int Agent health check interval in seconds (default: 3)
run_policy RunPolicy Job run policy (cleanup, TTL, deadlines, retries)

RunPolicy parameters

Parameter Type Description
clean_pod_policy str Pod cleanup policy: "None", "all", or "Running"
ttl_seconds_after_finished int Seconds to keep pods after job completion
active_deadline_seconds int Maximum time the job can run (seconds)
backoff_limit int Number of retries before marking the job as failed

NCCL tuning parameters

The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):

Parameter Type Default Description
nccl_heartbeat_timeout_sec int 300 NCCL heartbeat timeout (seconds)
nccl_async_error_handling bool False Enable async NCCL error handling
nccl_collective_timeout_sec int None Timeout for NCCL collective operations
nccl_enable_monitoring bool True Enable NCCL monitoring

Writing a distributed training task

Tasks using this plugin do not need to be async. Initialize the process group and use DistributedDataParallel as you normally would with torchrun:

import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP

@torch_env.task
def train(epochs: int) -> float:
    torch.distributed.init_process_group("gloo")
    model = DDP(MyModel())
    # ... training loop ...
    return final_loss

When nnodes=1, the task runs as a regular Python task (no Kubernetes training job is created). Set nnodes >= 2 for multi-node distributed training.

Example

pytorch_example.py
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pytorch",
#    "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///

import typing

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

import flyte

image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nproc_per_node=1,
        # if you want to do local testing set nnodes=1
        nnodes=2,
    ),
    image=image,
)


class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)


def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
    """
    Prepare a DataLoader with a DistributedSampler so each rank
    gets a shard of the dataset.
    """
    # Dummy dataset
    x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
    y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
    dataset = TensorDataset(x_train, y_train)

    # Distributed-aware sampler
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

    return DataLoader(dataset, batch_size=batch_size, sampler=sampler)


def train_loop(epochs: int = 3) -> float:
    """
    A simple training loop for linear regression.
    """
    torch.distributed.init_process_group("gloo")
    model = DDP(LinearRegressionModel())

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    dataloader = prepare_dataloader(
        rank=rank,
        world_size=world_size,
        batch_size=64,
    )

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    final_loss = 0.0

    for _ in range(epochs):
        for x, y in dataloader:
            outputs = model(x)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            final_loss = loss.item()
        if torch.distributed.get_rank() == 0:
            print(f"Loss: {final_loss}")

    return final_loss


@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
    """
    A nested task that sets up a simple distributed training job using PyTorch's
    """
    print("starting launcher")
    loss = train_loop(epochs=epochs)
    print("Training complete")
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(torch_distributed_train, epochs=3)
    print(r.name)
    print(r.url)
    r.wait()

API reference

See the PyTorch API reference for full details.