# Brain tumor MRI classification

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/tumor_detection).

This tutorial builds a medical-imaging pipeline that classifies brain MRI scans into four categories — Glioma, Meningioma, No Tumor, and Pituitary — using a two-phase EfficientNet-B4 transfer-learning strategy. The pipeline downloads the dataset, trains on a GPU with fault-tolerant checkpointing, and renders training curves and a confusion matrix directly in the Union.ai UI.

The example is split into focused modules:

- `config.py` — container image, task environments, and the `TrainingConfig` hyperparameters.
- `dataset.py` — downloads the Hugging Face dataset, builds class-balanced data loaders.
- `model.py` / `training.py` — the Lightning module and the two-phase training loop.
- `utils.py` — plotting helpers for the report.
- `run.py` — the three Flyte tasks and the pipeline driver.

Flyte handles the production concerns:

- **Per-task resources**: CPU for download/reporting, a GPU for training.
- **`cache="auto"`** on dataset download and training, so reruns with the same data and config are free.
- **`retries=3`** plus **Flyte checkpointing** on the training task so a preempted GPU job resumes from the last epoch.
- **Built-in reports** to visualize metrics without separate dashboard infrastructure.

## Define the container image

A single GPU-ready image is shared by all tasks. `with_source_folder` copies the local modules (`dataset.py`, `model.py`, etc.) into the image.

```
"""
Configuration for brain tumor MRI classification pipeline.

Defines task environments, resource requirements, and training hyperparameters.
"""

import pathlib

import flyte

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(
    name="tumor_detection_gpu"
).with_pip_packages(
    "torch",
    "lightning",
    "torchvision",
    "timm",
    "pillow",
    "scikit-learn",
    "plotly",
    "numpy",
    "pandas",
    "torchmetrics",
    "datasets",
    "typing_extensions",
).with_source_folder(
    pathlib.Path(__file__).parent,
    copy_contents_only=True,
)
# {{/docs-fragment image}}

# {{docs-fragment envs}}
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
    name="tumor_dataset",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
    cache="auto",
)

# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
    name="tumor_gpu_training",
    image=image,
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    env_vars={
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_LAUNCH_BLOCKING": "1",
        "TORCH_CUDA_MEMORY_FRACTION": "1.0",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
    cache="auto",
)

# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
    name="tumor_report",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)

# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
    name="tumor_pipeline",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, report_env],
)
# {{/docs-fragment envs}}

class TrainingConfig:
    """Unified training configuration for brain tumor MRI classification."""

    def __init__(
        self,
        image_size: int = 380,
        num_classes: int = 4,
        model_name: str = "efficientnet_b4",
        pretrained: bool = True,
        phase1_epochs: int = 8,
        phase1_lr: float = 1e-3,
        phase1_freeze_backbone: bool = True,
        phase2_epochs: int = 25,
        phase2_lr: float = 5e-5,
        batch_size: int = 16,
        num_workers: int = 0,
        val_split: float = 0.2,
        weight_decay: float = 1e-4,
        warmup_steps: int = 200,
        focal_gamma: float = 2.0,
        mixup_alpha: float = 0.0,
        log_interval: int = 50,
    ):
        self.image_size = image_size
        self.num_classes = num_classes

        self.model_name = model_name
        self.pretrained = pretrained

        self.phase1_epochs = phase1_epochs
        self.phase1_lr = phase1_lr
        self.phase1_freeze_backbone = phase1_freeze_backbone

        self.phase2_epochs = phase2_epochs
        self.phase2_lr = phase2_lr

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split

        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps

        self.focal_gamma = focal_gamma
        self.mixup_alpha = mixup_alpha

        self.log_interval = log_interval

    def to_dict(self) -> dict:
        return self.__dict__
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/config.py*

## Define the task environments

Each stage declares the resources it needs. The lightweight `pipeline_env` orchestrates the others via `depends_on`.

```
"""
Configuration for brain tumor MRI classification pipeline.

Defines task environments, resource requirements, and training hyperparameters.
"""

import pathlib

import flyte

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(
    name="tumor_detection_gpu"
).with_pip_packages(
    "torch",
    "lightning",
    "torchvision",
    "timm",
    "pillow",
    "scikit-learn",
    "plotly",
    "numpy",
    "pandas",
    "torchmetrics",
    "datasets",
    "typing_extensions",
).with_source_folder(
    pathlib.Path(__file__).parent,
    copy_contents_only=True,
)
# {{/docs-fragment image}}

