# MLE Bot: an autonomous ML engineer

> [!NOTE]
> Code available [here](https://github.com/unionai/unionai-examples/tree/main/v2/tutorials/mle_bot).

You have a dataset and a business question. Today, going from a raw CSV to a trained, evaluated model with a written report takes an ML engineer hours of experimentation: profiling the data, picking algorithms, engineering features, tuning hyperparameters, analyzing results, and iterating. What if you could describe the problem in plain English and let an agent handle the rest?

This tutorial walks you through building exactly that. You'll construct an autonomous ML engineer that takes a problem description and a dataset, designs experiments, runs them on cloud infrastructure, analyzes results, iterates, and produces a report summarizing the best model it found.

## TL;DR

- You'll build an agent that takes a natural language problem description and a CSV file, then produces a trained model and a detailed report comparing the results.
- The LLM reasons over dataset statistics, never raw data. Trusted tools compute statistics in the cloud, and only those statistics reach the LLM.
- LLM-generated orchestration code runs inside Flyte's sandbox: no imports, no network access, no filesystem. It can only call pre-approved tool functions.
- Each tool function runs as a durable Flyte task in the cloud, with retries, observability, and full traceability.

## The problem with LLMs and ML pipelines

If you ask an LLM to "train a model on this dataset," you run into a few issues fast. The LLM might hallucinate sklearn APIs that don't exist. It has no access to real compute, so it can't actually train anything. It runs everything in a single context with no way to handle large datasets. And if something goes wrong, there's no structured way to iterate.

The core tension is that LLMs are genuinely good at reasoning about *what* to try. Given a dataset profile showing class imbalance and temporal structure, a capable model will suggest rolling window features and appropriate resampling strategies. But LLMs are unreliable at *executing* those plans. They generate buggy code, lose track of variable names, and have no way to dispatch real compute.

The solution is to separate the two concerns. Let the LLM handle the planning: which algorithms to try, what feature engineering to apply, which hyperparameters to tune. Then hand the execution to trusted tool functions that run on real infrastructure. The LLM controls *what* happens. The tools control *how*.

Think of it like giving a junior engineer access to a curated set of approved tools and reviewing their work. They can compose those tools in creative ways, but they can't go off-script and install random packages or hit arbitrary endpoints.

## How it works

The agent runs in five phases:

1. **Profile** the dataset using a trusted tool. The tool returns statistics (shape, class balance, feature correlations, missing values). The LLM never touches the raw data.
2. **Design** a batch of experiments. The LLM reads the profile and proposes 2 to 3 experiments, each with an algorithm, hyperparameters, and a feature engineering strategy.
3. **Execute** each experiment in parallel. For each one, the LLM generates Python orchestration code that chains together pre-approved tool functions. That code runs inside a restricted sandbox, and each tool call dispatches as a durable Flyte task on cloud compute.
4. **Analyze** the results. The LLM reviews metrics across experiments, optionally requests targeted data explorations (e.g., "are failures concentrated on specific machines?"), and decides whether to iterate with new experiments.
5. **Produce a report** summarizing the winning model, the experiment journey, and deployment recommendations.

Two things make this work. First, the LLM never sees raw data. The profiling tool runs in the cloud on managed compute and returns only aggregated statistics. This keeps prompt sizes manageable and avoids leaking sensitive data into LLM API calls. Second, the LLM-generated code runs inside Flyte's sandbox where the only thing it can do is call your pre-approved tool functions. More on that shortly.

### What to expect

Here's what an actual run looks like on a synthetic predictive maintenance dataset (175k rows of sensor data from 20 industrial pumps, ~3% failure rate).

In the first iteration, the agent designed three experiments: a logistic regression baseline, an XGBoost model with rolling window features, and a random forest with lag features. After reviewing the results, it decided to continue. It requested two targeted explorations ("do failure cases show meaningfully higher vibration?" and "how do feature-target correlations vary by pump?"), then used those findings to design a second round of experiments with tuned feature engineering and class weighting.

After two iterations and five total experiments, the final rankings looked like this:

| Rank | Experiment | ROC-AUC | F1 | Recall | Precision |
|------|-----------|---------|------|--------|-----------|
| 1 | **Random Forest with Balanced Class Weights** | 0.7983 | 0.4284 | 0.4561 | 0.4038 |
| 2 | XGBoost with Feature Engineering | 0.7847 | 0.4568 | 0.4722 | 0.4425 |
| 3 | Enhanced XGBoost with Focused Feature Engineering | 0.7821 | 0.3565 | 0.4973 | 0.2778 |
| 4 | Random Forest with Lag Features | 0.7651 | 0.5206 | 0.4104 | 0.7116 |
| 5 | Baseline Logistic Regression | 0.7528 | 0.118 | 0.6496 | 0.0649 |

The agent autonomously explored different algorithms, feature strategies, and class imbalance techniques, then ranked everything by ROC-AUC. The full report includes the LLM's reasoning and generated code for every experiment, so you can trace exactly why it chose each approach and what code it wrote to implement it. Since the LLM makes different decisions each run, your results will vary, but the overall pattern (profile, design, execute, analyze, iterate) stays the same.

## Declaring task environments

Before writing any tasks, you need to declare *where* and *how* they run. In Flyte v2, a `TaskEnvironment` bundles together the container image, resource requirements, secrets, and dependencies for a group of tasks.

The MLE Bot uses two environments. One for the ML tools (pandas, sklearn, xgboost) and one for the agent itself (the OpenAI client and the sandbox runtime):

```
"""Flyte TaskEnvironment definitions for mle-bot.

Two environments:
- tool_env: Runs the ML tools (data loading, feature engineering, training, evaluation).
            Has sklearn, xgboost, pandas, numpy, joblib.
- agent_env: Runs the orchestrating agent (OpenAI calls, sandbox orchestration).
             Has openai, pydantic-monty. Depends on tool_env.
"""

# {{docs-fragment environments}}
import flyte

tool_env = flyte.TaskEnvironment(
    "mle-tools",
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    image=(
        flyte.Image.from_debian_base(name="mle-tools-image").with_pip_packages(
            "pandas>=2.0.0",
            "scikit-learn>=1.3.0",
            "xgboost>=2.0.0",
            "numpy>=1.24.0",
            "joblib>=1.3.0",
        )
    ),
)

agent_env = flyte.TaskEnvironment(
    "mle-agent",
    resources=flyte.Resources(cpu=1, memory="2Gi"),
    secrets=[flyte.Secret(key="OPENAI_API_KEY", as_env_var="OPENAI_API_KEY")],
    env_vars={"PYTHONUNBUFFERED": "1"},
    image=(
        flyte.Image.from_debian_base(name="mle-agent-image")
        .with_apt_packages("git")
        .with_pip_packages(
            "openai>=1.0.0",
            "flyte[sandbox]",
        )
    ),
    depends_on=[tool_env],
)
# {{/docs-fragment environments}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/environments.py*

A few things to note. `flyte.Resources` sets the CPU and memory for every task in that environment. `flyte.Image.from_debian_base()` builds a container image on the fly with the packages you declare, so you never need to manage Dockerfiles. `flyte.Secret` injects a secret from your cluster's secret store as an environment variable. And `depends_on=[tool_env]` tells Flyte that the agent environment needs to be able to dispatch tasks in the tool environment. This is what enables the sandbox to call tool functions that run on separate, appropriately-resourced compute.

## Building durable tool functions

Each tool is a regular Python async function decorated with `@env.task`. That decorator turns it into a durable Flyte task: it runs in its own container with the resources declared on the environment, it's automatically retried on transient failures, and every invocation is tracked in the Flyte UI.

Data flows between tasks as `flyte.io.File` objects. A `File` is a reference to data in cloud storage. When a task needs the actual bytes, it calls `await data.download()` to pull them into the container's local filesystem. When it produces output, it creates a `File` from a local path and returns it. Flyte handles the upload to cloud storage when the task completes. The data itself never passes through the agent or the LLM.

Here's what the training tool looks like:

```
"""Model training tools.

A single unified interface for training classifiers with different algorithms.
The tool handles serialization, class imbalance, and basic hyperparameter passing.
"""

from flyte.io import File

from mle_bot.environments import tool_env
from mle_bot.schemas import (
    GradientBoostingParams,
    LogisticRegressionParams,
    RandomForestParams,
    XGBoostParams,
)

# {{docs-fragment train_model}}
@tool_env.task
async def train_model(
    data: File,
    target_column: str,
    algorithm: str,
    hyperparams: dict,
) -> File:
    """Train a classification model and return the serialized model and training metrics.

    Supports multiple algorithms through a single interface so the agent can
    dispatch different approaches without knowing implementation details.

    Args:
        data: CSV file with training data (features + target column).
        target_column: Name of the column to predict.
        algorithm: One of:
            "xgboost"            — Gradient boosted trees. Good default for tabular data.
                                   Handles missing values and class imbalance natively.
            "random_forest"      — Ensemble of decision trees. More robust to outliers.
            "logistic_regression"— Linear model. Fast baseline, good for linearly separable problems.
            "gradient_boosting"  — Sklearn GradientBoostingClassifier. Slower than xgboost
                                   but sometimes better on small datasets.
        hyperparams: Dict of algorithm-specific hyperparameters. Common keys:
            For xgboost / gradient_boosting:
                n_estimators (int, default 100): Number of trees.
                max_depth (int, default 6): Maximum tree depth.
                learning_rate (float, default 0.1): Step size shrinkage.
                scale_pos_weight (float): Ratio negative/positive — use for imbalanced data.
                                          Set to (n_negative / n_positive) to upweight minority class.
                subsample (float, default 1.0): Fraction of samples used per tree.
                colsample_bytree (float, default 1.0): Fraction of features per tree.
            For random_forest:
                n_estimators (int, default 100): Number of trees.
                max_depth (int or null, default null): Maximum tree depth (null = unlimited).
                min_samples_leaf (int, default 1): Minimum samples at a leaf node.
                class_weight (str, default "balanced"): "balanced" reweights by class frequency.
            For logistic_regression:
                C (float, default 1.0): Inverse regularization strength (higher = less regularization).
                max_iter (int, default 1000): Maximum iterations for solver.
                class_weight (str, default "balanced"): "balanced" reweights by class frequency.

    Returns:
        File — serialized model (joblib format, contains model + feature columns + target column).
    """
    # {{/docs-fragment train_model}}
    import tempfile

    import joblib
    import numpy as np
    import pandas as pd
    from flyte.io import File as FlyteFile
    from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score

    path = await data.download()
    df = pd.read_csv(path)

    # Only use numeric columns — drop strings like machine_id automatically
    feature_cols = [c for c in df.select_dtypes(include=[np.number]).columns if c != target_column]
    X = df[feature_cols].values
    y = df[target_column].values

    class_dist = {str(k): int(v) for k, v in zip(*np.unique(y, return_counts=True))}
    n_positive = int((y == 1).sum())
    n_negative = int((y == 0).sum())
    default_scale = max(1.0, n_negative / n_positive) if n_positive > 0 else 1.0

    if algorithm == "xgboost":
        from xgboost import XGBClassifier
        p = XGBoostParams.model_validate({**hyperparams, "scale_pos_weight": hyperparams.get("scale_pos_weight", default_scale)})
        params = {**p.model_dump(), "eval_metric": "logloss", "random_state": 42}
        model = XGBClassifier(**params)

    elif algorithm == "random_forest":
        p = RandomForestParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42, "n_jobs": -1}
        model = RandomForestClassifier(**params)

    elif algorithm == "gradient_boosting":
        p = GradientBoostingParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42}
        model = GradientBoostingClassifier(**params)

    elif algorithm == "logistic_regression":
        p = LogisticRegressionParams.model_validate(hyperparams)
        params = {**p.model_dump(), "random_state": 42}
        model = LogisticRegression(**params)

    else:
        raise ValueError(f"Unknown algorithm: {algorithm!r}. Choose from: xgboost, random_forest, gradient_boosting, logistic_regression")

    model.fit(X, y)
    y_pred = model.predict(X)
    y_prob = model.predict_proba(X)[:, 1] if hasattr(model, "predict_proba") else y_pred

    train_metrics = {
        "accuracy": round(float(accuracy_score(y, y_pred)), 4),
        "f1": round(float(f1_score(y, y_pred, average="binary", zero_division=0)), 4),
        "precision": round(float(precision_score(y, y_pred, average="binary", zero_division=0)), 4),
        "recall": round(float(recall_score(y, y_pred, average="binary", zero_division=0)), 4),
        "roc_auc": round(float(roc_auc_score(y, y_prob)), 4),
    }

    # Feature importance (top 20)
    if hasattr(model, "feature_importances_"):
        importances = model.feature_importances_
        importance_dict = {feature_cols[i]: round(float(importances[i]), 4) for i in range(len(feature_cols))}
        importance_dict = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)[:20])
    elif hasattr(model, "coef_"):
        importances = abs(model.coef_[0])
        importance_dict = {feature_cols[i]: round(float(importances[i]), 4) for i in range(len(feature_cols))}
        importance_dict = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)[:20])
    else:
        importance_dict = {}

    model_file = tempfile.NamedTemporaryFile(suffix=".joblib", delete=False)
    joblib.dump({"model": model, "feature_columns": feature_cols, "target_column": target_column}, model_file.name)
    model_file.close()

    return await FlyteFile.from_local(model_file.name)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/tools/training.py*

