2.0.3
Elastic
Package: flyteplugins.pytorch
Elastic defines the configuration for running a PyTorch elastic job using torch.distributed.
When a worker fails (e.g. CUDA OOM), the elastic agent detects the failure and restarts all workers as a group. Each restart cycle has a cost determined by the NCCL timeout settings below. The total worst-case time before the job fails is::
(max_restarts + 1) * (nccl_collective_timeout_sec + nccl_heartbeat_timeout_sec)
For example, with defaults (max_restarts=3, collective=600s, heartbeat=300s): 4 * 900s = 60 min. With aggressive settings (max_restarts=0, collective=60s, heartbeat=60s): 1 * 120s = 2 min.
class Elastic(
nnodes: typing.Union[int, str],
nproc_per_node: int,
rdzv_backend: typing.Literal['c10d', 'etcd', 'etcd-v2'],
run_policy: typing.Optional[flyteplugins.pytorch.task.RunPolicy],
monitor_interval: int,
max_restarts: int,
rdzv_configs: typing.Dict[str, typing.Any],
nccl_heartbeat_timeout_sec: typing.Optional[int],
nccl_async_error_handling: bool,
nccl_collective_timeout_sec: typing.Optional[int],
nccl_enable_monitoring: bool,
)| Parameter | Type | Description |
|---|---|---|
nnodes |
typing.Union[int, str] |
Number of nodes to use. Can be a fixed int or a range string (e.g., “2:4” for elastic training). |
nproc_per_node |
int |
Number of processes to launch per node. |
rdzv_backend |
typing.Literal['c10d', 'etcd', 'etcd-v2'] |
Rendezvous backend to use. Typically “c10d”. Defaults to “c10d”. |
run_policy |
typing.Optional[flyteplugins.pytorch.task.RunPolicy] |
Run policy applied to the job execution. Defaults to None. |
monitor_interval |
int |
Interval (in seconds) the elastic agent polls worker process health. Once a worker process exits, detection takes at most this long. Defaults to 3. |
max_restarts |
int |
Maximum number of worker group restarts before the elastic agent gives up and raises ChildFailedError. Each restart kills all workers and relaunches the entire group. If the failure is deterministic (e.g. model too large for GPU memory), restarts just repeat the same failure — set to 0 to fail immediately. Use higher values for transient failures (e.g. spot instance preemption, occasional OOM from variable batch sizes). Defaults to 3. |
rdzv_configs |
typing.Dict[str, typing.Any] |
Rendezvous configuration key-value pairs. Defaults to {“timeout”: 900, “join_timeout”: 900}. |
nccl_heartbeat_timeout_sec |
typing.Optional[int] |
Timeout in seconds for the NCCL heartbeat monitor thread. After the collective timeout fires and the NCCL watchdog aborts the communicator, the heartbeat monitor waits this long before sending SIGABRT to kill the worker process. This is the second phase of failure detection — it converts a stuck NCCL abort into a hard process kill. Defaults to 300 (5 min) instead of PyTorch’s 1800s (30 min). Set to None to use PyTorch default. |
nccl_async_error_handling |
bool |
When True, sets TORCH_NCCL_ASYNC_ERROR_HANDLING=1 so that NCCL aborts stuck collectives asynchronously instead of blocking indefinitely. This causes the worker process to crash-exit on a stuck collective, which the elastic agent detects within monitor_interval seconds (~3s by default) — much faster than waiting for the heartbeat timeout. Defaults to False (PyTorch default behavior). |
nccl_collective_timeout_sec |
typing.Optional[int] |
Timeout in seconds for individual NCCL collective operations (e.g. all-reduce inside loss.backward()). This is the timeout passed to torch.distributed.init_process_group. When a worker desyncs (e.g. skips a collective after OOM), surviving workers block in the collective for this long before the NCCL watchdog fires. This is the first phase of failure detection. PyTorch default is 600s (10 min). Set to None to use PyTorch default. |
nccl_enable_monitoring |
bool |
When True, sets TORCH_NCCL_ENABLE_MONITORING=1 to activate NCCL’s built-in monitoring thread. The monitoring thread checks each worker’s heartbeat counter and sends SIGABRT when it stalls, which is what drives nccl_heartbeat_timeout_sec. Defaults to True. |