# {{docs-fragment envs}}
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
    name="tumor_dataset",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
    cache="auto",
)

# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
    name="tumor_gpu_training",
    image=image,
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    env_vars={
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_LAUNCH_BLOCKING": "1",
        "TORCH_CUDA_MEMORY_FRACTION": "1.0",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
    cache="auto",
)

# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
    name="tumor_report",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)

# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
    name="tumor_pipeline",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, report_env],
)
# {{/docs-fragment envs}}

class TrainingConfig:
    """Unified training configuration for brain tumor MRI classification."""

    def __init__(
        self,
        image_size: int = 380,
        num_classes: int = 4,
        model_name: str = "efficientnet_b4",
        pretrained: bool = True,
        phase1_epochs: int = 8,
        phase1_lr: float = 1e-3,
        phase1_freeze_backbone: bool = True,
        phase2_epochs: int = 25,
        phase2_lr: float = 5e-5,
        batch_size: int = 16,
        num_workers: int = 0,
        val_split: float = 0.2,
        weight_decay: float = 1e-4,
        warmup_steps: int = 200,
        focal_gamma: float = 2.0,
        mixup_alpha: float = 0.0,
        log_interval: int = 50,
    ):
        self.image_size = image_size
        self.num_classes = num_classes

        self.model_name = model_name
        self.pretrained = pretrained

        self.phase1_epochs = phase1_epochs
        self.phase1_lr = phase1_lr
        self.phase1_freeze_backbone = phase1_freeze_backbone

        self.phase2_epochs = phase2_epochs
        self.phase2_lr = phase2_lr

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split

        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps

        self.focal_gamma = focal_gamma
        self.mixup_alpha = mixup_alpha

        self.log_interval = log_interval

    def to_dict(self) -> dict:
        return self.__dict__
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/config.py*

## Configure training

Hyperparameters are gathered in a single `TrainingConfig`, serialized to JSON, and passed into the training task so the exact configuration is captured alongside the run.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Load the dataset