And here's the profiling tool, which is the first thing the agent calls. It computes dataset statistics that the LLM will use to design experiments:

```
"""Data loading, profiling, and splitting tools.

These tools are safe, general-purpose, and side-effect free.
They run as durable Flyte tasks so they execute in the cloud on managed compute.
"""

from flyte.io import File

from mle_bot.environments import tool_env

# {{docs-fragment profile_dataset}}
@tool_env.task
async def profile_dataset(data: File, target_column: str) -> dict:
    """Profile a dataset and return statistics that inform ML problem design.

    Call this first before designing any experiments. The returned profile tells
    you the shape, column types, class balance, missing values, and numeric
    statistics — everything needed to choose algorithms and feature strategies.

    Args:
        data: CSV file to profile.
        target_column: Name of the column to predict.

    Returns a dict with keys:
        - shape: [n_rows, n_cols]
        - columns: list of all column names
        - dtypes: {col: dtype_string, ...}
        - numeric_columns: list of numeric column names (excluding target)
        - categorical_columns: list of non-numeric column names (excluding target)
        - target_distribution: {class_value: count, ...}
        - class_balance: {class_value: pct, ...}  (proportions, sum to 100)
        - missing_pct: {col: pct_missing, ...}
        - numeric_stats: {col: {mean, std, min, max, median}, ...}
        - n_classes: int — number of unique target values
        - is_imbalanced: bool — True if minority class < 20% of data
        - sample: list of 3 example rows as dicts
    """
    import numpy as np
    import pandas as pd

    path = await data.download()
    df = pd.read_csv(path)

    target_counts = df[target_column].value_counts()
    class_balance = (df[target_column].value_counts(normalize=True) * 100).round(2).to_dict()
    minority_pct = float(min(class_balance.values()))

    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    categorical_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()

    numeric_stats = {}
    for col in numeric_cols:
        if col == target_column:
            continue
        numeric_stats[col] = {
            "mean": round(float(df[col].mean()), 4),
            "std": round(float(df[col].std()), 4),
            "min": round(float(df[col].min()), 4),
            "max": round(float(df[col].max()), 4),
            "median": round(float(df[col].median()), 4),
        }

    # Point-biserial correlation between each numeric feature and the target
    feature_target_corr = {}
    for col in numeric_cols:
        if col == target_column:
            continue
        corr = float(df[col].corr(df[target_column]))
        if not np.isnan(corr):
            feature_target_corr[col] = round(corr, 4)
    # Sort by absolute correlation descending
    feature_target_corr = dict(
        sorted(feature_target_corr.items(), key=lambda x: abs(x[1]), reverse=True)
    )

    return {
        "shape": list(df.shape),
        "columns": list(df.columns),
        "dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
        "numeric_columns": [c for c in numeric_cols if c != target_column],
        "categorical_columns": [c for c in categorical_cols if c != target_column],
        "target_distribution": {str(k): int(v) for k, v in target_counts.items()},
        "class_balance": {str(k): float(v) for k, v in class_balance.items()},
        "missing_pct": {col: round(float(pct * 100), 2) for col, pct in df.isnull().mean().items()},
        "numeric_stats": numeric_stats,
        "feature_target_corr": feature_target_corr,
        "n_classes": int(df[target_column].nunique()),
        "is_imbalanced": minority_pct < 20.0,
        "sample": df.head(3).fillna("").to_dict(orient="records"),
    }
# {{/docs-fragment profile_dataset}}

@tool_env.task
async def split_dataset(
    data: File,
    target_column: str,
    test_size: float,
    time_column: str,
    split_type: str,
) -> File:
    """Split a dataset and return either the train or test half.

    Call this twice — once with split_type="train" and once with split_type="test" —
    to get both halves. Always split before feature engineering to prevent data leakage.

    Args:
        data: CSV file to split.
        target_column: Name of the column to predict.
        test_size: Fraction of data to use for testing (e.g. 0.2 for 20%).
        time_column: If non-empty, sort by this column and take the last
                     `test_size` fraction as test (time-based split, no shuffling).
                     If empty string "", use stratified random split.
        split_type: Which half to return — "train" or "test".

    Returns:
        File — CSV file containing the requested split (train or test rows).
    """
    import tempfile

    import pandas as pd
    from flyte.io import File as FlyteFile
    from sklearn.model_selection import train_test_split

    path = await data.download()
    df = pd.read_csv(path)

    if time_column:
        df = df.sort_values(time_column).reset_index(drop=True)
        split_idx = int(len(df) * (1 - test_size))
        train_df = df.iloc[:split_idx]
        test_df = df.iloc[split_idx:]
    else:
        train_df, test_df = train_test_split(
            df,
            test_size=test_size,
            stratify=df[target_column],
            random_state=42,
        )

    selected_df = train_df if split_type == "train" else test_df

    out = tempfile.NamedTemporaryFile(suffix=".csv", delete=False)
    selected_df.to_csv(out.name, index=False)
    out.close()
    return await FlyteFile.from_local(out.name)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/tools/data.py*

The full tool inventory includes ten functions: `profile_dataset`, `split_dataset`, `explore_dataset`, `engineer_features`, `select_features`, `resample_dataset`, `train_model`, `get_predictions`, `evaluate_model`, and `rank_experiments`. Each one does exactly one thing. The LLM composes them into pipelines, but each tool enforces its own correctness guarantees internally. For example, `resample_dataset` only applies resampling to training data, never test data, regardless of what the LLM asks for.

## Guiding the LLM with domain knowledge

The quality of the agent's experiments depends heavily on what you tell it. The MLE Bot bakes ML best practices directly into its system prompts, so the LLM starts from a solid foundation rather than relying on whatever it picked up during pretraining.

The orchestration prompt, for example, includes guidance on feature engineering strategies, class imbalance handling, and algorithm selection. It's dynamically built from the dataset profile, so the LLM sees concrete context alongside the general advice:

```python
def _build_orchestration_system_prompt(profile: dict) -> str:
    return f"""\
You are an expert ML engineer. Your job is to design and write the best
possible pipeline for a machine learning experiment.

## Dataset context
Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw): {corr_str}

## General ML best practices
**Feature engineering**:
- Sequential/time-series data: rolling window features capture trends
  that point-in-time readings miss. Choose window sizes relative to
  the prediction horizon and temporal resolution of the data.
- Consider skipping feature engineering entirely for a baseline.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight.
- The default 0.5 decision threshold may not be optimal.

**Algorithm selection**:
- XGBoost: strong default for tabular data. Start here.
- RandomForest: more robust to outliers, good for noisy data.
- LogisticRegression: fast linear baseline.
...
"""
```

This means the LLM doesn't just get a blank canvas. It gets a structured briefing that combines the actual dataset characteristics with best practices for handling them. When the profile shows class imbalance, the prompt tells it which hyperparameters to adjust and which resampling strategies to consider. When there's a timestamp column, the prompt suggests rolling window features with guidance on window sizing.

The user's problem description also has a significant impact on the agent's behavior. A query like "Predict pump failures 24 hours before they happen based on sensor readings" tells the LLM that this is a time-series classification problem with a specific prediction horizon. That shapes everything: the LLM will favor temporal feature engineering (rolling windows sized relative to that 24-hour horizon), pick algorithms that handle imbalanced binary classification well, and focus on recall as a key metric because missing a failure is worse than a false alarm. Change the query to something like "Classify machine health status from the latest sensor snapshot" and the same dataset would produce a completely different set of experiments, with less emphasis on temporal features and more on cross-sectional patterns.

## The agent loop: profile, design, execute, iterate

The agent's main function orchestrates five phases. Let's walk through each one.

**Phase 1: Profile.** The agent calls `profile_dataset` directly as a trusted tool. This isn't sandboxed because there's nothing to protect against here: the function is your code, running on your compute. The `flyte.group` call organizes this step in the Flyte UI so you can inspect it later.

```python
with flyte.group("profile"):
    profile = await profile_dataset(data, target_column)
