Pytorch
Flyte can execute distributed PyTorch jobs (which is similar to Running a torchrun script) natively on a Kubernetes Cluster, which manages a cluster’s lifecycle, spin-up, and tear down. It leverages the open-sourced Kubeflow Operator. This is like running a transient Pytorch cluster — a type of cluster spun up for a specific Pytorch job and torn down after completion.
To install the plugin, run the following command:
$ pip install --pre flyteplugins-pytorchThe following example shows how to configure Pytorch in a TaskEnvironment. Flyte automatically provisions a Pytorch cluster for each task using this configuration:
# /// script
# requires-python = "==3.13"
# dependencies = [
# "flyte==2.0.0b31",
# "flyteplugins-pytorch",
# "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///
import typing
import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
import flyte
image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
image=image,
)
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
"""
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
dataset = TensorDataset(x_train, y_train)
# Distributed-aware sampler
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, sampler=sampler)
def train_loop(epochs: int = 3) -> float:
"""
A simple training loop for linear regression.
"""
torch.distributed.init_process_group("gloo")
model = DDP(LinearRegressionModel())
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
dataloader = prepare_dataloader(
rank=rank,
world_size=world_size,
batch_size=64,
)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
final_loss = 0.0
for _ in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
final_loss = loss.item()
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")
return final_loss
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
A nested task that sets up a simple distributed training job using PyTorch's
"""
print("starting launcher")
loss = train_loop(epochs=epochs)
print("Training complete")
return loss
if __name__ == "__main__":
flyte.init_from_config()
r = flyte.run(torch_distributed_train, epochs=3)
print(r.name)
print(r.url)
r.wait()