flyteplugins.mlflow
Key features:
- Automatic MLflow run management with
@mlflow_rundecorator - Built-in autologging support via
autolog=Trueparameter - Auto-generated MLflow UI links via
link_hostconfig and theMlflowlink class - Parent/child task support with run sharing
- Distributed training support (only rank 0 logs to MLflow)
- Configuration management with
mlflow_config()
Basic usage:
-
Manual logging with
@mlflow_run:from flyteplugins.mlflow import mlflow_run, get_mlflow_run @mlflow_run( tracking_uri="http://localhost:5000", experiment_name="my-experiment", tags={"team": "ml"}, ) @env.task async def train_model(learning_rate: float) -> str: import mlflow mlflow.log_param("lr", learning_rate) mlflow.log_metric("loss", 0.5) run = get_mlflow_run() return run.info.run_id -
Automatic logging with
@mlflow_run(autolog=True):from flyteplugins.mlflow import mlflow_run @mlflow_run( autolog=True, framework="sklearn", tracking_uri="http://localhost:5000", log_models=True, log_datasets=False, experiment_id="846992856162999", ) @env.task async def train_sklearn_model(): from sklearn.linear_model import LogisticRegression model = LogisticRegression() model.fit(X, y) # Autolog captures parameters, metrics, and model -
Workflow-level configuration with
mlflow_config():from flyteplugins.mlflow import mlflow_config r = flyte.with_runcontext( custom_context=mlflow_config( tracking_uri="http://localhost:5000", experiment_id="846992856162999", tags={"team": "ml"}, ) ).run(train_model, learning_rate=0.001) -
Per-task config overrides with context manager:
@mlflow_run @env.task async def parent_task(): # Override config for a specific child task with mlflow_config(run_mode="new", tags={"role": "child"}): await child_task() -
Run modes — control run creation vs sharing:
@mlflow_run # "auto": new run if no parent, else share parent's @mlflow_run(run_mode="new") # Always create a new run -
HPO — objective can be a Flyte task with
run_mode="new":@mlflow_run(run_mode="new") @env.task def objective(params: dict) -> float: mlflow.log_params(params) loss = train(params) mlflow.log_metric("loss", loss) return loss -
Distributed training (only rank 0 logs):
@mlflow_run # Auto-detects rank from RANK env var @env.task async def distributed_train(): ... -
MLflow UI links — auto-generated via
link_host:from flyteplugins.mlflow import Mlflow, mlflow_config # Set link_host at workflow level — children with Mlflow() link # auto-get the URL after the parent creates the run. r = flyte.with_runcontext( custom_context=mlflow_config( tracking_uri="http://localhost:5000", link_host="http://localhost:5000", ) ).run(parent_task) # Attach the link to child tasks: @mlflow_run @env.task(links=[Mlflow()]) async def child_task(): ... # Custom URL template (e.g. Databricks): mlflow_config( link_host="https://dbc-xxx.cloud.databricks.com", link_template="{host}/ml/experiments/{experiment_id}/runs/{run_id}", )
Decorator order: @mlflow_run must be outermost (before @env.task):
@mlflow_run
@env.task
async def my_task(): ...
@mlflow_run(autolog=True, framework="sklearn")
@env.task
async def my_task(): ...Directory
Classes
| Class | Description |
|---|---|
Mlflow |
MLflow UI link for Flyte tasks. |
Methods
| Method | Description |
|---|---|
get_mlflow_context() |
Retrieve current MLflow configuration from Flyte context. |
get_mlflow_run() |
Get the current MLflow run if within a @mlflow_run decorated task or trace. |
mlflow_config() |
Create MLflow configuration. |
mlflow_run() |
Decorator to manage MLflow runs for Flyte tasks and plain functions. |
Methods
get_mlflow_context()
def get_mlflow_context()Retrieve current MLflow configuration from Flyte context.
get_mlflow_run()
def get_mlflow_run()Get the current MLflow run if within a @mlflow_run decorated task or trace.
The run is started when the @mlflow_run decorator enters.
Returns None if not within an mlflow_run context.
Returns: mlflow.ActiveRun | None: The current MLflow active run or None.
mlflow_config()
def mlflow_config(
tracking_uri: typing.Optional[str],
experiment_name: typing.Optional[str],
experiment_id: typing.Optional[str],
run_name: typing.Optional[str],
run_id: typing.Optional[str],
tags: typing.Optional[dict[str, str]],
run_mode: typing.Literal['auto', 'new', 'nested'],
autolog: bool,
framework: typing.Optional[str],
log_models: typing.Optional[bool],
log_datasets: typing.Optional[bool],
autolog_kwargs: typing.Optional[dict[str, typing.Any]],
link_host: typing.Optional[str],
link_template: typing.Optional[str],
kwargs: **kwargs,
) -> flyteplugins.mlflow._context._MLflowConfigCreate MLflow configuration.
Works in two contexts:
- With
flyte.with_runcontext()for global configuration - As a context manager to override configuration
| Parameter | Type | Description |
|---|---|---|
tracking_uri |
typing.Optional[str] |
MLflow tracking server URI. |
experiment_name |
typing.Optional[str] |
MLflow experiment name. |
experiment_id |
typing.Optional[str] |
MLflow experiment ID. |
run_name |
typing.Optional[str] |
Human-readable run name. |
run_id |
typing.Optional[str] |
Explicit MLflow run ID. |
tags |
typing.Optional[dict[str, str]] |
MLflow run tags. |
run_mode |
typing.Literal['auto', 'new', 'nested'] |
Flyte-specific run mode (“auto”, “new”, “nested”). |
autolog |
bool |
Enable MLflow autologging. |
framework |
typing.Optional[str] |
Framework-specific autolog (e.g. “sklearn”, “pytorch”). |
log_models |
typing.Optional[bool] |
Whether to log models automatically. |
log_datasets |
typing.Optional[bool] |
Whether to log datasets automatically. |
autolog_kwargs |
typing.Optional[dict[str, typing.Any]] |
Extra parameters passed to mlflow.autolog(). |
link_host |
typing.Optional[str] |
MLflow UI host for auto-generating task links. |
link_template |
typing.Optional[str] |
Custom URL template. Defaults to standard MLflow UI format. Available placeholders: {host}, {experiment_id}, {run_id}. |
kwargs |
**kwargs |
mlflow_run()
def mlflow_run(
_func: typing.Optional[~F],
run_mode: typing.Literal['auto', 'new', 'nested'],
tracking_uri: typing.Optional[str],
experiment_name: typing.Optional[str],
experiment_id: typing.Optional[str],
run_name: typing.Optional[str],
run_id: typing.Optional[str],
tags: typing.Optional[dict[str, str]],
autolog: bool,
framework: typing.Optional[str],
log_models: typing.Optional[bool],
log_datasets: typing.Optional[bool],
autolog_kwargs: typing.Optional[dict[str, typing.Any]],
rank: typing.Optional[int],
kwargs,
) -> ~FDecorator to manage MLflow runs for Flyte tasks and plain functions.
Handles both manual logging and autologging. For autologging, pass
autolog=True and optionally framework to select a specific
framework (e.g. "sklearn").
Decorator Order: @mlflow_run must be the outermost decorator::
@mlflow_run
@env.task
async def my_task():
...
| Parameter | Type | Description |
|---|---|---|
_func |
typing.Optional[~F] |
|
run_mode |
typing.Literal['auto', 'new', 'nested'] |
“auto” (default), “new”, or “nested”. - “auto”: reuse parent run if available, else create new. - “new”: always create a new independent run. - “nested”: create a new run nested under the parent via mlflow.parentRunId tag. Works across processes/containers. |
tracking_uri |
typing.Optional[str] |
MLflow tracking server URL. |
experiment_name |
typing.Optional[str] |
MLflow experiment name (exclusive with experiment_id). |
experiment_id |
typing.Optional[str] |
MLflow experiment ID (exclusive with experiment_name). |
run_name |
typing.Optional[str] |
Human-readable run name (exclusive with run_id). |
run_id |
typing.Optional[str] |
MLflow run ID (exclusive with run_name). |
tags |
typing.Optional[dict[str, str]] |
Dictionary of tags for the run. |
autolog |
bool |
Enable MLflow autologging. |
framework |
typing.Optional[str] |
MLflow framework name for autolog (e.g. “sklearn”, “pytorch”). |
log_models |
typing.Optional[bool] |
Whether to log models automatically (requires autolog). |
log_datasets |
typing.Optional[bool] |
Whether to log datasets automatically (requires autolog). |
autolog_kwargs |
typing.Optional[dict[str, typing.Any]] |
Extra parameters passed to mlflow.autolog(). |
rank |
typing.Optional[int] |
Process rank for distributed training (only rank 0 logs). |
kwargs |
**kwargs |