```

**Phase 2: Design.** The profile dict goes to the LLM along with the problem description. The LLM returns a structured response matching the `InitialDesign` schema:

```
"""Pydantic schemas for tool inputs and agent data structures.

These models define the expected shape of configs and results throughout the agent.

Important: Tool functions that are called from the Monty sandbox must accept plain
`dict` at the boundary (Monty can't import or instantiate classes). Each tool parses
its incoming dict into the appropriate model internally for validation. In agent.py,
use `.model_dump()` to convert models back to dicts before passing to the sandbox.
"""

from typing import Literal

from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Feature engineering
# ---------------------------------------------------------------------------

class FeatureConfig(BaseModel):
    """Configuration for the engineer_features tool."""

    group_column: str = Field(
        default="",
        description="Column to group by for rolling/lag features (e.g. 'machine_id'). "
                    "Required when rolling_columns or lag_columns is specified.",
    )
    time_column: str = Field(
        default="",
        description="Timestamp column to sort by before computing rolling/lag features.",
    )
    rolling_columns: list[str] = Field(
        default_factory=list,
        description="Numeric columns to compute rolling statistics for (mean, std, min, max).",
    )
    windows: list[int] = Field(
        default_factory=list,
        description="Rolling window sizes in rows (e.g. [6, 12, 24]).",
    )
    lag_columns: list[str] = Field(
        default_factory=list,
        description="Numeric columns to create lag features for.",
    )
    lags: list[int] = Field(
        default_factory=list,
        description="Lag steps in rows (e.g. [1, 3, 6]).",
    )
    normalize: bool = Field(
        default=False,
        description="If true, z-score normalize all numeric columns except target_column.",
    )
    target_column: str = Field(
        default="",
        description="Column to exclude from normalization. Required when normalize=True.",
    )
    drop_columns: list[str] = Field(
        default_factory=list,
        description="Columns to remove from output (e.g. raw timestamp after rolling).",
    )
    fillna_method: Literal["forward", "zero", "drop"] = Field(
        default="forward",
        description="How to fill NaN values introduced by rolling/lag. "
                    "'forward' forward-fills then fills remaining with 0. "
                    "'zero' fills all NaN with 0. 'drop' drops rows with NaN.",
    )

# ---------------------------------------------------------------------------
# Training hyperparameters (per algorithm)
# ---------------------------------------------------------------------------

class XGBoostParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int = Field(default=6, ge=1, le=20)
    learning_rate: float = Field(default=0.1, gt=0, le=1)
    scale_pos_weight: float = Field(
        default=1.0, ge=0,
        description="Set to n_negative/n_positive for imbalanced datasets.",
    )
    subsample: float = Field(default=1.0, gt=0, le=1)
    colsample_bytree: float = Field(default=1.0, gt=0, le=1)

class RandomForestParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int | None = Field(
        default=None,
        description="Maximum tree depth. None means unlimited.",
    )
    min_samples_leaf: int = Field(default=1, ge=1)
    class_weight: Literal["balanced", "balanced_subsample"] | None = Field(default="balanced")

class GradientBoostingParams(BaseModel):
    n_estimators: int = Field(default=100, ge=1)
    max_depth: int = Field(default=3, ge=1, le=10)
    learning_rate: float = Field(default=0.1, gt=0, le=1)
    subsample: float = Field(default=1.0, gt=0, le=1)

class LogisticRegressionParams(BaseModel):
    C: float = Field(default=1.0, gt=0, description="Inverse regularization strength.")
    max_iter: int = Field(default=1000, ge=100)
    class_weight: Literal["balanced"] | None = Field(default="balanced")

# ---------------------------------------------------------------------------
# Experiment design (used by agent.py, validated when parsing LLM JSON)
# ---------------------------------------------------------------------------

Algorithm = Literal["xgboost", "random_forest", "gradient_boosting", "logistic_regression"]

# {{docs-fragment schemas}}
class ExperimentConfig(BaseModel):
    """One experiment to run — produced by the LLM and executed by the agent."""

    name: str = Field(description="Short descriptive name for this experiment.")
    algorithm: Algorithm
    hyperparams: dict = Field(
        default_factory=dict,
        description="Algorithm-specific hyperparameters. Will be validated inside train_model.",
    )
    feature_config: FeatureConfig = Field(default_factory=FeatureConfig)
    rationale: str = Field(default="", description="Why this experiment is worth running.")

class InitialDesign(BaseModel):
    """LLM response for initial experiment design."""

    problem_type: str = Field(default="binary_classification")
    primary_metric: Literal["roc_auc", "f1", "recall"] = Field(default="roc_auc")
    reasoning: str
    experiments: list[ExperimentConfig]

class IterationDecision(BaseModel):
    """LLM response after analyzing experiment results."""

    should_continue: bool
    reasoning: str
    exploration_requests: list[dict] = Field(
        default_factory=list,
        description="Optional list of explore_dataset config dicts to run before designing "
                    "the next batch. Each dict is passed directly to explore_dataset.",
    )
    next_experiments: list[ExperimentConfig] = Field(default_factory=list)
# {{/docs-fragment schemas}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/schemas.py*

The LLM typically proposes 2 to 3 experiments: a baseline with minimal feature engineering, an experiment with rolling window features for temporal data, and perhaps one testing a different algorithm or resampling strategy.

**Phase 3: Execute in parallel.** All experiments in a batch run simultaneously using `asyncio.gather()`. Each experiment dispatches its own set of durable Flyte tasks:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

**Phase 4: Analyze and iterate.** After each batch completes, the LLM reviews the results and decides whether to continue. It can optionally request targeted data explorations before designing the next round. If the LLM requests explorations (e.g., "do failure cases show higher vibration readings?"), the agent runs `explore_dataset` with those configurations, feeds the results back to the LLM, and lets it refine the next batch of experiments based on what it learned. The loop continues until the LLM decides to stop, the target metric threshold is reached, or the maximum number of iterations is exhausted.

## Running LLM-generated code in Flyte's sandbox

This is where it gets interesting. The LLM doesn't just pick parameters from a dropdown. For each experiment, it writes actual Python code that decides how to compose the tool functions into a pipeline. Maybe it splits the data, engineers rolling window features, applies SMOTE resampling on the training split, trains an XGBoost model, and evaluates it. Or maybe it skips feature engineering entirely for a baseline. The LLM decides the structure.

That code runs inside Flyte's sandbox, a restricted execution environment that enforces strict constraints:

- No `import` statements. The only callable functions are the ones you explicitly provide.
- No network access and no filesystem access.
- No `try`/`except`, no `class` definitions, no augmented assignment (`+=`).
- No `with` statements, no generators, no `global`/`nonlocal`.

The sandbox sees your pre-approved tool functions as plain function calls. When the code calls `train_model(...)`, the sandbox pauses execution, dispatches the call to Flyte (which runs it as a durable task on cloud compute with the resources declared on `tool_env`), waits for the result, and resumes. The LLM-generated code looks like synchronous Python, but under the hood each tool call is a full Flyte task execution.

Here's how the sandbox is invoked:

```python
result = await flyte.sandbox.orchestrate_local(
    code,
    inputs={
        "data": data,
        "target_column": target_column,
        "time_column": time_column,
        "experiment_name": exp_name,
    },
    tasks=TOOLS,
)
```

The `code` parameter is a string of Python generated by the LLM. `inputs` provides the variables that the code can reference. `tasks` is the allowlist: a list of Flyte task functions that the code is permitted to call. Nothing else is available.

Here's an example of what the LLM might generate for a single experiment:

```python
train_file = split_dataset(data, target_column, 0.2, time_column, "train")
test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

eng_train = engineer_features(train_file, {
    "rolling_columns": ["vibration_mms", "temperature_c"],
    "windows": [6, 12, 24],
    "group_column": "machine_id",
    "time_column": "timestamp"
})
eng_test = engineer_features(test_file, {
    "rolling_columns": ["vibration_mms", "temperature_c"],
    "windows": [6, 12, 24],
    "group_column": "machine_id",
    "time_column": "timestamp"
})

model_file = train_model(eng_train, target_column, "xgboost", {
    "n_estimators": 200, "max_depth": 8, "scale_pos_weight": 33
})
eval_result = evaluate_model(model_file, eng_test, target_column)

{"experiment_name": experiment_name, "algorithm": "xgboost",
 "metrics": eval_result["metrics"],
 "confusion_matrix": eval_result["confusion_matrix"],
 "threshold_analysis": eval_result["threshold_analysis"],
 "n_samples": eval_result["n_samples"]}
```

