BERT emotion classification

Code available here.

This tutorial fine-tunes a BERT-style model (ModernBERT by default) on the dair-ai/emotion Twitter dataset for six-way emotion classification: sadness, joy, love, anger, fear, and surprise. The pipeline trains the classifier, evaluates with a confusion matrix and per-class F1, and explores inference with attention and token-importance visualizations in Flyte reports.

Flyte provides:

  • GPU fine-tuning with live training loss charts.
  • Rich evaluation reports including confusion matrices and confidence bars.
  • Cached dataset loading for repeatable experiments.

Define the task environments

bert_fine_tuning_emotion.py
import os

main_img = flyte.Image.from_uv_script(__file__, name="bert-fine-tuning-emotion", pre=True)

gpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-gpu",
    image=main_img,
    resources=flyte.Resources(cpu=4, memory="16Gi", gpu=1),
    secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)

cpu_env = flyte.TaskEnvironment(
    name="bert-fine-tuning-emotion-cpu",
    image=main_img,
    resources=flyte.Resources(cpu=2, memory="8Gi"),
    depends_on=[gpu_env],
)

HF_TOKEN = os.environ.get("HF_TOKEN")
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.4.0",
#    "torch>=2.1.0",
#    "transformers>=4.45.0",
#    "datasets>=3.0.0",
#    "scikit-learn",
#    ...
# ]
# ///

Orchestrate the pipeline

bert_fine_tuning_emotion.py
@cpu_env.task(report=True)
async def pipeline(
    model_name: str = "answerdotai/ModernBERT-base",
    epochs: int = 3,
    lr: float = 2e-5,
    batch_size: int = 16,
    warmup_steps: int = 100,
    max_train_samples: int = 10000,
    max_eval_samples: int = 2000,
    num_eval_examples: int = 200,
    num_explore_examples: int = 12,
) -> flyte.io.Dir:
    """
    ModernBERT emotion classification pipeline.

    Returns the fine-tuned model directory (used by serve.py for deployment).

    1. Download emotion dataset (6 classes from Twitter text)
    2. Fine-tune ModernBERT for sequence classification
    3. Evaluate: base vs fine-tuned with confusion matrix
    4. Explore inference: attention heatmaps + token importance

    Args:
        model_name: HuggingFace encoder model to fine-tune.
        num_explore_examples: Number of examples for attention/attribution analysis.
    """
    log.info(f"Pipeline: {model_name} | emotion classification")
    steps = ["Get Data", "Train", "Evaluate", "Explore Inference"]

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(0, steps)}"
            f'<div class="card"><p>Downloading emotion dataset...</p></div>'
        ),
        do_flush=True,
    )

    # Step 1: Get data
    data_dir = await get_data(max_train_samples, max_eval_samples)

    # Step 2: Train
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(1, steps)}"
            f'<div class="card"><p>Fine-tuning for emotion classification...</p></div>'
        ),
        do_flush=True,
    )

    finetuned_dir = await train(model_name, data_dir, epochs, lr, batch_size, warmup_steps)

    # Step 3: Evaluate
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(2, steps)}"
            f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
        ),
        do_flush=True,
    )

    eval_result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
    eval_metrics = json.loads(eval_result)

    # Step 4: Explore inference
    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(3, steps)}"
            f'<div class="card"><p>Analyzing attention patterns and token importance...</p></div>'
        ),
        do_flush=True,
    )

    explore_result = await explore_inference(finetuned_dir, data_dir, num_explore_examples)

    # -- Final report --
    improvement = eval_metrics["improvement"]
    imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"

    await flyte.report.replace.aio(
        wrap_report(
            f"<h2>Emotion Classification Pipeline Complete</h2>"
            f"<h3>{model_name}</h3>"
            f"{pipeline_step_indicator(4, steps)}"
            f'<div class="stat-grid">'
            f'  <div class="stat"><div class="value">{eval_metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_accuracy"]}%</div><div class="label">Fine-tuned Accuracy</div></div>'
            f'  <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
            f'  <div class="stat"><div class="value">{eval_metrics["finetuned_f1"]}%</div><div class="label">Weighted F1</div></div>'
            f'</div>'
        ),
        do_flush=True,
    )

    log.info(f"Pipeline complete. Accuracy improvement: {improvement:+.1f}pp")
    return finetuned_dir

Run the workflow

From the example directory:

cd v2/tutorials/bert_fine_tuning_emotion
uv run --script bert_fine_tuning_emotion.py

Quick smoke test with a small sample:

flyte run bert_fine_tuning_emotion.py pipeline --max_train_samples 200 --max_eval_samples 50 --epochs 1

Open the evaluate and explore_inference task reports for confusion matrices and attention visualizations.