Satellite Image Classification with EfficientNet

Satellite Image

Background

Remote sensing has transformed how we monitor our planet. From tracking deforestation to detecting urban sprawl, satellite imagery provides a bird’s-eye view of land use change at global scale.

But training a model that can reliably classify that imagery — across 10 distinct land-use categories, at production quality — requires more than just a good model. It requires a pipeline that handles data, compute, caching, experiment tracking, and reporting as first-class concerns.

This tutorial walks through a complete satellite image classification pipeline built on Union.ai, using EfficientNet-B0, a two-phase training strategy, and Weights & Biases for experiment tracking.

Full code available here.

Dataset

EuroSAT is a benchmark dataset of 27,000 labeled satellite images drawn from the Sentinel-2 satellite. Each image is 64×64 pixels across 10 land-use classes: Annual Crop, Forest, Herbaceous Vegetation, Highway, Industrial, Pasture, Permanent Crop, Residential, River, and Sea/Lake.

It’s a well-structured dataset - balanced, clearly labeled - which makes it ideal for demonstrating a production-grade training pipeline without the overhead of massive data infrastructure.

Model

We use EfficientNet-B0 from timm (the PyTorch Image Models library), pretrained on ImageNet. EfficientNet was designed to scale depth, width, and resolution jointly using a compound coefficient, giving strong accuracy with a relatively small parameter count (~5.3M). The ImageNet pretraining means the backbone already understands edges, textures, and shapes - features that transfer well to satellite imagery.

Two-Phase Training

Fine-tuning a pretrained model naively by using all weights immediately often leads to catastrophic forgetting: the model destroys its learned representations before the new task-specific head has had a chance to stabilize.

Instead, we use a two-phase approach:

Phase 1: Feature Extraction (frozen backbone). The EfficientNet backbone is frozen. Only the classification head is trained, at a relatively high learning rate (2e-3). This gives the head 7 epochs to learn to map ImageNet features to EuroSAT categories, without disturbing the pretrained weights.

Phase 2: Fine-tuning (unfrozen backbone). The backbone is unfrozen and added to the optimizer with a 10× lower learning rate than the head (phase2_lr × 0.1). A fresh cosine annealing schedule is initialized over the remaining steps, so the learning rate doesn’t arrive near-zero from Phase 1’s schedule before Phase 2 even begins. This lets the backbone adapt to satellite-specific features while preserving the general representations it learned on ImageNet.

The transition happens automatically inside a PhaseChangeCallback:

training.py
    class PhaseChangeCallback(L.Callback):
        def __init__(self, phase1_epochs: int, phase2_lr: float):
            super().__init__()
            self.phase1_epochs = phase1_epochs
            self.phase2_lr = phase2_lr
            self.phase_changed = False

        def on_train_epoch_end(self, trainer, pl_module):
            if not self.phase_changed and (trainer.current_epoch + 1) == self.phase1_epochs:
                print("\n" + "=" * 80)
                print("TRANSITIONING TO PHASE 2: UNFREEZING BACKBONE AND ADJUSTING LR")
                print("=" * 80 + "\n")

                pl_module.model.unfreeze_backbone()

                for param_group in trainer.optimizers[0].param_groups:
                    param_group["lr"] = self.phase2_lr

                # Add backbone params to optimizer with 10x lower LR.
                # Backbone was excluded at init because it was frozen.
                backbone_lr = self.phase2_lr * 0.1
                backbone_decay, backbone_no_decay = [], []
                for param in pl_module.model.backbone.parameters():
                    if param.ndim >= 2:
                        backbone_decay.append(param)
                    else:
                        backbone_no_decay.append(param)
                optimizer = trainer.optimizers[0]
                optimizer.add_param_group({"params": backbone_decay, "lr": backbone_lr, "weight_decay": pl_module.weight_decay})
                optimizer.add_param_group({"params": backbone_no_decay, "lr": backbone_lr, "weight_decay": 0.0})

                # Fresh cosine schedule over remaining Phase 2 steps to avoid
                # the Phase 1 schedule arriving near-zero before Phase 2 begins.
                steps_remaining = trainer.estimated_stepping_batches - trainer.global_step
                new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    trainer.optimizers[0],
                    T_max=max(1, steps_remaining),
                    eta_min=1e-6,
                )
                for lr_scheduler_config in trainer.lr_scheduler_configs:
                    lr_scheduler_config.scheduler = new_scheduler

                print(f"Phase 2 started: lr={self.phase2_lr}")
                print(f"Total parameters: {get_model_size(pl_module.model):,}")
                print(f"Trainable parameters: {get_trainable_params(pl_module.model):,}")
                self.phase_changed = True