Each function call in that snippet dispatches a separate Flyte task. The `split_dataset` calls run on the tool environment's compute (2 CPU, 4Gi memory). The `train_model` call trains an actual XGBoost model. The last expression (a dict literal) is returned as the sandbox result.

Sometimes the LLM generates code with bugs, like a wrong variable name or a missing argument. The agent handles this with a retry loop. If the sandbox raises an exception, the error message and the failing code are fed back to the LLM, which gets a chance to fix the issue:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

On the first attempt, `previous_error` and `previous_code` are empty. On subsequent attempts, the LLM sees exactly what went wrong and can fix it. In practice, most experiments succeed on the first try, with occasional recoveries on the second.

## Streaming results to a live report

While the agent runs, it streams results to the Flyte UI in real time using `flyte.report.log.aio()`. You don't have to wait for the full run to finish to see how experiments are performing.

The entrypoint task enables this with `report=True`:

```
"""MLE Agent — orchestrates ML experiments using Flyte's durable sandbox.

The agent:
  1. Profiles the dataset using a trusted tool (data never touches the LLM).
  2. Asks OpenAI to design a set of experiments (algorithms, hyperparams, feature strategy).
  3. For each experiment, generates Monty orchestration code and executes it via
     flyte.sandbox.orchestrate_local(), which dispatches the heavy compute as durable tasks.
  4. Analyzes results, iterates if needed.
  5. Produces a model card summarizing the winning model.

The Monty sandbox ensures the LLM-generated orchestration code is safe — it can only
call the pre-approved tool functions and has no access to imports, network, or filesystem.
"""

import asyncio
import inspect
import json
import os
import textwrap
from dataclasses import dataclass

import flyte
import flyte.sandbox
from flyte.io import File

from mle_bot.schemas import ExperimentConfig, InitialDesign, IterationDecision

from mle_bot.environments import agent_env
from mle_bot.tools.data import profile_dataset, split_dataset
from mle_bot.tools.evaluation import evaluate_model, rank_experiments
from mle_bot.tools.exploration import explore_dataset
from mle_bot.tools.features import engineer_features
from mle_bot.tools.predictions import get_predictions
from mle_bot.tools.resampling import resample_dataset
from mle_bot.tools.selection import select_features
from mle_bot.tools.training import train_model

# {{docs-fragment tools}}
# All tools exposed to the sandbox.
# Keys must match the function names used in LLM-generated orchestration code.
TOOLS = [
    profile_dataset, split_dataset, explore_dataset,
    engineer_features, resample_dataset, select_features,
    train_model, get_predictions, evaluate_model, rank_experiments,
]
TOOLS_BY_NAME = {t.func.__name__ if hasattr(t, "func") else t.__name__: t for t in TOOLS}
# {{/docs-fragment tools}}

# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------

def _tool_signatures() -> str:
    """Build a summary of available tool signatures and docstrings for the system prompt."""
    parts = []
    for t in TOOLS:
        func = t.func if hasattr(t, "func") else t
        sig = inspect.signature(func)
        doc = inspect.getdoc(func) or ""
        # Trim docstring to first 40 lines so prompt stays manageable
        doc_lines = doc.splitlines()[:40]
        short_doc = "\n    ".join(doc_lines)
        parts.append(f"async def {func.__name__}{sig}:\n    \"\"\"{short_doc}\"\"\"\n    ...")
    return "\n\n".join(parts)

# {{docs-fragment orchestration_prompt}}
def _build_orchestration_system_prompt(profile: dict) -> str:
    monty_rules = flyte.sandbox.ORCHESTRATOR_SYNTAX_PROMPT
    tools_section = _tool_signatures()
    is_imbalanced = profile.get("is_imbalanced", False)
    class_balance = profile.get("class_balance", {})
    columns = profile.get("columns", [])
    numeric_cols = profile.get("numeric_columns", [])
    categorical_cols = profile.get("categorical_columns", [])
    corr = profile.get("feature_target_corr", {})
    corr_str = ", ".join(f"{k}: {v:+.3f}" for k, v in list(corr.items())[:8]) if corr else "n/a"
    shape = profile.get("shape", [0, 0])
    return f"""\
You are an expert ML engineer. Your job is to design and write the best possible
pipeline for a machine learning experiment, then generate the Python orchestration
code to execute it.

The code runs inside a restricted sandbox. The last expression in your code
is returned as the result. All tool calls are made like regular function calls —
you do NOT need to await them.

## Dataset context

Shape: {shape[0]:,} rows × {shape[1]} columns
Numeric features: {numeric_cols}
Categorical features (excluded from model — not supported): {categorical_cols}
Class balance: {class_balance}, imbalanced: {is_imbalanced}
Feature-target correlations (raw, point-biserial): {corr_str}

## General ML best practices — apply these based on the dataset context above

**Feature engineering** (engineer_features tool):
- Sequential/time-series data (timestamp column present, rows ordered over time):
  rolling window features (means, stds, min/max) capture trends that point-in-time
  readings miss. Lag features capture recent history. Choose window sizes relative
  to the prediction horizon and temporal resolution of the data.
- Tabular cross-sectional data: normalization helps linear models and distance-based
  methods. Interaction terms can help if correlations are weak individually.
- Consider skipping feature engineering entirely for a baseline — it establishes
  whether raw features already carry enough signal.

**Class imbalance** (when is_imbalanced=true):
- Tree ensembles: use class_weight="balanced" or scale_pos_weight=n_neg/n_pos.
- Threshold: the default 0.5 decision threshold may not be optimal — the model's
  probability output is what matters, threshold is tuned at deployment time.
- Metric: ROC-AUC is robust to imbalance; avg_precision is better when positives
  are very rare.

**Algorithm selection**:
- XGBoost / GradientBoosting: strong default for tabular data, handles missing
  values, built-in imbalance handling. Start here unless data is very small.
- RandomForest: more robust to outliers, good for noisy data, parallelizes well.
- LogisticRegression: fast linear baseline. Useful to establish whether the
  problem is linearly separable before adding complexity.
- Prefer simpler models when n_samples < 5,000 to avoid overfitting.

**Resampling** (resample_dataset tool) — data-level imbalance handling:
- Use when class_weight/scale_pos_weight alone isn't improving recall adequately,
  or when you want to test whether data-level vs algorithm-level imbalance handling
  works better for this dataset.
- ONLY resample the TRAIN split — never test. Resampling test data gives misleading metrics.
- "oversample": fast, no new information, good first try.
- "smote": synthetic samples via interpolation — more diverse than random oversampling,
  better for high-dimensional or sparse minority classes.
- "undersample": loses majority data — only useful when majority class is very large
  and training speed is a concern.

**Feature selection** (select_features tool) — prune after feature engineering:
- Use after engineer_features when the feature count is large (20+) and you suspect
  many features are redundant or noisy (e.g., rolling stats at many window sizes).
- "mutual_info": ranks by non-linear dependence with target — best general choice.
- "variance_threshold": drops near-constant features — cheap first pass.
- "correlation_filter": drops redundant features that are highly correlated with
  each other — useful when many rolling windows capture the same trend.
- Can be applied before or after splitting. Apply the same selection to both train
  and test to ensure the model sees the same features at evaluation time.

**Prediction output** (get_predictions tool) — enables two advanced patterns:
1. Error analysis: train a model → get_predictions(model, test_file, target) →
   explore_dataset(predictions_file, {{"class_distributions": ["feature_x"],
   "target_column": "correct"}}) to see which examples the model gets wrong.
   Use this to inform feature engineering for the next iteration.
2. Stacking: train base_model → get_predictions(base_model, train_file, target) →
   train a meta_model on the predictions CSV (use "predicted_prob" as a feature
   alongside original features) → evaluate meta_model on test.
   get_predictions returns a CSV with columns: all originals + predicted_prob,
   predicted_class, correct.

**Pipeline structure** — you are not required to follow a fixed sequence.
Design what makes sense for this specific experiment.

## Available tools

{tools_section}

## Monty sandbox restrictions

{monty_rules}

## Critical patterns for using tool results

split_dataset returns a File — call it twice:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")

engineer_features returns a File — chain calls freely:
    eng = engineer_features(train_file, {{"rolling_columns": [...], "windows": [...]}})
    eng2 = engineer_features(eng, {{"normalize": true, "target_column": target_column}})

train_model returns a File — pass directly to evaluate_model:
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)

evaluate_model returns a dict — subscript reads are allowed:
    roc = eval_result["metrics"]["roc_auc"]

Do NOT use augmented assignment (+=), subscript assignment (d["k"]=v), or try/except.
Build dicts as literals only. The last expression (no assignment) is the return value.

## When fixing a previous error

Read the error and the failing code carefully before writing a fix. Identify the root
cause — do not just change variable names or add no-ops. Trace what each tool returns,
what each subsequent call expects, and where the mismatch is. Then fix the underlying
logic, not just the surface symptom.

## Pipeline design — you decide the structure

You are NOT required to follow a fixed sequence. Design the pipeline that makes most
sense for the experiment. Examples of valid approaches:

Baseline (no feature engineering):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Two-stage feature engineering (rolling then normalize separately):
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    rolled_train = engineer_features(train_file, {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    rolled_test  = engineer_features(test_file,  {{"rolling_columns": ["vibration"], "windows": [6, 24]}})
    eng_train = engineer_features(rolled_train, {{"normalize": true, "target_column": target_column}})
    eng_test  = engineer_features(rolled_test,  {{"normalize": true, "target_column": target_column}})
    model_file = train_model(eng_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Compare two class weightings and return the better model:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file = split_dataset(data, target_column, 0.2, time_column, "test")
    model_a = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 10}})
    model_b = train_model(train_file, target_column, "xgboost", {{"n_estimators": 100, "scale_pos_weight": 33}})
    eval_a = evaluate_model(model_a, test_file, target_column)
    eval_b = evaluate_model(model_b, test_file, target_column)
    best_eval = eval_a if eval_a["metrics"]["roc_auc"] > eval_b["metrics"]["roc_auc"] else eval_b
    {{"experiment_name": experiment_name, "algorithm": "xgboost", "metrics": best_eval["metrics"], "confusion_matrix": best_eval["confusion_matrix"], "threshold_analysis": best_eval["threshold_analysis"], "n_samples": best_eval["n_samples"]}}

SMOTE oversampling before training:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms"], "windows": [6, 12]}})
    resampled_train = resample_dataset(eng_train, target_column, {{"strategy": "smote", "target_ratio": 0.2}})
    model_file = train_model(resampled_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, eng_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Feature engineering followed by feature selection:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    eng_train  = engineer_features(train_file, {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    eng_test   = engineer_features(test_file,  {{"rolling_columns": ["vibration_mms", "temperature_c"], "windows": [6, 12, 24]}})
    sel_train  = select_features(eng_train, target_column, {{"method": "mutual_info", "k": 15}})
    sel_test   = select_features(eng_test,  target_column, {{"method": "mutual_info", "k": 15}})
    model_file = train_model(sel_train, target_column, algorithm, hyperparams)
    eval_result = evaluate_model(model_file, sel_test, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"]}}

Error analysis — explore what the model gets wrong, then return that as insight:
    train_file = split_dataset(data, target_column, 0.2, time_column, "train")
    test_file  = split_dataset(data, target_column, 0.2, time_column, "test")
    model_file = train_model(train_file, target_column, algorithm, hyperparams)
    pred_file  = get_predictions(model_file, test_file, target_column)
    error_analysis = explore_dataset(pred_file, {{"target_column": "correct", "class_distributions": ["vibration_mms", "temperature_c"]}})
    eval_result = evaluate_model(model_file, test_file, target_column)
    {{"experiment_name": experiment_name, "algorithm": algorithm, "metrics": eval_result["metrics"], "confusion_matrix": eval_result["confusion_matrix"], "threshold_analysis": eval_result["threshold_analysis"], "n_samples": eval_result["n_samples"], "error_analysis": error_analysis}}

The last expression MUST be a dict with at minimum these keys:
    experiment_name, algorithm, metrics, confusion_matrix, threshold_analysis, n_samples
Additional keys (e.g. error_analysis) are allowed and will appear in the report.

## Response format

Respond in exactly this format:

## Reasoning
[Your thinking: what pipeline makes sense for this experiment and why. Consider whether
feature engineering helps, whether class imbalance needs special treatment, whether
chaining multiple steps adds value, etc.]

## Code
```python
[your orchestration code]
```
"""
# {{/docs-fragment orchestration_prompt}}

def _build_analysis_system_prompt(max_iterations: int, current_iteration: int) -> str:
    remaining = max_iterations - current_iteration - 1
    return f"""\
You are an expert ML engineer analyzing experiment results to guide the next iteration
of model development.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{{
  "should_continue": true | false,
  "reasoning": "What you observed, what it tells you, and what to try next",
  "exploration_requests": [
    {{
      "question": "The specific hypothesis you are testing, e.g. 'Do failure cases show meaningfully higher vibration than healthy cases?'",
      "analysis_type": "class_distributions",
      "target_column": "failure_24h",
      "class_distributions": ["vibration_mms", "temperature_c"]
    }}
  ],
  "next_experiments": [
    {{
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": {{ ... algorithm-specific hyperparams ... }},
      "feature_config": {{
        "group_column": "...",
        "time_column": "...",
        "rolling_columns": [...],
        "windows": [...],
        "lag_columns": [...],
        "lags": [...],
        "normalize": true | false,
        "drop_columns": [...],
        "fillna_method": "forward"
      }},
      "rationale": "Why this experiment is worth trying"
    }}
  ]
}}

exploration_requests rules:
- Max 2 requests per iteration.
- Each request targets EXACTLY ONE analysis_type. Do not mix multiple types in one request.
- Supported analysis_type values and their required config fields:
    "class_distributions" → requires: target_column, class_distributions (list of columns)
    "correlation_matrix"  → requires: correlation_matrix: true
    "temporal_trend"      → requires: temporal_trend: {{time_column, target_column, freq}}
    "group_stats"         → requires: group_stats: {{group_column, target_column}}
    "outlier_summary"     → requires: outlier_summary (list of columns)
    "feature_target_corr_by_group" → requires: feature_target_corr_by_group: {{group_column, target_column, feature_columns}}
- The "question" field is required. It must be a specific testable hypothesis, not a
  general request. Bad: "explore the data". Good: "Is vibration_mms higher for failures?"
- Set exploration_requests to [] if the current results already tell you enough to
  design the next experiments. Only explore when you have a concrete unanswered question.

When deciding next experiments, reason about WHAT WAS TRIED vs what hasn't been explored.
Each result includes used_feature_engineering, used_rolling_features, used_lag_features.
Think systematically: if no feature engineering was tried yet, does the data profile
suggest it would help (weak raw correlations, temporal/sequential structure)?
If feature engineering helped, can it be improved? Avoid experiments identical to ones tried.

Iteration context: this is iteration {current_iteration + 1} of {max_iterations} requested.
Remaining iterations allowed: {remaining}.
Set should_continue=false only if:
- Best ROC-AUC >= 0.97, OR
- No remaining iterations (remaining == 0), OR
- Results have genuinely plateaued (< 0.005 ROC-AUC improvement over last iteration
  AND you have already tried the most promising directions)
Otherwise keep exploring — the user asked for {max_iterations} iterations for a reason.
"""

def _build_initial_design_system_prompt() -> str:
    return """\
You are an expert ML engineer. Given a dataset profile and a problem description,
design the first batch of experiments to run.

You must respond with valid JSON only — no markdown, no explanation outside the JSON.

Response format:
{
  "problem_type": "binary_classification",
  "primary_metric": "roc_auc" | "f1" | "recall",
  "reasoning": "Brief description of your strategy",
  "experiments": [
    {
      "name": "descriptive experiment name",
      "algorithm": "xgboost" | "random_forest" | "gradient_boosting" | "logistic_regression",
      "hyperparams": { ... algorithm-specific hyperparams ... },
      "feature_config": {
        "group_column": "",
        "time_column": "",
        "rolling_columns": [],
        "windows": [],
        "lag_columns": [],
        "lags": [],
        "normalize": false,
        "drop_columns": [],
        "fillna_method": "forward"
      },
      "rationale": "Why this experiment makes sense given the data profile"
    }
  ]
}

Design 2-3 experiments for the first batch. Good first batches typically include:
- A fast baseline to establish a floor (e.g. logistic_regression with default settings)
- Your best initial hypothesis given the data profile
- Optionally one variant that tests a specific idea suggested by the profile

Use the dataset profile to guide your choices:
- feature_target_corr: weak raw correlations suggest feature engineering may help
- categorical_columns: note these are excluded from the model automatically
- is_imbalanced: handle with class_weight or scale_pos_weight
- Shape and column types should inform algorithm complexity (simpler models for small datasets)
- A time column suggests sequential structure; lag/rolling features may capture temporal patterns

The feature_config in each experiment describes what engineer_features should apply.
Leave all fields empty/false if no feature engineering is needed for that experiment.
The orchestration code generator will decide the exact pipeline — your job here is
to specify what the experiment is trying to learn, not to prescribe every implementation detail.
"""

# ---------------------------------------------------------------------------
# LLM client
# ---------------------------------------------------------------------------

def _openai_client():
    from openai import OpenAI
    return OpenAI(api_key=os.environ["OPENAI_API_KEY"])

async def _call_llm(system: str, messages: list[dict], model: str = "gpt-4o") -> str:
    """Call OpenAI and return the response text."""
    client = _openai_client()
    response = await asyncio.to_thread(
        client.chat.completions.create,
        model=model,
        messages=[{"role": "system", "content": system}, *messages],
        temperature=0.2,
    )
    return response.choices[0].message.content

def _extract_code(text: str) -> str:
    """Pull Python code out of a markdown code block."""
    if "```python" in text:
        start = text.index("```python") + len("```python")
        end = text.index("```", start)
        return text[start:end].strip()
    if "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        return text[start:end].strip()
    return text.strip()

