Brain tumor MRI classification

Code available here.

This tutorial builds a medical-imaging pipeline that classifies brain MRI scans into four categories — Glioma, Meningioma, No Tumor, and Pituitary — using a two-phase EfficientNet-B4 transfer-learning strategy. The pipeline downloads the dataset, trains on a GPU with fault-tolerant checkpointing, and renders training curves and a confusion matrix directly in the Union.ai UI.

The example is split into focused modules:

  • config.py — container image, task environments, and the TrainingConfig hyperparameters.
  • dataset.py — downloads the Hugging Face dataset, builds class-balanced data loaders.
  • model.py / training.py — the Lightning module and the two-phase training loop.
  • utils.py — plotting helpers for the report.
  • run.py — the three Flyte tasks and the pipeline driver.

Flyte handles the production concerns:

  • Per-task resources: CPU for download/reporting, a GPU for training.
  • cache="auto" on dataset download and training, so reruns with the same data and config are free.
  • retries=3 plus Flyte checkpointing on the training task so a preempted GPU job resumes from the last epoch.
  • Built-in reports to visualize metrics without separate dashboard infrastructure.

Define the container image

A single GPU-ready image is shared by all tasks. with_source_folder copies the local modules (dataset.py, model.py, etc.) into the image.

config.py
image = flyte.Image.from_debian_base(
    name="tumor_detection_gpu"
).with_pip_packages(
    "torch",
    "lightning",
    "torchvision",
    "timm",
    "pillow",
    "scikit-learn",
    "plotly",
    "numpy",
    "pandas",
    "torchmetrics",
    "datasets",
    "typing_extensions",
).with_source_folder(
    pathlib.Path(__file__).parent,
    copy_contents_only=True,
)

Define the task environments

Each stage declares the resources it needs. The lightweight pipeline_env orchestrates the others via depends_on.

config.py
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
    name="tumor_dataset",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
    cache="auto",
)

# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
    name="tumor_gpu_training",
    image=image,
    resources=flyte.Resources(
        cpu=8,
        memory="32Gi",
        gpu="T4:1",
        disk="100Gi",
    ),
    env_vars={
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_LAUNCH_BLOCKING": "1",
        "TORCH_CUDA_MEMORY_FRACTION": "1.0",
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
    cache="auto",
)

# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
    name="tumor_report",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
)

# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
    name="tumor_pipeline",
    image=image,
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    depends_on=[dataset_env, training_env, report_env],
)

Configure training

Hyperparameters are gathered in a single TrainingConfig, serialized to JSON, and passed into the training task so the exact configuration is captured alongside the run.

run.py
TRAINING_CONFIG = TrainingConfig(
    phase1_epochs=8,
    phase2_epochs=25,
    phase1_lr=1e-3,
    phase2_lr=5e-5,
    batch_size=16,
    num_workers=0,
    log_interval=50,
    mixup_alpha=0.0,
    image_size=380,
    focal_gamma=3.0,
)

Load the dataset

The first task downloads the public Brain Tumor MRI dataset from Hugging Face (no auth required) and stores it as a flyte.io.Dir. It’s cached, so subsequent runs reuse it.

run.py
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_tumor_dataset()

Train the model

The training task downloads the dataset Dir, runs two-phase training (frozen backbone, then full fine-tuning), and writes metrics and predictions to an output Dir. It sets retries=3 so a preempted GPU node restarts the task.

run.py
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics and final predictions as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/tumor_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    from training import train_tumor_classifier
    config = TrainingConfig(**json.loads(config_json))
    result = train_tumor_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
    (output_dir / "predictions.json").write_text(json.dumps({
        "preds": result["final_preds"],
        "targets": result["final_targets"],
    }))

    return await Dir.from_local(str(output_dir))

Resumable checkpointing

To make retries cheap, training mirrors its Lightning checkpoint directory to a flyte.Checkpoint after every epoch, and resumes from the latest checkpoint when the task restarts.

training.py
    class FlyteLightningCheckpointCallback(ModelCheckpoint):
        """Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""

        def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
            super().__init__(dirpath=dirpath, **kwargs)
            self._flyte_checkpoint = flyte_checkpoint

        @override
        def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
            super().on_train_epoch_end(trainer, pl_module)
            if self.dirpath:
                self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))

On startup, the training loop looks for a checkpoint from a previous attempt and resumes from it if present:

training.py
    # --- Flyte checkpoint: resume from previous attempt if one exists ---
    resume_ckpt: str | None = None
    ctx = flyte.ctx()
    flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None

    if flyte_checkpoint:
        prev_path = flyte_checkpoint.load_sync()
        if prev_path:
            last = flyte.latest_checkpoint(prev_path)
            if last:
                ck = torch.load(str(last), map_location="cpu", weights_only=False)
                epoch_start = int(ck.get("epoch", 0))
                resume_ckpt = str(last)
                print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
    # --------------------------------------------------------------------

Generate the report

The reporting task reads the metrics and predictions, then renders accuracy/loss curves, a confusion matrix, and a per-class F1 chart with Plotly. report=True surfaces the HTML directly in the run’s report panel.

run.py
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves, confusion matrix,
    and per-class F1 chart in the Union UI report panel.
    """
    import numpy as np
    from pathlib import Path

    from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot

    local_dir = Path("/tmp/tumor_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())
    predictions = json.loads((local_path / "predictions.json").read_text())

    preds = np.array(predictions["preds"])
    targets = np.array(predictions["targets"])

    loss_fig, acc_fig = create_metrics_plots(history)
    cm_fig = create_confusion_matrix_plot(preds, targets)
    f1_fig = create_per_class_f1_plot(preds, targets)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
        + cm_fig.to_html(include_plotlyjs=False, full_html=False)
        + f1_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)

Orchestrate the pipeline

The driver task wires the three steps together.

run.py
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)

Run the pipeline

This example has no secrets — the dataset is public. Because the pipeline imports sibling modules and uses with_source_folder, run it from inside the example directory so the local files are picked up.

From the example directory:

cd v2/tutorials/tumor_detection
python run.py

Or submit it with the Flyte CLI from the same directory:

flyte run run.py tumor_detection_pipeline

When the run completes, open the create_report task in the UI to view the training curves, confusion matrix, and per-class F1 scores.