This two-phase strategy consistently reaches >95% validation accuracy on EuroSAT within 17 total epochs.

Pipeline

Training a model is only part of the story. The real challenge is building a system that is reproducible, cost-efficient, and easy to iterate on. That’s where Union’s TaskEnvironment model shines: each stage of the pipeline runs in the right compute environment, and results are cached so you never pay for work you’ve already done.

The pipeline has four components, each with its own environment defined in config.py.

Task 1: Data Download (dataset_env)

run.py
@dataset_env.task
async def load_dataset() -> Dir:
    """
    Download raw EuroSAT JPEG files and cache as flyte.io.Dir.
    Runs once — result is reused on subsequent pipeline runs (cache="auto").
    """
    return await download_eurosat()

This task downloads the raw EuroSAT JPEG files via torchvision and packages them as a flyte.io.Dir. It runs on a lightweight CPU container (2 cores, 2 GB RAM) - no GPU needed. With cache="auto", the result is stored and reused on every subsequent run. You pay for the download exactly once.

No preprocessing happens here. Raw images are passed directly to training so that all transforms - resize, normalization, and augmentation - happen per-batch with the full training context, giving the model properly prepared 224×224 input from the original pixels.

Task 2: GPU Training (training_env)

run.py
@wandb_init
@training_env.task
async def train_model(dataset_dir: Dir, config_json: str) -> Dir:
    """
    Download the raw dataset Dir, run two-phase training,
    and return training metrics as a Dir for the report task.
    """
    from pathlib import Path

    local_dir = Path("/tmp/eurosat_local")
    local_dir.mkdir(parents=True, exist_ok=True)
    await dataset_dir.download(local_path=str(local_dir))

    config = TrainingConfig(**json.loads(config_json))
    result = train_satellite_classifier(config=config, dataset_path=str(local_dir))

    output_dir = Path("/tmp/training_results")
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "metrics.json").write_text(json.dumps(result["metrics"]))

    return await Dir.from_local(str(output_dir))

This task runs on a T4 GPU with 32 GB RAM. It receives the dataset Dir from Task 1, downloads it locally, then runs the two-phase training loop using PyTorch Lightning.

Two things worth noting:

  • With cache="auto", training results are cached based on the input data and config. If you rerun the pipeline with the same dataset and hyperparameters, Union skips training entirely and returns the cached metrics. This makes hyperparameter search much cheaper: only configurations you haven’t tried before actually execute.

  • @wandb_init — the flyteplugins-wandb integration initializes a W&B run automatically and makes it available via get_wandb_run(). This means every training run automatically logs metrics, learning rate curves, and t-SNE visualizations of the learned feature space to your W&B project.

training.py
    wandb_logger = WandbLogger(experiment=get_wandb_run(), log_model=False)

Task 3: Report Generation (report_env)

This task reads the metrics.json produced by training and renders interactive Plotly charts - validation accuracy and train/val loss curves - directly in the Union UI. The report=True flag tells Union to render the task output as a rich report panel. A dashed vertical line marks the Phase 1 → Phase 2 transition, making it easy to see how much the backbone fine-tuning contributes.