def _extract_reasoning(text: str) -> str:
    """Extract the ## Reasoning section from LLM response."""
    if "## Reasoning" in text:
        start = text.index("## Reasoning") + len("## Reasoning")
        if "## Code" in text:
            end = text.index("## Code")
            return text[start:end].strip()
        return text[start:].strip()
    return ""

def _parse_json(text: str) -> dict:
    """Extract and parse JSON from LLM response."""
    text = text.strip()
    if "```json" in text:
        start = text.index("```json") + 7
        end = text.index("```", start)
        text = text[start:end].strip()
    elif "```" in text:
        start = text.index("```") + 3
        end = text.index("```", start)
        text = text[start:end].strip()
    return json.loads(text)

# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------

def _recommend_threshold(threshold_analysis: list, min_precision: float = 0.70) -> dict | None:
    """Find the threshold that maximises recall subject to precision >= min_precision."""
    candidates = [t for t in threshold_analysis if t["precision"] >= min_precision]
    if not candidates:
        return None
    return max(candidates, key=lambda t: t["recall"])

def _print_experiment_table(results: list["ExperimentResult"], best_name: str) -> None:
    """Print a ranked comparison table of all experiments."""
    sorted_results = sorted(results, key=lambda r: r.metrics.get("roc_auc", 0), reverse=True)
    print("\n" + "─" * 78)
    print(f"  {'Rank':<5} {'Experiment':<32} {'ROC-AUC':<9} {'F1':<7} {'Recall':<8} {'Note'}")
    print("─" * 78)
    for rank, r in enumerate(sorted_results, 1):
        note = "◀ winner" if r.name == best_name else ""
        roc = r.metrics.get("roc_auc", 0)
        f1 = r.metrics.get("f1", 0)
        recall = r.metrics.get("recall", 0)
        print(f"  {rank:<5} {r.name:<32} {roc:<9.4f} {f1:<7.4f} {recall:<8.4f} {note}")
    print("─" * 78)

