BERT emotion classification
Code available here.
This tutorial fine-tunes a BERT-style model (ModernBERT by default) on the dair-ai/emotion Twitter dataset for six-way emotion classification: sadness, joy, love, anger, fear, and surprise. The pipeline trains the classifier, evaluates with a confusion matrix and per-class F1, and explores inference with attention and token-importance visualizations in Flyte reports.
Flyte provides:
- GPU fine-tuning with live training loss charts.
- Rich evaluation reports including confusion matrices and confidence bars.
- Cached dataset loading for repeatable experiments.
Define the task environments
import os
main_img = flyte.Image.from_uv_script(__file__, name="bert-fine-tuning-emotion", pre=True)
gpu_env = flyte.TaskEnvironment(
name="bert-fine-tuning-emotion-gpu",
image=main_img,
resources=flyte.Resources(cpu=4, memory="16Gi", gpu=1),
secrets=[flyte.Secret(key="huggingface-token", as_env_var="HF_TOKEN")],
)
cpu_env = flyte.TaskEnvironment(
name="bert-fine-tuning-emotion-cpu",
image=main_img,
resources=flyte.Resources(cpu=2, memory="8Gi"),
depends_on=[gpu_env],
)
HF_TOKEN = os.environ.get("HF_TOKEN")
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "flyte>=2.4.0",
# "torch>=2.1.0",
# "transformers>=4.45.0",
# "datasets>=3.0.0",
# "scikit-learn",
# ...
# ]
# ///Orchestrate the pipeline
@cpu_env.task(report=True)
async def pipeline(
model_name: str = "answerdotai/ModernBERT-base",
epochs: int = 3,
lr: float = 2e-5,
batch_size: int = 16,
warmup_steps: int = 100,
max_train_samples: int = 10000,
max_eval_samples: int = 2000,
num_eval_examples: int = 200,
num_explore_examples: int = 12,
) -> flyte.io.Dir:
"""
ModernBERT emotion classification pipeline.
Returns the fine-tuned model directory (used by serve.py for deployment).
1. Download emotion dataset (6 classes from Twitter text)
2. Fine-tune ModernBERT for sequence classification
3. Evaluate: base vs fine-tuned with confusion matrix
4. Explore inference: attention heatmaps + token importance
Args:
model_name: HuggingFace encoder model to fine-tune.
num_explore_examples: Number of examples for attention/attribution analysis.
"""
log.info(f"Pipeline: {model_name} | emotion classification")
steps = ["Get Data", "Train", "Evaluate", "Explore Inference"]
await flyte.report.replace.aio(
wrap_report(
f"<h2>Emotion Classification Pipeline</h2>"
f"<h3>{model_name}</h3>"
f"{pipeline_step_indicator(0, steps)}"
f'<div class="card"><p>Downloading emotion dataset...</p></div>'
),
do_flush=True,
)
# Step 1: Get data
data_dir = await get_data(max_train_samples, max_eval_samples)
# Step 2: Train
await flyte.report.replace.aio(
wrap_report(
f"<h2>Emotion Classification Pipeline</h2>"
f"<h3>{model_name}</h3>"
f"{pipeline_step_indicator(1, steps)}"
f'<div class="card"><p>Fine-tuning for emotion classification...</p></div>'
),
do_flush=True,
)
finetuned_dir = await train(model_name, data_dir, epochs, lr, batch_size, warmup_steps)
# Step 3: Evaluate
await flyte.report.replace.aio(
wrap_report(
f"<h2>Emotion Classification Pipeline</h2>"
f"<h3>{model_name}</h3>"
f"{pipeline_step_indicator(2, steps)}"
f'<div class="card"><p>Evaluating base vs fine-tuned model...</p></div>'
),
do_flush=True,
)
eval_result = await evaluate(model_name, finetuned_dir, data_dir, num_eval_examples)
eval_metrics = json.loads(eval_result)
# Step 4: Explore inference
await flyte.report.replace.aio(
wrap_report(
f"<h2>Emotion Classification Pipeline</h2>"
f"<h3>{model_name}</h3>"
f"{pipeline_step_indicator(3, steps)}"
f'<div class="card"><p>Analyzing attention patterns and token importance...</p></div>'
),
do_flush=True,
)
explore_result = await explore_inference(finetuned_dir, data_dir, num_explore_examples)
# -- Final report --
improvement = eval_metrics["improvement"]
imp_badge = "badge-success" if improvement > 0 else "badge-danger" if improvement < 0 else "badge-info"
await flyte.report.replace.aio(
wrap_report(
f"<h2>Emotion Classification Pipeline Complete</h2>"
f"<h3>{model_name}</h3>"
f"{pipeline_step_indicator(4, steps)}"
f'<div class="stat-grid">'
f' <div class="stat"><div class="value">{eval_metrics["base_accuracy"]}%</div><div class="label">Base Accuracy</div></div>'
f' <div class="stat"><div class="value">{eval_metrics["finetuned_accuracy"]}%</div><div class="label">Fine-tuned Accuracy</div></div>'
f' <div class="stat"><div class="value"><span class="badge {imp_badge}">{improvement:+.1f}pp</span></div><div class="label">Improvement</div></div>'
f' <div class="stat"><div class="value">{eval_metrics["finetuned_f1"]}%</div><div class="label">Weighted F1</div></div>'
f'</div>'
),
do_flush=True,
)
log.info(f"Pipeline complete. Accuracy improvement: {improvement:+.1f}pp")
return finetuned_dir
Run the workflow
From the example directory:
cd v2/tutorials/bert_fine_tuning_emotion
uv run --script bert_fine_tuning_emotion.pyQuick smoke test with a small sample:
flyte run bert_fine_tuning_emotion.py pipeline --max_train_samples 200 --max_eval_samples 50 --epochs 1Open the evaluate and explore_inference task reports for confusion matrices and attention visualizations.