The first task downloads the public [Brain Tumor MRI dataset](https://huggingface.co/datasets/AIOmarRehan/Brain_Tumor_MRI_Dataset) from Hugging Face (no auth required) and stores it as a `flyte.io.Dir`. It's cached, so subsequent runs reuse it.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Train the model

The training task downloads the dataset `Dir`, runs two-phase training (frozen backbone, then full fine-tuning), and writes metrics and predictions to an output `Dir`. It sets `retries=3` so a preempted GPU node restarts the task.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

### Resumable checkpointing

To make retries cheap, training mirrors its Lightning checkpoint directory to a `flyte.Checkpoint` after every epoch, and resumes from the latest checkpoint when the task restarts.

```
"""
Training pipeline for brain tumor MRI classification.

Implements two-phase training:
- Phase 1: Frozen backbone (feature extractor), train classification head
- Phase 2: Fine-tune full model with differential LRs + cosine annealing
"""

from config import TrainingConfig
from dataset import compute_class_weights, create_data_loaders
from model import TumorClassifierLightningModule
from utils import get_model_size, get_trainable_params

def train_tumor_classifier(
    config: TrainingConfig,
    dataset_path: str,
) -> dict:
    """
    Run two-phase training on the preprocessed dataset and return metrics + final predictions.

    dataset_path: local directory where the flyte.io.Dir was downloaded by the training task.
    """
    import pathlib

    import flyte
    import lightning as L
    import torch
    from lightning.pytorch.callbacks import ModelCheckpoint
    from typing_extensions import override

    # {{docs-fragment flyte_checkpoint}}
    class FlyteLightningCheckpointCallback(ModelCheckpoint):
        """Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""

        def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
            super().__init__(dirpath=dirpath, **kwargs)
            self._flyte_checkpoint = flyte_checkpoint

        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            super().on_train_epoch_end(trainer, pl_module)
            if self.dirpath:
                self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
    # {{/docs-fragment flyte_checkpoint}}

    class MetricsLoggerCallback(L.Callback):
        def __init__(self, phase1_epochs: int):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.history = []

        def on_validation_epoch_end(self, trainer, _pl_module):
            epoch = trainer.current_epoch
            metrics = trainer.callback_metrics
            self.history.append({
                "epoch": epoch,
                "phase": 1 if epoch < self.phase1_epochs else 2,
                "train_loss": float(metrics.get("train/loss_epoch", 0)),
                "val_loss": float(metrics.get("val/loss", 0)),
                "val_acc": float(metrics.get("val/acc", 0)),
                "macro_f1": float(metrics.get("val/macro_f1", 0)),
            })

    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True

    print("\n" + "=" * 80)
    print("BRAIN TUMOR MRI CLASSIFICATION WITH EFFICIENTNET-B4")
    print("=" * 80)
    print(f"Config: {config.to_dict()}\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    print("\nLoading MRI images...")
    train_loader, val_loader = create_data_loaders(
        dataset_path=dataset_path,
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        val_split=config.val_split,
    )
    print(f"Data loaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")

    print("\nComputing class weights for focal loss...")
    class_weights = compute_class_weights(dataset_path)
    print(f"Class weights: {class_weights.tolist()}")

    # Per-class gamma: Meningioma gets 7.0, all others 3.0.
    # CLASS_NAMES alphabetical order: Glioma=0, Meningioma=1, No Tumor=2, Pituitary=3
    gamma_per_class = torch.tensor([3.0, 7.0, 3.0, 3.0])

    print("\nInitializing model...")
    model = TumorClassifierLightningModule(
        num_classes=config.num_classes,
        model_name=config.model_name,
        pretrained=config.pretrained,
        learning_rate=config.phase1_lr,
        freeze_backbone=config.phase1_freeze_backbone,
        weight_decay=config.weight_decay,
        warmup_steps=config.warmup_steps,
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        focal_gamma=config.focal_gamma,
        mixup_alpha=config.mixup_alpha,
        class_weights=class_weights,
        gamma_per_class=gamma_per_class,
    )

    print(f"Model: {config.model_name}")
    print(f"Total parameters: {get_model_size(model.model):,}")
    print(f"Trainable parameters: {get_trainable_params(model.model):,}")

    from pathlib import Path
    checkpoint_dir = Path("/tmp/tumor_checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # {{docs-fragment resume}}
    # --- Flyte checkpoint: resume from previous attempt if one exists ---
    resume_ckpt: str | None = None
    ctx = flyte.ctx()
    flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None

    if flyte_checkpoint:
        prev_path = flyte_checkpoint.load_sync()
        if prev_path:
            last = flyte.latest_checkpoint(prev_path)
            if last:
                ck = torch.load(str(last), map_location="cpu", weights_only=False)
                epoch_start = int(ck.get("epoch", 0))
                resume_ckpt = str(last)
                print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
    # --------------------------------------------------------------------
    # {{/docs-fragment resume}}

    metrics_cb = MetricsLoggerCallback(phase1_epochs=config.phase1_epochs)

    resume_callback = (
        FlyteLightningCheckpointCallback(
            flyte_checkpoint,
            dirpath=str(checkpoint_dir),
            filename="last",
            save_last=True,
            save_top_k=1,
        )
        if flyte_checkpoint else
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename="best-{epoch:03d}-{val_acc:.3f}",
            monitor="val/acc",
            mode="max",
            save_top_k=3,
            verbose=True,
            auto_insert_metric_name=False,
        )
    )

    callbacks = [
        resume_callback,
        metrics_cb,
        PhaseChangeCallback(
            phase1_epochs=config.phase1_epochs,
            phase2_lr=config.phase2_lr,
        ),
    ]

    trainer = L.Trainer(
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision="16-mixed",
        callbacks=callbacks,
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=config.log_interval,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)

    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(f"\n✓ Training complete!")
    print(f"Best checkpoint: {best_checkpoint}")

    # Final inference with TTA (test-time augmentation): average logits over
    # original + h-flip + v-flip + 90° rotations for a free accuracy boost.
    print("\nRunning final inference with TTA for confusion matrix...")
    import numpy as np
    import torchvision.transforms.functional as TF
    model.eval()
    model.to(device)
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            aug_logits = [
                model.model(images),
                model.model(TF.hflip(images)),
                model.model(TF.vflip(images)),
                model.model(torch.rot90(images, k=1, dims=[2, 3])),
                model.model(torch.rot90(images, k=3, dims=[2, 3])),
            ]
            avg_logits = torch.stack(aug_logits).mean(dim=0)
            all_preds.append(avg_logits.argmax(dim=1).cpu())
            all_targets.append(labels.cpu())
    final_preds = torch.cat(all_preds).numpy()
    final_targets = torch.cat(all_targets).numpy()

    return {
        "best_checkpoint": best_checkpoint,
        "metrics": metrics_cb.history,
        "final_preds": final_preds.tolist(),
        "final_targets": final_targets.tolist(),
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/training.py*

On startup, the training loop looks for a checkpoint from a previous attempt and resumes from it if present:

```
"""
Training pipeline for brain tumor MRI classification.

Implements two-phase training:
- Phase 1: Frozen backbone (feature extractor), train classification head
- Phase 2: Fine-tune full model with differential LRs + cosine annealing
"""

from config import TrainingConfig
from dataset import compute_class_weights, create_data_loaders
from model import TumorClassifierLightningModule
from utils import get_model_size, get_trainable_params

def train_tumor_classifier(
    config: TrainingConfig,
    dataset_path: str,
) -> dict:
    """
    Run two-phase training on the preprocessed dataset and return metrics + final predictions.

    dataset_path: local directory where the flyte.io.Dir was downloaded by the training task.
    """
    import pathlib

    import flyte
    import lightning as L
    import torch
    from lightning.pytorch.callbacks import ModelCheckpoint
    from typing_extensions import override

    # {{docs-fragment flyte_checkpoint}}
    class FlyteLightningCheckpointCallback(ModelCheckpoint):
        """Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""

        def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
            super().__init__(dirpath=dirpath, **kwargs)
            self._flyte_checkpoint = flyte_checkpoint

        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            super().on_train_epoch_end(trainer, pl_module)
            if self.dirpath:
                self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
    # {{/docs-fragment flyte_checkpoint}}

    class MetricsLoggerCallback(L.Callback):
        def __init__(self, phase1_epochs: int):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.history = []

        def on_validation_epoch_end(self, trainer, _pl_module):
            epoch = trainer.current_epoch
            metrics = trainer.callback_metrics
            self.history.append({
                "epoch": epoch,
                "phase": 1 if epoch < self.phase1_epochs else 2,
                "train_loss": float(metrics.get("train/loss_epoch", 0)),
                "val_loss": float(metrics.get("val/loss", 0)),
                "val_acc": float(metrics.get("val/acc", 0)),
                "macro_f1": float(metrics.get("val/macro_f1", 0)),
            })

    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True

    print("\n" + "=" * 80)
    print("BRAIN TUMOR MRI CLASSIFICATION WITH EFFICIENTNET-B4")
    print("=" * 80)
    print(f"Config: {config.to_dict()}\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    print("\nLoading MRI images...")
    train_loader, val_loader = create_data_loaders(
        dataset_path=dataset_path,
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        val_split=config.val_split,
    )
    print(f"Data loaders created: {len(train_loader)} train batches, {len(val_loader)} val batches")

    print("\nComputing class weights for focal loss...")
    class_weights = compute_class_weights(dataset_path)
    print(f"Class weights: {class_weights.tolist()}")

    # Per-class gamma: Meningioma gets 7.0, all others 3.0.
    # CLASS_NAMES alphabetical order: Glioma=0, Meningioma=1, No Tumor=2, Pituitary=3
    gamma_per_class = torch.tensor([3.0, 7.0, 3.0, 3.0])

    print("\nInitializing model...")
    model = TumorClassifierLightningModule(
        num_classes=config.num_classes,
        model_name=config.model_name,
        pretrained=config.pretrained,
        learning_rate=config.phase1_lr,
        freeze_backbone=config.phase1_freeze_backbone,
        weight_decay=config.weight_decay,
        warmup_steps=config.warmup_steps,
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        focal_gamma=config.focal_gamma,
        mixup_alpha=config.mixup_alpha,
        class_weights=class_weights,
        gamma_per_class=gamma_per_class,
    )

    print(f"Model: {config.model_name}")
    print(f"Total parameters: {get_model_size(model.model):,}")
    print(f"Trainable parameters: {get_trainable_params(model.model):,}")

    from pathlib import Path
    checkpoint_dir = Path("/tmp/tumor_checkpoints")
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # {{docs-fragment resume}}
    # --- Flyte checkpoint: resume from previous attempt if one exists ---
    resume_ckpt: str | None = None
    ctx = flyte.ctx()
    flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None

    if flyte_checkpoint:
        prev_path = flyte_checkpoint.load_sync()
        if prev_path:
            last = flyte.latest_checkpoint(prev_path)
            if last:
                ck = torch.load(str(last), map_location="cpu", weights_only=False)
                epoch_start = int(ck.get("epoch", 0))
                resume_ckpt = str(last)
                print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
    # --------------------------------------------------------------------
    # {{/docs-fragment resume}}

    metrics_cb = MetricsLoggerCallback(phase1_epochs=config.phase1_epochs)

    resume_callback = (
        FlyteLightningCheckpointCallback(
            flyte_checkpoint,
            dirpath=str(checkpoint_dir),
            filename="last",
            save_last=True,
            save_top_k=1,
        )
        if flyte_checkpoint else
        ModelCheckpoint(
            dirpath=str(checkpoint_dir),
            filename="best-{epoch:03d}-{val_acc:.3f}",
            monitor="val/acc",
            mode="max",
            save_top_k=3,
            verbose=True,
            auto_insert_metric_name=False,
        )
    )

    callbacks = [
        resume_callback,
        metrics_cb,
        PhaseChangeCallback(
            phase1_epochs=config.phase1_epochs,
            phase2_lr=config.phase2_lr,
        ),
    ]

    trainer = L.Trainer(
        max_epochs=config.phase1_epochs + config.phase2_epochs,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        precision="16-mixed",
        callbacks=callbacks,
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=config.log_interval,
        gradient_clip_val=1.0,
    )

    trainer.fit(model, train_loader, val_loader, ckpt_path=resume_ckpt)

    best_checkpoint = trainer.checkpoint_callback.best_model_path
    print(f"\n✓ Training complete!")
    print(f"Best checkpoint: {best_checkpoint}")

    # Final inference with TTA (test-time augmentation): average logits over
    # original + h-flip + v-flip + 90° rotations for a free accuracy boost.
    print("\nRunning final inference with TTA for confusion matrix...")
    import numpy as np
    import torchvision.transforms.functional as TF
    model.eval()
    model.to(device)
    all_preds, all_targets = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            aug_logits = [
                model.model(images),
                model.model(TF.hflip(images)),
                model.model(TF.vflip(images)),
                model.model(torch.rot90(images, k=1, dims=[2, 3])),
                model.model(torch.rot90(images, k=3, dims=[2, 3])),
            ]
            avg_logits = torch.stack(aug_logits).mean(dim=0)
            all_preds.append(avg_logits.argmax(dim=1).cpu())
            all_targets.append(labels.cpu())
    final_preds = torch.cat(all_preds).numpy()
    final_targets = torch.cat(all_targets).numpy()

    return {
        "best_checkpoint": best_checkpoint,
        "metrics": metrics_cb.history,
        "final_preds": final_preds.tolist(),
        "final_targets": final_targets.tolist(),
    }
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/training.py*

## Generate the report

The reporting task reads the metrics and predictions, then renders accuracy/loss curves, a confusion matrix, and a per-class F1 chart with Plotly. `report=True` surfaces the HTML directly in the run's report panel.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Orchestrate the pipeline

The driver task wires the three steps together.

```
"""
Flyte/Union pipeline for brain tumor MRI classification.

Three-task pipeline:
1. load_dataset  — download Brain Tumor MRI from Hugging Face, cache as Dir (CPU)
2. train_model   — two-phase EfficientNet-B4 training with focal loss (GPU)
3. create_report — render training curves and confusion matrix in the Union UI (CPU)
"""

import json

import flyte
from flyte.io import Dir

from config import TrainingConfig, dataset_env, pipeline_env, report_env, training_env
from dataset import download_tumor_dataset

# {{docs-fragment config}}
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)
# {{/docs-fragment config}}

# {{docs-fragment load_dataset}}
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()
# {{/docs-fragment load_dataset}}

# {{docs-fragment train_model}}
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))
# {{/docs-fragment train_model}}

# {{docs-fragment create_report}}
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)
# {{/docs-fragment create_report}}

# {{docs-fragment pipeline}}
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)
# {{/docs-fragment pipeline}}

if __name__ == "__main__":
    import pathlib
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    run = flyte.with_runcontext().run(tumor_detection_pipeline)
    print(f"\n✓ Pipeline submitted!")
    print(f"Run URL: {run.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/tumor_detection/run.py*

## Run the pipeline

This example has no secrets — the dataset is public. Because the pipeline imports sibling modules and uses `with_source_folder`, run it from inside the example directory so the local files are picked up.

From the [example directory](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/tumor_detection):

```
cd v2/tutorials/tumor_detection
python run.py
```

Or submit it with the Flyte CLI from the same directory:

```
flyte run run.py tumor_detection_pipeline
```

When the run completes, open the `create_report` task in the UI to view the training curves, confusion matrix, and per-class F1 scores.

---
**Source**: https://github.com/unionai/unionai-docs/blob/main/content/tutorials/biotech-healthcare/tumor-detection/_index.md
**HTML**: https://www.union.ai/docs/v2/union/tutorials/biotech-healthcare/tumor-detection/