def _print_threshold_recommendation(threshold_analysis: list, default_metrics: dict) -> None:
    """Print the operational threshold recommendation."""
    rec = _recommend_threshold(threshold_analysis)
    if not rec:
        return
    default_recall = default_metrics.get("recall", 0)
    default_precision = default_metrics.get("precision", 0)
    missed_pct = round((1 - rec["recall"]) * 100, 1)
    false_alarm_pct = round((1 - rec["precision"]) * 100, 1)

    print(f"\n  Recommended decision threshold: {rec['threshold']}")
    print(f"  ├─ Precision : {rec['precision']:.0%}   ({false_alarm_pct}% of alerts are false alarms)")
    print(f"  ├─ Recall    : {rec['recall']:.0%}   (catches {rec['recall']*100:.0f}% of actual failures)")
    print(f"  └─ F1        : {rec['f1']:.4f}")
    print(f"  Default threshold (0.5): Precision={default_precision:.0%}, Recall={default_recall:.0%}")
    if rec["recall"] > default_recall:
        extra = round((rec["recall"] - default_recall) * 100, 1)
        print(f"  → Lowering threshold catches {extra}% more failures at cost of more alerts")

# ---------------------------------------------------------------------------
# Orchestration code generation (durable Flyte task with Flyte report)
# ---------------------------------------------------------------------------

@agent_env.task
async def plan_experiment(
    experiment_json: str,
    profile_json: str,
    target_column: str,
    time_column: str,
    previous_error: str = "",
    previous_code: str = "",
    llm_model: str = "gpt-4o",
) -> str:
    """LLM plans a single experiment: reasons about the pipeline and generates Monty code.

    Runs as a durable Flyte task so each experiment's planning step is traceable.
    Returns a JSON string: {"code": "...", "reasoning": "..."}.

    Args:
        experiment_json: JSON string of the experiment spec (name, algorithm, hyperparams, ...).
        profile_json: JSON string of the dataset profile from profile_dataset.
        target_column: Name of the target column.
        time_column: Time column for temporal splitting, or empty string.
        previous_error: Error message from the previous attempt (empty on first try).
        previous_code: Code that failed on the previous attempt (empty on first try).
        llm_model: OpenAI model identifier.

    Returns:
        str — JSON string with keys "code" and "reasoning".
    """
    experiment = json.loads(experiment_json)
    profile = json.loads(profile_json)
    exp_name = experiment.get("name", "experiment")

    # Strip rationale — it was written by the design LLM to explain *why* this
    # experiment was chosen. Passing it here causes plan_experiment to parrot it
    # back as "reasoning" instead of independently thinking about *how* to build
    # the best pipeline. Keep only the technical spec.
    pipeline_spec = {
        k: v for k, v in experiment.items()
        if k not in ("rationale",)
    }

    system = _build_orchestration_system_prompt(profile)

    user_content = textwrap.dedent(f"""
        Design and implement the best pipeline for this experiment:

        Name: {exp_name}
        Algorithm: {pipeline_spec.get("algorithm")}
        Hyperparams: {json.dumps(pipeline_spec.get("hyperparams", {}), indent=2)}
        Feature config hint: {json.dumps(pipeline_spec.get("feature_config", {}), indent=2)}

        Available sandbox inputs:
        - data: File  — the full dataset CSV
        - target_column: str = "{target_column}"
        - time_column: str = "{time_column}"  (empty string means no time ordering)
        - experiment_name: str = "{exp_name}"

        The feature config hint is a suggestion from the experiment designer — you can
        follow it, improve on it, or override it if the dataset context and your ML
        judgment suggest a better approach. In your ## Reasoning, explain your actual
        pipeline decisions: what you chose to do (or not do) and why, based on the
        dataset profile above. Do not restate the experiment name or why it was chosen.
    """).strip()

    messages = [{"role": "user", "content": user_content}]
    if previous_code and previous_error:
        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": f"```python\n{previous_code}\n```"},
            {"role": "user", "content": f"That code failed with this error:\n\n{previous_error}\n\nPlease fix it."},
        ]

    response = await _call_llm(system, messages, llm_model)
    reasoning = _extract_reasoning(response)
    code = _extract_code(response)
    return json.dumps({"code": code, "reasoning": reasoning})

