Brain tumor MRI classification
Code available here.
This tutorial builds a medical-imaging pipeline that classifies brain MRI scans into four categories — Glioma, Meningioma, No Tumor, and Pituitary — using a two-phase EfficientNet-B4 transfer-learning strategy. The pipeline downloads the dataset, trains on a GPU with fault-tolerant checkpointing, and renders training curves and a confusion matrix directly in the Union.ai UI.
The example is split into focused modules:
config.py— container image, task environments, and theTrainingConfighyperparameters.dataset.py— downloads the Hugging Face dataset, builds class-balanced data loaders.model.py/training.py— the Lightning module and the two-phase training loop.utils.py— plotting helpers for the report.run.py— the three Flyte tasks and the pipeline driver.
Flyte handles the production concerns:
- Per-task resources: CPU for download/reporting, a GPU for training.
cache="auto"on dataset download and training, so reruns with the same data and config are free.retries=3plus Flyte checkpointing on the training task so a preempted GPU job resumes from the last epoch.- Built-in reports to visualize metrics without separate dashboard infrastructure.
Define the container image
A single GPU-ready image is shared by all tasks. with_source_folder copies the local modules (dataset.py, model.py, etc.) into the image.
image = flyte.Image.from_debian_base(
name="tumor_detection_gpu"
).with_pip_packages(
"torch",
"lightning",
"torchvision",
"timm",
"pillow",
"scikit-learn",
"plotly",
"numpy",
"pandas",
"torchmetrics",
"datasets",
"typing_extensions",
).with_source_folder(
pathlib.Path(__file__).parent,
copy_contents_only=True,
)
Define the task environments
Each stage declares the resources it needs. The lightweight pipeline_env orchestrates the others via depends_on.
# Downloads raw MRI JPEG files — CPU only, no auth needed, result is cached
dataset_env = flyte.TaskEnvironment(
name="tumor_dataset",
image=image,
resources=flyte.Resources(cpu=2, memory="4Gi", disk="8Gi"),
cache="auto",
)
# GPU training — result is cached so re-running with the same data + config is free
training_env = flyte.TaskEnvironment(
name="tumor_gpu_training",
image=image,
resources=flyte.Resources(
cpu=8,
memory="32Gi",
gpu="T4:1",
disk="100Gi",
),
env_vars={
"CUDA_VISIBLE_DEVICES": "0",
"CUDA_LAUNCH_BLOCKING": "1",
"TORCH_CUDA_MEMORY_FRACTION": "1.0",
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
},
cache="auto",
)
# Report generation — CPU only, reads training results and renders Union UI panels
report_env = flyte.TaskEnvironment(
name="tumor_report",
image=image,
resources=flyte.Resources(cpu=2, memory="4Gi"),
)
# Pipeline driver — lightweight orchestrator that calls the three tasks above
pipeline_env = flyte.TaskEnvironment(
name="tumor_pipeline",
image=image,
resources=flyte.Resources(cpu=2, memory="4Gi"),
depends_on=[dataset_env, training_env, report_env],
)
Configure training
Hyperparameters are gathered in a single TrainingConfig, serialized to JSON, and passed into the training task so the exact configuration is captured alongside the run.
TRAINING_CONFIG = TrainingConfig(
phase1_epochs=8,
phase2_epochs=25,
phase1_lr=1e-3,
phase2_lr=5e-5,
batch_size=16,
num_workers=0,
log_interval=50,
mixup_alpha=0.0,
image_size=380,
focal_gamma=3.0,
)
Load the dataset
The first task downloads the public
Brain Tumor MRI dataset from Hugging Face (no auth required) and stores it as a flyte.io.Dir. It’s cached, so subsequent runs reuse it.
@dataset_env.task
async def load_dataset() -> Dir:
"""
Download raw Brain Tumor MRI JPEG files from Hugging Face and cache as flyte.io.Dir.
Runs once — result is reused on subsequent pipeline runs (cache="auto").
"""
return await download_tumor_dataset()
Train the model
The training task downloads the dataset Dir, runs two-phase training (frozen backbone, then full fine-tuning), and writes metrics and predictions to an output Dir. It sets retries=3 so a preempted GPU node restarts the task.
@training_env.task(retries=3)
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
"""
Download the raw dataset Dir, run two-phase training,
and return training metrics and final predictions as a Dir for the report task.
"""
from pathlib import Path
local_dir = Path("/tmp/tumor_local")
local_dir.mkdir(parents=True, exist_ok=True)
await dataset_dir.download(local_path=str(local_dir))
from training import train_tumor_classifier
config = TrainingConfig(**json.loads(config_json))
result = train_tumor_classifier(config=config, dataset_path=str(local_dir))
output_dir = Path("/tmp/training_results")
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))
(output_dir / "predictions.json").write_text(json.dumps({
"preds": result["final_preds"],
"targets": result["final_targets"],
}))
return await Dir.from_local(str(output_dir))
Resumable checkpointing
To make retries cheap, training mirrors its Lightning checkpoint directory to a flyte.Checkpoint after every epoch, and resumes from the latest checkpoint when the task restarts.
class FlyteLightningCheckpointCallback(ModelCheckpoint):
"""Mirrors the checkpoint directory to Flyte after every epoch so retries can resume."""
def __init__(self, flyte_checkpoint: flyte.Checkpoint, *, dirpath: str, **kwargs):
super().__init__(dirpath=dirpath, **kwargs)
self._flyte_checkpoint = flyte_checkpoint
@override
def on_train_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
super().on_train_epoch_end(trainer, pl_module)
if self.dirpath:
self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
On startup, the training loop looks for a checkpoint from a previous attempt and resumes from it if present:
# --- Flyte checkpoint: resume from previous attempt if one exists ---
resume_ckpt: str | None = None
ctx = flyte.ctx()
flyte_checkpoint = getattr(ctx, "checkpoint", None) if ctx else None
if flyte_checkpoint:
prev_path = flyte_checkpoint.load_sync()
if prev_path:
last = flyte.latest_checkpoint(prev_path)
if last:
ck = torch.load(str(last), map_location="cpu", weights_only=False)
epoch_start = int(ck.get("epoch", 0))
resume_ckpt = str(last)
print(f"Resuming from epoch {epoch_start}, checkpoint: {last}")
# --------------------------------------------------------------------
Generate the report
The reporting task reads the metrics and predictions, then renders accuracy/loss curves, a confusion matrix, and a per-class F1 chart with Plotly. report=True surfaces the HTML directly in the run’s report panel.
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
"""
Download training metrics and render loss/accuracy curves, confusion matrix,
and per-class F1 chart in the Union UI report panel.
"""
import numpy as np
from pathlib import Path
from utils import create_confusion_matrix_plot, create_metrics_plots, create_per_class_f1_plot
local_dir = Path("/tmp/tumor_report")
local_dir.mkdir(parents=True, exist_ok=True)
await results_dir.download(local_path=str(local_dir))
matches = list(local_dir.glob("**/metrics.json"))
if not matches:
raise RuntimeError(f"metrics.json not found under {local_dir}")
local_path = matches[0].parent
history = json.loads((local_path / "metrics.json").read_text())
predictions = json.loads((local_path / "predictions.json").read_text())
preds = np.array(predictions["preds"])
targets = np.array(predictions["targets"])
loss_fig, acc_fig = create_metrics_plots(history)
cm_fig = create_confusion_matrix_plot(preds, targets)
f1_fig = create_per_class_f1_plot(preds, targets)
combined_html = (
acc_fig.to_html(include_plotlyjs=True, full_html=False)
+ loss_fig.to_html(include_plotlyjs=False, full_html=False)
+ cm_fig.to_html(include_plotlyjs=False, full_html=False)
+ f1_fig.to_html(include_plotlyjs=False, full_html=False)
)
flyte.report.log(combined_html, do_flush=True)
Orchestrate the pipeline
The driver task wires the three steps together.
@pipeline_env.task
async def tumor_detection_pipeline() -> None:
"""Orchestrate dataset loading, GPU training, and report generation."""
dataset_dir = await load_dataset()
results_dir = await train_model(
dataset_dir=dataset_dir,
config_json=json.dumps(TRAINING_CONFIG.to_dict()),
)
await create_report(results_dir=results_dir)
Run the pipeline
This example has no secrets — the dataset is public. Because the pipeline imports sibling modules and uses with_source_folder, run it from inside the example directory so the local files are picked up.
From the example directory:
cd v2/tutorials/tumor_detection
python run.pyOr submit it with the Flyte CLI from the same directory:
flyte run run.py tumor_detection_pipelineWhen the run completes, open the create_report task in the UI to view the training curves, confusion matrix, and per-class F1 scores.