run.py
@report_env.task(report=True)
async def create_report(results_dir: Dir) -> None:
    """
    Download training metrics and render loss/accuracy curves
    in the Union UI report panel.
    """
    import plotly.graph_objects as go
    from pathlib import Path

    local_dir = Path("/tmp/training_report")
    local_dir.mkdir(parents=True, exist_ok=True)
    await results_dir.download(local_path=str(local_dir))

    matches = list(local_dir.glob("**/metrics.json"))
    if not matches:
        raise RuntimeError(f"metrics.json not found under {local_dir}")
    local_path = matches[0].parent

    history = json.loads((local_path / "metrics.json").read_text())

    epochs = [e["epoch"] for e in history]
    val_acc = [e["val_acc"] for e in history]
    val_loss = [e["val_loss"] for e in history]
    train_loss = [e["train_loss"] for e in history]
    # phase_boundary: first epoch where phase 2 begins (frozen → fine-tune transition)
    phase_boundary = next((e["epoch"] for e in history if e["phase"] == 2), None)

    def add_phase_line(fig):
        if phase_boundary is not None:
            fig.add_vline(
                x=phase_boundary,
                line_dash="dash",
                line_color="gray",
                annotation_text="Phase 2 start",
            )

    acc_fig = go.Figure()
    acc_fig.add_trace(go.Scatter(x=epochs, y=val_acc, mode="lines+markers", name="Val Accuracy"))
    acc_fig.update_layout(title="Validation Accuracy", xaxis_title="Epoch", yaxis_title="Accuracy")
    add_phase_line(acc_fig)

    loss_fig = go.Figure()
    loss_fig.add_trace(go.Scatter(x=epochs, y=train_loss, mode="lines+markers", name="Train Loss"))
    loss_fig.add_trace(go.Scatter(x=epochs, y=val_loss, mode="lines+markers", name="Val Loss"))
    loss_fig.update_layout(title="Loss", xaxis_title="Epoch", yaxis_title="Loss")
    add_phase_line(loss_fig)

    combined_html = (
        acc_fig.to_html(include_plotlyjs=True, full_html=False)
        + loss_fig.to_html(include_plotlyjs=False, full_html=False)
    )
    flyte.report.log(combined_html, do_flush=True)

Task 4: Orchestration (pipeline_env)

The pipeline task is a lightweight orchestrator. It has no heavy dependencies of its own, just enough to call the three tasks above in sequence. The async/await pattern means each task handoff is non-blocking: Union manages scheduling, retries, and data movement between tasks transparently.

run.py
@pipeline_env.task
async def satellite_classification_pipeline() -> None:
    """Orchestrate dataset loading, GPU training, and report generation."""
    dataset_dir = await load_dataset()
    results_dir = await train_model(
        dataset_dir=dataset_dir,
        config_json=json.dumps(TRAINING_CONFIG.to_dict()),
    )
    await create_report(results_dir=results_dir)

Running the Pipeline

Submit the pipeline with a single command from the project directory:

uv run run.py

This calls:

run.py
    run = flyte.with_runcontext(
        custom_context=wandb_config(
            project=TRAINING_CONFIG.wandb_project,
            entity=TRAINING_CONFIG.wandb_entity,
        ),
    ).run(satellite_classification_pipeline)

The W&B project and entity are wired in at submission time. Union handles spinning up the right containers, routing data between tasks, and surfacing results in the UI.

What You Get

After the pipeline completes:

  • Union UI: a report panel with interactive accuracy and loss curves, phase transition marker, and full task logs for each stage.

    Validation Accuracy

    Loss

  • Weights & Biases: a complete experiment run with validation metrics like loss and accuracy, train loss, and t-SNE visualizations of the model’s learned embeddings at configurable epoch intervals. Every few epochs, a t-SNE plot of the validation set embeddings is logged, showing how the model’s feature representations evolve over training. Classes that start as an overlapping cloud gradually pull apart into tight, well-separated clusters as the backbone learns satellite-specific features.

    t-SNE Visualization

  • Model checkpoints: Lightning’s ModelCheckpoint saves the top 3 best-performing checkpoints by validation accuracy, named best-{epoch}-{val_acc}.ckpt. These are standard PyTorch Lightning checkpoints that can be loaded directly for inference.