@flyte.trace
async def design_experiments(
    problem_description: str,
    profile_json: str,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs the initial batch of experiments given problem + dataset profile.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching InitialDesign schema).
    """
    design_prompt = textwrap.dedent(f"""
        Problem description: {problem_description}

        Dataset profile:
        {profile_json}

        Design the first batch of experiments.
    """).strip()
    return await _call_llm(
        _build_initial_design_system_prompt(),
        [{"role": "user", "content": design_prompt}],
        llm_model,
    )

@flyte.trace
async def analyze_iteration(
    analysis_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM analyzes experiment results and decides whether/how to continue.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string matching IterationDecision schema).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [{"role": "user", "content": analysis_prompt}],
        llm_model,
    )

@flyte.trace
async def plan_followup(
    analysis_prompt: str,
    analysis_response: str,
    followup_prompt: str,
    max_iterations: int,
    current_iteration: int,
    llm_model: str = "gpt-4o",
) -> str:
    """LLM designs next experiments after targeted data explorations.

    Traced so the prompt/response is visible in the Flyte UI and results are
    cached for deterministic replay on crash/retry.
    Returns raw LLM response (JSON string with {"next_experiments": [...]}).
    """
    return await _call_llm(
        _build_analysis_system_prompt(max_iterations, current_iteration),
        [
            {"role": "user", "content": analysis_prompt},
            {"role": "assistant", "content": analysis_response},
            {"role": "user", "content": followup_prompt},
        ],
        llm_model,
    )

def _corrupt_experiment_for_demo(exp_dict: dict) -> dict:
    """Introduce a deliberate error into the first experiment for demo purposes.

    Corrupts the algorithm name so the LLM must recover from a known-bad value.
    The retry loop will catch this, regenerate with the error message, and fix it.
    """
    corrupted = dict(exp_dict)
    corrupted["algorithm"] = corrupted["algorithm"] + "_INVALID"
    return corrupted

# ---------------------------------------------------------------------------
# Main agent loop
# ---------------------------------------------------------------------------

@dataclass
class ExperimentResult:
    name: str
    algorithm: str
    metrics: dict
    confusion_matrix: dict
    threshold_analysis: list
    n_samples: int
    code: str
    attempts: int
    reasoning: str = ""
    error: str = ""

@dataclass
class AgentResult:
    model_card: str
    best_experiment: str
    best_metrics: dict
    all_results: list[ExperimentResult]
    iterations: int
    total_experiments: int

async def _run_experiment(
    exp: "ExperimentConfig",
    exp_dict: dict,
    inject_failure: bool,
    data: File,
    target_column: str,
    time_column: str,
    profile: dict,
    llm_model: str,
    max_retries: int,
) -> "ExperimentResult | None":
    """Run a single experiment with retries. Returns None on total failure."""
    exp_name = exp.name
    profile_json = json.dumps(profile)
    print(f"\n   ┌─ {exp_name}  [{exp.algorithm}]")
    if exp.rationale:
        for line in textwrap.wrap(exp.rationale, width=58):
            print(f"   │  {line}")
    if inject_failure:
        print(f"   │  [injecting failure for demo: algorithm='{exp_dict['algorithm']}']")

    code = ""
    error = ""
    result = None
    attempt = 0

    reasoning = ""
    # {{docs-fragment retry_loop}}
    for attempt in range(max_retries):
        try:
            with flyte.group(exp_name):
                plan_json = await plan_experiment.aio(
                    experiment_json=json.dumps(exp_dict),
                    profile_json=profile_json,
                    target_column=target_column,
                    time_column=time_column,
                    previous_error=error,
                    previous_code=code,
                    llm_model=llm_model,
                )
                plan = json.loads(plan_json)
                code = plan["code"]
                reasoning = plan.get("reasoning", "")
                result = await flyte.sandbox.orchestrate_local(
                    code,
                    inputs={"data": data, "target_column": target_column,
                            "time_column": time_column, "experiment_name": exp_name},
                    tasks=TOOLS,
                )
            error = ""
            break
        except Exception as exc:
            error = str(exc)
            short_error = error[:100] + "..." if len(error) > 100 else error
            print(f"   │  attempt {attempt + 1} failed: {short_error}")
            print(f"   │  → asking LLM to fix and retry...")
            if inject_failure and attempt == 0:
                exp_dict = exp.model_dump()
    # {{/docs-fragment retry_loop}}

    if result and not error:
        exp_result = ExperimentResult(
            name=exp_name,
            algorithm=exp.algorithm,
            metrics=result.get("metrics", {}),
            confusion_matrix=result.get("confusion_matrix", {}),
            threshold_analysis=result.get("threshold_analysis", []),
            n_samples=result.get("n_samples", 0),
            code=code,
            reasoning=reasoning,
            attempts=attempt + 1,
        )
        m = exp_result.metrics
        attempts_note = f" (recovered after {attempt + 1} attempts)" if attempt > 0 else ""
        print(f"   └─ ROC-AUC={m.get('roc_auc')}, F1={m.get('f1')}, Recall={m.get('recall')}{attempts_note}")
        return exp_result

    print(f"   └─ FAILED after {max_retries} attempts — skipping.")
    return None

async def run_agent(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
    max_retries_per_experiment: int = 3,
    llm_model: str = "gpt-4o",
    inject_failure: bool = False,
) -> AgentResult:
    """Run the MLE agent end-to-end.

    Args:
        data: CSV file containing the dataset.
        problem_description: Natural language description of the ML problem.
        target_column: Name of the target column to predict.
        time_column: Optional column to use for time-based train/test split.
        max_iterations: Maximum number of experiment iterations to run.
        max_retries_per_experiment: Max times to retry a failed sandbox execution.
        llm_model: OpenAI model to use (default: gpt-4o).
        inject_failure: If True, corrupts the first experiment to demonstrate self-healing.
    """
    print(f"\n{'='*60}")
    print(f"MLE Agent starting")
    print(f"Problem: {problem_description}")
    print(f"Target: {target_column}")
    if inject_failure:
        print(f"[demo mode: failure injection enabled]")
    print(f"{'='*60}\n")

    # {{docs-fragment phase1_profile}}
    # --- Phase 1: Profile the dataset (trusted tool, LLM never sees raw data) ---
    print(">> Phase 1: Profiling dataset...")
    with flyte.group("profile"):
        profile = await profile_dataset(data, target_column)
    # {{/docs-fragment phase1_profile}}
    print(f"   Shape: {profile['shape']}, Classes: {profile['target_distribution']}")
    print(f"   Imbalanced: {profile['is_imbalanced']}, Columns: {len(profile['columns'])}")
    corr = profile.get("feature_target_corr", {})
    top_corr = list(corr.items())[:5]
    print(f"   Top correlations: {', '.join(f'{k}={v:+.3f}' for k,v in top_corr)}")

    # Stream report: dataset summary
    await flyte.report.log.aio(
        f"<h1>MLE Agent Run</h1>"
        f"<p><b>Problem:</b> {problem_description}</p>"
        f"<p><b>Dataset:</b> {profile['shape'][0]:,} rows × {profile['shape'][1]} cols &nbsp;|&nbsp; "
        f"Class balance: {profile['class_balance']} &nbsp;|&nbsp; Imbalanced: {profile['is_imbalanced']}</p>"
        f"<p><b>Top feature-target correlations (raw):</b> "
        + ", ".join(f"{k}: {v:+.3f}" for k, v in top_corr) +
        f"</p><hr>",
        do_flush=True,
    )

    # --- Phase 2: LLM designs initial experiments ---
    print("\n>> Phase 2: Designing initial experiments...")
    design_response = await design_experiments(
        problem_description=problem_description,
        profile_json=json.dumps(profile),
        llm_model=llm_model,
    )
    design = InitialDesign.model_validate(_parse_json(design_response))
    print(f"   Primary metric: {design.primary_metric}")
    print(f"   Strategy: {design.reasoning}")
    print(f"   Experiments planned: {len(design.experiments)}")

    all_results: list[ExperimentResult] = []
    iteration_log: list[dict] = []  # tracks per-iteration decisions + explorations for summary
    current_experiments: list[ExperimentConfig] = design.experiments
    first_experiment = True

    # --- Phase 3: Iterative experiment loop ---
    for iteration in range(max_iterations):
        experiments = current_experiments

        if not experiments:
            print(f"\n>> No experiments to run in iteration {iteration + 1}. Stopping.")
            break

        print(f"\n>> Phase 3.{iteration + 1}: Running {len(experiments)} experiment(s) in parallel...")

        # Assign names and prepare dicts before launching in parallel
        exp_batch = []
        for i, exp in enumerate(experiments):
            if not exp.name:
                exp.name = f"experiment_{len(all_results) + i + 1}"
            exp_dict = exp.model_dump()
            inject_this = inject_failure and first_experiment and i == 0
            if inject_this:
                exp_dict = _corrupt_experiment_for_demo(exp_dict)
            first_experiment = False
            exp_batch.append((exp, exp_dict, inject_this))

        # {{docs-fragment parallel_execute}}
        batch_results = await asyncio.gather(*[
            _run_experiment(
                exp=exp,
                exp_dict=exp_dict,
                inject_failure=inject_this,
                data=data,
                target_column=target_column,
                time_column=time_column,
                profile=profile,
                llm_model=llm_model,
                max_retries=max_retries_per_experiment,
            )
            for exp, exp_dict, inject_this in exp_batch
        ])
        # {{/docs-fragment parallel_execute}}

        for exp_result in batch_results:
            if exp_result is not None:
                all_results.append(exp_result)
                # Stream report: each experiment as it completes
                m = exp_result.metrics
                html = (
                    f"<h3>Iteration {iteration + 1} — {exp_result.name}</h3>"
                    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
                    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
                    f"<b>F1:</b> {m.get('f1')} &nbsp;|&nbsp; "
                    f"<b>Recall:</b> {m.get('recall')} &nbsp;|&nbsp; "
                    f"<b>Attempts:</b> {exp_result.attempts}</p>"
                )
                if exp_result.reasoning:
                    html += f"<details><summary>Reasoning</summary><pre>{exp_result.reasoning}</pre></details>"
                html += f"<details><summary>Generated Code</summary><pre>{exp_result.code}</pre></details>"
                await flyte.report.log.aio(html, do_flush=True)

        # --- Phase 4: Analyze results, decide whether to iterate ---
        if all_results and iteration < max_iterations - 1:
            print(f"\n>> Phase 4.{iteration + 1}: Analyzing results, deciding next steps...")
            results_summary = [
                {
                    "experiment_name": r.name,
                    "algorithm": r.algorithm,
                    "metrics": r.metrics,
                    "confusion_matrix": r.confusion_matrix,
                    "used_feature_engineering": "engineer_features" in r.code,
                    "used_rolling_features": "rolling_columns" in r.code,
                    "used_lag_features": "lag_columns" in r.code,
                }
                for r in all_results
            ]
            analysis_prompt = textwrap.dedent(f"""
                Problem: {problem_description}
                Dataset profile: shape={profile['shape']}, imbalanced={profile['is_imbalanced']}
                Feature-target correlations (raw): {json.dumps(profile.get('feature_target_corr', {}), indent=2)}

                Experiment results so far (iteration {iteration + 1}):
                {json.dumps(results_summary, indent=2)}

                Should we run more experiments? If yes, request any data explorations
                you need, then specify what experiments to run next.
            """).strip()

            analysis_response = await analyze_iteration(
                analysis_prompt=analysis_prompt,
                max_iterations=max_iterations,
                current_iteration=iteration,
                llm_model=llm_model,
            )
            decision = IterationDecision.model_validate(_parse_json(analysis_response))
            verdict = "continuing" if decision.should_continue else "stopping"
            print(f"   Decision: {verdict}")
            print(f"   Reasoning: {decision.reasoning}")

            # Stream report: analysis decision
            await flyte.report.log.aio(
                f"<h3>Analysis — Iteration {iteration + 1}</h3>"
                f"<p><b>Decision:</b> {verdict}</p>"
                f"<p><b>Reasoning:</b> {decision.reasoning}</p>",
                do_flush=True,
            )

            # Track this iteration for the experiment journey summary
            iter_entry = {
                "iteration": iteration + 1,
                "experiments": [r.name for r in batch_results if r is not None],
                "best_roc_auc": max(
                    (r.metrics.get("roc_auc", 0) for r in all_results), default=0
                ),
                "reasoning": decision.reasoning,
                "explorations": [],
            }

            # --- Targeted exploration before next iteration ---
            if decision.should_continue and decision.exploration_requests:
                print(f"   Running {len(decision.exploration_requests)} exploration request(s)...")
                exploration_questions = []
                exploration_results = []

                for i, req in enumerate(decision.exploration_requests):
                    question = req.get("question", f"Exploration {i + 1}")
                    # Strip agent-level metadata — tool only needs the analysis config
                    tool_config = {k: v for k, v in req.items() if k not in ("question", "analysis_type")}

                    print(f"   Q: {question}")
                    with flyte.group(f"explore_{iteration + 1}_{i + 1}"):
                        result = await explore_dataset(data, tool_config)
                    exploration_questions.append(question)
                    exploration_results.append(result)
                    iter_entry["explorations"].append({"question": question})

                    await flyte.report.log.aio(
                        f"<h4>Exploration {i + 1}</h4>"
                        f"<p><b>Question:</b> {question}</p>"
                        f"<details><summary>Results</summary><pre>{json.dumps(result, indent=2)}</pre></details>",
                        do_flush=True,
                    )

                # Build follow-up that explicitly connects each question to its answer
                qa_pairs = "\n\n".join(
                    f'Question {i + 1}: "{q}"\nResult:\n{json.dumps(r, indent=2)}'
                    for i, (q, r) in enumerate(zip(exploration_questions, exploration_results))
                )
                followup_prompt = textwrap.dedent(f"""
                    You requested {len(exploration_results)} targeted exploration(s).
                    Here is what you asked and what you learned:

                    {qa_pairs}

                    Given what you learned and your earlier reasoning:
                    "{decision.reasoning}"

                    Now specify the next experiments. For each experiment, briefly state
                    which exploration insight informed your choice.
                    Respond with valid JSON: {{"next_experiments": [...same schema as before...]}}
                """).strip()
                followup_response = await plan_followup(
                    analysis_prompt=analysis_prompt,
                    analysis_response=analysis_response,
                    followup_prompt=followup_prompt,
                    max_iterations=max_iterations,
                    current_iteration=iteration,
                    llm_model=llm_model,
                )
                followup = _parse_json(followup_response)
                current_experiments = IterationDecision.model_validate({
                    "should_continue": True,
                    "reasoning": decision.reasoning,
                    "next_experiments": followup.get("next_experiments", []),
                }).next_experiments
                print(f"   Post-exploration: {len(current_experiments)} experiment(s) planned")
            else:
                current_experiments = decision.next_experiments

            iteration_log.append(iter_entry)

            if not decision.should_continue:
                break

    # --- Phase 5: Rank all results and generate model card ---
    print(f"\n>> Phase 5: Ranking {len(all_results)} experiment(s) and generating model card...")

    if not all_results:
        return AgentResult(
            model_card="No experiments completed successfully.",
            best_experiment="",
            best_metrics={},
            all_results=[],
            iterations=iteration + 1,
            total_experiments=0,
        )

    ranking_input = [
        {
            "experiment_name": r.name,
            "metrics": r.metrics,
            "confusion_matrix": r.confusion_matrix,
        }
        for r in all_results
    ]
    with flyte.group("rank"):
        ranking = await rank_experiments(json.dumps(ranking_input))
    best_name = ranking["best_experiment"]
    best_result = next(r for r in all_results if r.name == best_name)

    _print_experiment_table(all_results, best_name)
    _print_threshold_recommendation(best_result.threshold_analysis, best_result.metrics)

    # Stream report: final rankings table
    rows = "".join(
        f"<tr><td>{row['rank']}</td>"
        f"<td>{'<b>' if row['experiment_name'] == best_name else ''}"
        f"{row['experiment_name']}"
        f"{'</b>' if row['experiment_name'] == best_name else ''}</td>"
        f"<td>{row['roc_auc']}</td><td>{row['f1']}</td>"
        f"<td>{row['recall']}</td><td>{row['precision']}</td></tr>"
        for row in ranking.get("ranking", [])
    )
    await flyte.report.log.aio(
        f"<hr><h2>Final Rankings</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0'>"
        f"<tr><th>Rank</th><th>Experiment</th><th>ROC-AUC</th><th>F1</th><th>Recall</th><th>Precision</th></tr>"
        f"{rows}</table>"
        f"<p>{ranking.get('summary', '')}</p>",
        do_flush=True,
    )

    # Stream report: experiment journey summary
    journey_rows = ""
    for entry in iteration_log:
        exps = ", ".join(entry["experiments"]) if entry["experiments"] else "—"
        explorations = "; ".join(e["question"] for e in entry["explorations"]) if entry["explorations"] else "—"
        short_reasoning = (entry["reasoning"][:120] + "…") if len(entry["reasoning"]) > 120 else entry["reasoning"]
        journey_rows += (
            f"<tr>"
            f"<td style='text-align:center'>{entry['iteration']}</td>"
            f"<td>{exps}</td>"
            f"<td style='text-align:center'>{entry['best_roc_auc']:.4f}</td>"
            f"<td>{short_reasoning}</td>"
            f"<td>{explorations}</td>"
            f"</tr>"
        )
    await flyte.report.log.aio(
        f"<hr><h2>Experiment Journey</h2>"
        f"<table border='1' cellpadding='6' cellspacing='0' style='width:100%;border-collapse:collapse'>"
        f"<tr><th>Iter</th><th>Experiments</th><th>Best ROC-AUC</th><th>Key insight</th><th>Explorations</th></tr>"
        f"{journey_rows}"
        f"</table>",
        do_flush=True,
    )

    model_card = await _generate_model_card(
        problem_description=problem_description,
        profile=profile,
        all_results=all_results,
        best_result=best_result,
        ranking=ranking,
        iteration_log=iteration_log,
        llm_model=llm_model,
    )

    print(f"\n{'='*60}")
    print(f"DONE — Best model: {best_name}")
    print(f"       ROC-AUC={best_result.metrics.get('roc_auc')}, F1={best_result.metrics.get('f1')}")
    print(f"{'='*60}\n")

    return AgentResult(
        model_card=model_card,
        best_experiment=best_name,
        best_metrics=best_result.metrics,
        all_results=all_results,
        iterations=iteration + 1,
        total_experiments=len(all_results),
    )

async def _generate_model_card(
    problem_description: str,
    profile: dict,
    all_results: list[ExperimentResult],
    best_result: ExperimentResult,
    ranking: dict,
    iteration_log: list[dict],
    llm_model: str,
) -> str:
    """Generate a markdown model card summarizing the winning model."""
    system = textwrap.dedent("""
        You are an ML engineer writing a model card for a trained model.
        Write in markdown. Be concise but informative. Include:
        - Problem statement
        - Dataset summary
        - Experiment journey (brief per-iteration narrative: what was tried, what was learned, what changed)
        - Experiment summary (table of all experiments with metrics)
        - Winning model details (algorithm, key hyperparams, metrics, threshold analysis)
        - Recommendations for deployment (decision threshold, monitoring)
    """).strip()

    results_text = "\n".join(
        f"- {r.name} ({r.algorithm}): ROC-AUC={r.metrics.get('roc_auc')}, "
        f"F1={r.metrics.get('f1')}, Recall={r.metrics.get('recall')}"
        for r in all_results
    )

    journey_text = ""
    if iteration_log:
        journey_text = "\n\nIteration log:\n" + "\n".join(
            f"  Iteration {e['iteration']}: ran [{', '.join(e['experiments'])}], "
            f"best ROC-AUC so far={e['best_roc_auc']:.4f}. "
            f"Key insight: {e['reasoning'][:200]}. "
            + (f"Explorations: {'; '.join(x['question'] for x in e['explorations'])}" if e['explorations'] else "")
            for e in iteration_log
        )

    user_content = textwrap.dedent(f"""
        Problem: {problem_description}

        Dataset: {profile['shape'][0]} rows × {profile['shape'][1]} cols.
        Class balance: {profile['class_balance']}
        Imbalanced: {profile['is_imbalanced']}
        {journey_text}

        All experiments:
        {results_text}

        Best model: {best_result.name} ({best_result.algorithm})
        Metrics: {json.dumps(best_result.metrics, indent=2)}
        Confusion matrix: {json.dumps(best_result.confusion_matrix, indent=2)}
        Threshold analysis: {json.dumps(best_result.threshold_analysis, indent=2)}

        Ranking summary: {ranking['summary']}
    """).strip()

    response = await _call_llm(system, [{"role": "user", "content": user_content}], llm_model)
    return response

# ---------------------------------------------------------------------------
# Durable entrypoint (runs the agent as a Flyte task in the cloud)
# ---------------------------------------------------------------------------

# {{docs-fragment entrypoint}}
@agent_env.task(retries=1, report=True)
async def mle_agent_task(
    data: File,
    problem_description: str,
    target_column: str,
    time_column: str = "",
    max_iterations: int = 3,
) -> str:
    """Durable Flyte task entrypoint for the MLE agent."""
    result = await run_agent(
        data=data,
        problem_description=problem_description,
        target_column=target_column,
        time_column=time_column,
        max_iterations=max_iterations,
    )
    return result.model_card
# {{/docs-fragment entrypoint}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/tutorials/mle_bot/mle_bot/agent.py*

As each experiment completes, the agent streams its metrics to the report:

```python
await flyte.report.log.aio(
    f"<h3>Iteration {iteration + 1}: {exp_result.name}</h3>"
    f"<p><b>Algorithm:</b> {exp_result.algorithm} &nbsp;|&nbsp; "
    f"<b>ROC-AUC:</b> {m.get('roc_auc')} &nbsp;|&nbsp; "
    f"<b>F1:</b> {m.get('f1')}</p>",
    do_flush=True,
)
```

The final report includes a dataset summary, per-experiment metrics with expandable reasoning and generated code, the analysis decisions at each iteration, a final rankings table, and an experiment journey summary showing how the agent's strategy evolved.

## Running it

First, generate the synthetic demo dataset (a predictive maintenance scenario with 175k+ rows of simulated sensor data from 20 industrial pumps):

```bash
uv run main.py generate-data
```

Then submit the agent to your Flyte cluster:

```bash
uv run main.py run \
    --data data/predictive_maintenance.csv \
    --problem "Predict pump failures 24 hours before they happen" \
    --target failure_24h \
    --time-column timestamp \
    --max-iterations 3 \
    --output results/report.md
```

The agent connects to your cluster via `~/.flyte/config.yaml`, uploads the CSV, and submits the agent task. You'll see a URL to track the execution in the Flyte UI, and logs will stream to your terminal.

> [!NOTE]
> You'll need to register your OpenAI API key as a cluster secret before running:
> `flyte create secret openai-api-key <YOUR_KEY>`

If you want to see the self-healing retry loop in action, add the `--inject-failure` flag. This deliberately corrupts the first experiment so the agent has to detect the error and recover, which makes for a nice demo of the durability guarantees.

## Why Flyte?

You could build something similar with plain Python and `exec()`. But there are a few things you'd lose.

**Safety.** Flyte's sandbox restricts LLM-generated code to calling your pre-approved functions and nothing else. No imports, no network, no filesystem. If you wouldn't give an intern root access to your production cluster, you probably shouldn't give an LLM unrestricted code execution either.

**Durability.** Every tool call is a Flyte task. If the agent process crashes halfway through iteration 3, the experiments that already completed are cached. You restart and pick up where you left off instead of retraining models from scratch. For long-running ML experiments, this matters.

**Observability.** You can see every LLM prompt, every generated code snippet, every tool invocation, and every result in the Flyte UI. When the agent makes a questionable decision (like skipping feature engineering on temporal data), you can trace exactly why: the prompt it received, the profile it read, the reasoning it generated.

**Compute isolation.** The ML tools run on cloud instances with the CPU and memory they need. The agent itself runs on a small 1-CPU instance since all it does is call the LLM and dispatch tool tasks. You're not bottlenecked by your laptop, and you're not paying for GPU-class compute to run an orchestration loop.

**Parallelism.** Multiple experiments run simultaneously via `asyncio.gather()`, each dispatching its own durable tasks. Flyte handles the scheduling. If you have three experiments in a batch and each involves training + evaluation, that's six tasks running concurrently on cloud compute.

The MLE Bot is a specific example of a more general pattern: giving an LLM the ability to reason about *what* work should be done, while Flyte handles *how* that work gets executed safely, durably, and at scale. The sandbox is the boundary between the two. Everything above the boundary is LLM-generated and untrusted. Everything below it is your code, running on your infrastructure, with all the guarantees you'd expect from a production orchestrator.

---
**Source**: https://github.com/unionai/unionai-docs/blob/main/content/tutorials/mle-bot/_index.md
**HTML**: https://www.union.ai/docs/v2/flyte/tutorials/mle-bot/
